Coverage for an_website/backdoor/backdoor.py: 83.140%
172 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-16 19:56 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-16 19:56 +0000
1# This program is free software: you can redistribute it and/or modify
2# it under the terms of the GNU Affero General Public License as
3# published by the Free Software Foundation, either version 3 of the
4# License, or (at your option) any later version.
5#
6# This program is distributed in the hope that it will be useful,
7# but WITHOUT ANY WARRANTY; without even the implied warranty of
8# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9# GNU Affero General Public License for more details.
10#
11# You should have received a copy of the GNU Affero General Public License
12# along with this program. If not, see <https://www.gnu.org/licenses/>.
14"""The backdoor API of the website."""
16from __future__ import annotations
18import io
19import logging
20import pickle # nosec: B403
21import pickletools # nosec: B403
22import pydoc
23import traceback
24from ast import PyCF_ALLOW_TOP_LEVEL_AWAIT, PyCF_ONLY_AST, PyCF_TYPE_COMMENTS
25from asyncio import Future
26from base64 import b85decode, b85encode
27from collections.abc import MutableMapping
28from inspect import CO_COROUTINE # pylint: disable=no-name-in-module
29from random import Random
30from types import TracebackType
31from typing import Any, ClassVar, Final, cast
33import dill # type: ignore[import-untyped] # nosec: B403
34import jsonpickle # type: ignore[import-untyped]
35import regex
36from tornado.web import HTTPError
38from .. import EVENT_REDIS, EVENT_SHUTDOWN, pytest_is_running
39from ..utils.decorators import requires
40from ..utils.request_handler import APIRequestHandler
41from ..utils.utils import Permission
43LOGGER: Final = logging.getLogger(__name__)
44SEPARATOR: Final = regex.compile(r"[,\s]+")
47class PrintWrapper: # pylint: disable=too-few-public-methods
48 """Wrapper for print()."""
50 def __call__(self, *args: Any, **kwargs: Any) -> None: # noqa: D102
51 kwargs.setdefault("file", self._output)
52 print(*args, **kwargs)
54 def __init__(self, output: io.TextIOBase) -> None: # noqa: D107
55 self._output: io.TextIOBase = output
58class Backdoor(APIRequestHandler):
59 """The request handler for the backdoor API."""
61 POSSIBLE_CONTENT_TYPES: ClassVar[tuple[str, ...]] = (
62 "application/vnd.uqfoundation.dill",
63 "application/vnd.python.pickle",
64 "application/json",
65 "text/plain",
66 )
68 ALLOWED_METHODS: ClassVar[tuple[str, ...]] = ("POST",)
70 sessions: ClassVar[dict[str, dict[str, Any]]] = {}
72 async def backup_session(self) -> bool:
73 """Backup a session using Redis and return whether it succeeded."""
74 session_id = self.request.headers.get("X-Backdoor-Session")
75 if not (EVENT_REDIS.is_set() and session_id in self.sessions):
76 return False
77 session = self.sessions[session_id].copy()
78 session.pop("self", None)
79 session.pop("app", None)
80 session.pop("settings", None)
81 for key, value in tuple(session.items()):
82 try:
83 session[key] = pickletools.optimize(
84 dill.dumps(value, max(dill.DEFAULT_PROTOCOL, 5))
85 )
86 except BaseException: # pylint: disable=broad-except # noqa: B036
87 del session[key]
88 return bool(
89 await self.redis.setex(
90 f"{self.redis_prefix}:backdoor-session:{session_id}",
91 60 * 60 * 24 * 7, # time to live in seconds (1 week)
92 b85encode(pickletools.optimize(dill.dumps(session))),
93 )
94 )
96 def ensure_serializable(self, obj: Any) -> Any:
97 """Ensure that obj can be serialized."""
98 if self.serialize(obj) is None:
99 return self.safe_repr(obj)
100 return obj
102 def finish_serialized_dict(self, **kwargs: Any) -> Future[None]:
103 """Finish with a serialized dictionary."""
104 return self.finish(self.serialize(kwargs))
106 def get_flags(self, flags: int) -> int:
107 """Get compiler flags."""
108 import __future__ # pylint: disable=import-outside-toplevel
110 for feature in SEPARATOR.split(
111 self.request.headers.get("X-Future-Feature", "")
112 ):
113 if feature in __future__.all_feature_names:
114 flags |= getattr(__future__, feature).compiler_flag
116 return flags
118 def get_protocol_version(self) -> int:
119 """Get the protocol version for the response."""
120 try:
121 return min(
122 int(
123 self.request.headers.get("X-Pickle-Protocol"), base=0 # type: ignore[arg-type] # noqa: B950
124 ),
125 pickle.HIGHEST_PROTOCOL,
126 )
127 except (TypeError, ValueError):
128 return 5
130 async def load_session(self) -> dict[str, Any]:
131 """Load the backup of a session or create a new one."""
132 if not (session_id := self.request.headers.get("X-Backdoor-Session")):
133 session: dict[str, Any] = {
134 "__builtins__": __builtins__,
135 "__name__": "this",
136 }
137 elif session_id in self.sessions:
138 session = self.sessions[session_id]
139 else:
140 session_pickle = (
141 await self.redis.get(
142 f"{self.redis_prefix}:backdoor-session:{session_id}"
143 )
144 if EVENT_REDIS.is_set()
145 else None
146 )
147 if session_pickle:
148 session = dill.loads(b85decode(session_pickle)) # nosec: B301
149 for key, value in session.items():
150 try:
151 session[key] = dill.loads(value) # nosec: B301
152 except BaseException: # pylint: disable=broad-except # noqa: B036, B950 # fmt: skip
153 LOGGER.exception(
154 "Error while loading %r in session %r. Data: %r",
155 key,
156 session,
157 value.decode("BRAILLE"),
158 )
159 if self.apm_client:
160 self.apm_client.capture_exception() # type: ignore[no-untyped-call]
161 else:
162 session = {
163 "__builtins__": __builtins__,
164 "__name__": "this",
165 }
166 if pytest_is_running():
167 session["session_id"] = session_id
168 self.sessions[session_id] = session
169 self.update_session(session)
170 return session
172 @requires(Permission.BACKDOOR, allow_cookie_auth=False)
173 async def post(self, mode: str) -> None: # noqa: C901
174 # pylint: disable=too-complex, too-many-branches
175 # pylint: disable=too-many-statements
176 """Handle POST requests to the backdoor API."""
177 source, output = self.request.body, io.StringIO()
178 exception: None | BaseException = None
179 output_str: None | str
180 result: Any
181 try:
182 random = Random(335573788461)
183 parsed = compile(
184 source,
185 "",
186 mode,
187 self.get_flags(PyCF_ONLY_AST | PyCF_TYPE_COMMENTS),
188 cast(bool, 0x5F3759DF),
189 random.randrange(3),
190 _feature_version=12,
191 )
192 code = compile(
193 parsed,
194 "",
195 mode,
196 self.get_flags(PyCF_ALLOW_TOP_LEVEL_AWAIT),
197 cast(bool, 0x5F3759DF),
198 random.randrange(3),
199 _feature_version=12,
200 )
201 except SyntaxError as exc:
202 exception = exc
203 result = exc
204 else:
205 session = await self.load_session()
206 if "print" not in session or isinstance(
207 session["print"], PrintWrapper
208 ):
209 session["print"] = PrintWrapper(output)
210 if "help" not in session or isinstance(
211 session["help"], pydoc.Helper
212 ):
213 session["help"] = pydoc.Helper(io.StringIO(), output)
214 try:
215 try:
216 result = eval( # pylint: disable=eval-used # nosec: B307
217 code, session
218 )
219 if code.co_flags & CO_COROUTINE:
220 result = await result
221 except KeyboardInterrupt:
222 EVENT_SHUTDOWN.set()
223 raise SystemExit("Shutdown initiated.") from None
224 except SystemExit as exc:
225 if self.content_type == "text/plain":
226 return await self.finish(
227 traceback.format_exception_only(exc)[0]
228 )
229 session.pop("self", None)
230 session.pop("app", None)
231 session.pop("settings", None)
232 await self.backup_session()
233 exc.args = [self.ensure_serializable(arg) for arg in exc.args] # type: ignore[assignment] # noqa: B950
234 output_str = output.getvalue() if not output.closed else None
235 output.close()
236 return await self.finish_serialized_dict(
237 success=..., output=output_str, result=exc
238 )
239 except BaseException as exc: # pylint: disable=broad-except # noqa: B036, B950 # fmt: skip
240 exception = exc # pylint: disable=redefined-variable-type
241 result = exc
242 else:
243 if result is session.get("print") and isinstance(
244 result, PrintWrapper
245 ):
246 result = print
247 elif result is session.get("help") and isinstance(
248 result, pydoc.Helper
249 ):
250 result = help
251 if result is not None:
252 session["_"] = result
253 finally:
254 session.pop("self", None)
255 session.pop("app", None)
256 session.pop("settings", None)
257 await self.backup_session()
258 output_str = output.getvalue() if not output.closed else None
259 output.close()
260 exception_text = (
261 "".join(traceback.format_exception(exception)).strip()
262 if exception is not None
263 else None
264 )
265 if self.content_type == "text/plain":
266 if mode == "exec":
267 return await self.finish(exception_text or output_str)
268 return await self.finish(exception_text or self.safe_repr(result))
269 serialized_result = self.serialize(result)
270 result_tuple: tuple[None | str, None | bytes] = (
271 exception_text or self.safe_repr(result),
272 serialized_result,
273 )
274 return await self.finish_serialized_dict(
275 success=exception is None,
276 output=output_str,
277 result=(
278 None if exception is None and result is None else result_tuple
279 ),
280 )
282 def safe_repr(self, obj: Any) -> str: # pylint: disable=no-self-use
283 """Safe version of repr()."""
284 try:
285 return repr(obj)
286 except BaseException: # pylint: disable=broad-except # noqa: B036
287 return object.__repr__(obj)
289 def serialize(self, data: Any, protocol: None | int = None) -> None | bytes:
290 """Serialize the data and return it."""
291 try:
292 if self.content_type == "application/json":
293 return cast(bytes, jsonpickle.encode(data))
294 protocol = protocol or self.get_protocol_version()
295 if self.content_type == "application/vnd.uqfoundation.dill":
296 return cast(bytes, dill.dumps(data, protocol))
297 return pickle.dumps(data, protocol)
298 except BaseException: # pylint: disable=broad-except # noqa: B036
299 return None
301 def update_session(self, session: MutableMapping[str, Any]) -> None:
302 """Add request-specific stuff to the session."""
303 session.update(self=self, app=self.application, settings=self.settings)
305 def write_error(self, status_code: int, **kwargs: Any) -> None:
306 """Respond with error message."""
307 if self.content_type not in {
308 "application/vnd.python.pickle",
309 "application/vnd.uqfoundation.dill",
310 }:
311 super().write_error(status_code, **kwargs)
312 return
313 if "exc_info" in kwargs:
314 exc_info: tuple[
315 type[BaseException], BaseException, TracebackType
316 ] = kwargs["exc_info"]
317 if not issubclass(exc_info[0], HTTPError):
318 # pylint: disable=line-too-long
319 self.finish(self.serialize(self.get_error_message(**kwargs))) # type: ignore[unused-awaitable] # noqa: B950
320 return
321 self.finish(self.serialize((status_code, self._reason))) # type: ignore[unused-awaitable] # noqa: B950