Coverage for an_website / backdoor / backdoor.py: 83.626%
171 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-15 14:36 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-15 14:36 +0000
1# This program is free software: you can redistribute it and/or modify
2# it under the terms of the GNU Affero General Public License as
3# published by the Free Software Foundation, either version 3 of the
4# License, or (at your option) any later version.
5#
6# This program is distributed in the hope that it will be useful,
7# but WITHOUT ANY WARRANTY; without even the implied warranty of
8# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9# GNU Affero General Public License for more details.
10#
11# You should have received a copy of the GNU Affero General Public License
12# along with this program. If not, see <https://www.gnu.org/licenses/>.
14"""The backdoor API of the website."""
17import io
18import logging
19import pickle # nosec: B403
20import pickletools # nosec: B403
21import pydoc
22import traceback
23from ast import PyCF_ALLOW_TOP_LEVEL_AWAIT, PyCF_ONLY_AST, PyCF_TYPE_COMMENTS
24from asyncio import Future
25from base64 import b85decode, b85encode
26from collections.abc import MutableMapping
27from inspect import CO_COROUTINE # pylint: disable=no-name-in-module
28from random import Random
29from types import TracebackType
30from typing import Any, ClassVar, Final, cast
32import dill # type: ignore[import-untyped] # nosec: B403
33import jsonpickle # type: ignore[import-untyped]
34import regex
35from tornado.web import HTTPError
37from .. import EVENT_REDIS, EVENT_SHUTDOWN, pytest_is_running
38from ..utils.decorators import requires
39from ..utils.request_handler import APIRequestHandler
40from ..utils.utils import Permission
42LOGGER: Final = logging.getLogger(__name__)
43SEPARATOR: Final = regex.compile(r"[,\s]+")
46class PrintWrapper: # pylint: disable=too-few-public-methods
47 """Wrapper for print()."""
49 def __call__(self, *args: Any, **kwargs: Any) -> None: # noqa: D102
50 kwargs.setdefault("file", self._output)
51 print(*args, **kwargs)
53 def __init__(self, output: io.TextIOBase) -> None: # noqa: D107
54 self._output: io.TextIOBase = output
57class Backdoor(APIRequestHandler):
58 """The request handler for the backdoor API."""
60 POSSIBLE_CONTENT_TYPES: ClassVar[tuple[str, ...]] = (
61 "application/vnd.uqfoundation.dill",
62 "application/vnd.python.pickle",
63 "application/json",
64 "text/plain",
65 )
67 ALLOWED_METHODS: ClassVar[tuple[str, ...]] = ("POST",)
69 sessions: ClassVar[dict[str, dict[str, Any]]] = {}
71 async def backup_session(self) -> bool:
72 """Backup a session using Redis and return whether it succeeded."""
73 session_id = self.request.headers.get("X-Backdoor-Session")
74 if not (EVENT_REDIS.is_set() and session_id in self.sessions):
75 return False
76 session = self.sessions[session_id].copy()
77 session.pop("self", None)
78 session.pop("app", None)
79 session.pop("settings", None)
80 for key, value in tuple(session.items()):
81 try:
82 session[key] = pickletools.optimize(
83 dill.dumps(value, max(dill.DEFAULT_PROTOCOL, 5))
84 )
85 except BaseException: # pylint: disable=broad-except # noqa: B036
86 del session[key]
87 return bool(
88 await self.redis.setex(
89 f"{self.redis_prefix}:backdoor-session:{session_id}",
90 60 * 60 * 24 * 7, # time to live in seconds (1 week)
91 b85encode(pickletools.optimize(dill.dumps(session))),
92 )
93 )
95 def ensure_serializable(self, obj: Any) -> Any:
96 """Ensure that obj can be serialized."""
97 if self.serialize(obj) is None:
98 return self.safe_repr(obj)
99 return obj
101 def finish_serialized_dict(self, **kwargs: Any) -> Future[None]:
102 """Finish with a serialized dictionary."""
103 return self.finish(self.serialize(kwargs))
105 def get_flags(self, flags: int) -> int:
106 """Get compiler flags."""
107 import __future__ # pylint: disable=import-outside-toplevel
109 for feature in SEPARATOR.split(
110 self.request.headers.get("X-Future-Feature", "")
111 ):
112 if feature in __future__.all_feature_names:
113 flags |= getattr(__future__, feature).compiler_flag
115 return flags
117 def get_protocol_version(self) -> int:
118 """Get the protocol version for the response."""
119 try:
120 return min(
121 int(
122 self.request.headers.get("X-Pickle-Protocol"), base=0 # type: ignore[arg-type] # noqa: B950
123 ),
124 pickle.HIGHEST_PROTOCOL,
125 )
126 except (TypeError, ValueError):
127 return 5
129 async def load_session(self) -> dict[str, Any]:
130 """Load the backup of a session or create a new one."""
131 if not (session_id := self.request.headers.get("X-Backdoor-Session")):
132 session: dict[str, Any] = {
133 "__builtins__": __builtins__,
134 "__name__": "this",
135 }
136 elif session_id in self.sessions:
137 session = self.sessions[session_id]
138 else:
139 session_pickle = (
140 await self.redis.get(
141 f"{self.redis_prefix}:backdoor-session:{session_id}"
142 )
143 if EVENT_REDIS.is_set()
144 else None
145 )
146 if session_pickle:
147 session = dill.loads(b85decode(session_pickle)) # nosec: B301
148 for key, value in session.items():
149 try:
150 session[key] = dill.loads(value) # nosec: B301
151 except BaseException: # pylint: disable=broad-except # noqa: B036, B950 # fmt: skip
152 LOGGER.exception(
153 "Error while loading %r in session %r. Data: %r",
154 key,
155 session,
156 value.decode("BRAILLE"),
157 )
158 if self.apm_client:
159 self.apm_client.capture_exception() # type: ignore[no-untyped-call]
160 else:
161 session = {
162 "__builtins__": __builtins__,
163 "__name__": "this",
164 }
165 if pytest_is_running():
166 session["session_id"] = session_id
167 self.sessions[session_id] = session
168 self.update_session(session)
169 return session
171 @requires(Permission.BACKDOOR, allow_cookie_auth=False)
172 async def post(self, mode: str) -> None: # noqa: C901
173 # pylint: disable=too-complex, too-many-branches
174 # pylint: disable=too-many-statements
175 """Handle POST requests to the backdoor API."""
176 source, output = self.request.body, io.StringIO()
177 exception: None | BaseException = None
178 output_str: None | str
179 result: Any
180 try:
181 random = Random(335573788461)
182 parsed = compile(
183 source,
184 "",
185 mode,
186 self.get_flags(PyCF_ONLY_AST | PyCF_TYPE_COMMENTS),
187 cast(bool, 0x5F3759DF),
188 random.randrange(3),
189 _feature_version=12,
190 )
191 code = compile(
192 parsed,
193 "",
194 mode,
195 self.get_flags(PyCF_ALLOW_TOP_LEVEL_AWAIT),
196 cast(bool, 0x5F3759DF),
197 random.randrange(3),
198 _feature_version=12,
199 )
200 except SyntaxError as exc:
201 exception = exc
202 result = exc
203 else:
204 session = await self.load_session()
205 if "print" not in session or isinstance(
206 session["print"], PrintWrapper
207 ):
208 session["print"] = PrintWrapper(output)
209 if "help" not in session or isinstance(
210 session["help"], pydoc.Helper
211 ):
212 session["help"] = pydoc.Helper(io.StringIO(), output)
213 try:
214 try:
215 result = eval( # pylint: disable=eval-used # nosec: B307
216 code, session
217 )
218 if code.co_flags & CO_COROUTINE:
219 result = await result
220 except KeyboardInterrupt:
221 EVENT_SHUTDOWN.set()
222 raise SystemExit("Shutdown initiated.") from None
223 except SystemExit as exc:
224 if self.content_type == "text/plain":
225 return await self.finish(
226 traceback.format_exception_only(exc)[0]
227 )
228 session.pop("self", None)
229 session.pop("app", None)
230 session.pop("settings", None)
231 await self.backup_session()
232 exc.args = [self.ensure_serializable(arg) for arg in exc.args] # type: ignore[assignment] # noqa: B950
233 output_str = output.getvalue() if not output.closed else None
234 output.close()
235 return await self.finish_serialized_dict(
236 success=..., output=output_str, result=exc
237 )
238 except BaseException as exc: # pylint: disable=broad-except # noqa: B036, B950 # fmt: skip
239 exception = exc # pylint: disable=redefined-variable-type
240 result = exc
241 else:
242 if result is session.get("print") and isinstance(
243 result, PrintWrapper
244 ):
245 result = print
246 elif result is session.get("help") and isinstance(
247 result, pydoc.Helper
248 ):
249 result = help
250 if result is not None:
251 session["_"] = result
252 finally:
253 session.pop("self", None)
254 session.pop("app", None)
255 session.pop("settings", None)
256 await self.backup_session()
257 output_str = output.getvalue() if not output.closed else None
258 output.close()
259 exception_text = (
260 "".join(traceback.format_exception(exception)).strip()
261 if exception is not None
262 else None
263 )
264 if self.content_type == "text/plain":
265 if mode == "exec":
266 return await self.finish(exception_text or output_str)
267 return await self.finish(exception_text or self.safe_repr(result))
268 serialized_result = self.serialize(result)
269 result_tuple: tuple[None | str, None | bytes] = (
270 exception_text or self.safe_repr(result),
271 serialized_result,
272 )
273 return await self.finish_serialized_dict(
274 success=exception is None,
275 output=output_str,
276 result=(
277 None if exception is None and result is None else result_tuple
278 ),
279 )
281 def safe_repr(self, obj: Any) -> str: # pylint: disable=no-self-use
282 """Safe version of repr()."""
283 try:
284 return repr(obj)
285 except BaseException: # pylint: disable=broad-except # noqa: B036
286 return object.__repr__(obj)
288 def serialize(self, data: Any, protocol: None | int = None) -> None | bytes:
289 """Serialize the data and return it."""
290 try:
291 if self.content_type == "application/json":
292 return cast(bytes, jsonpickle.encode(data))
293 protocol = protocol or self.get_protocol_version()
294 if self.content_type == "application/vnd.uqfoundation.dill":
295 return cast(bytes, dill.dumps(data, protocol))
296 return pickle.dumps(data, protocol)
297 except BaseException: # pylint: disable=broad-except # noqa: B036
298 return None
300 def update_session(self, session: MutableMapping[str, Any]) -> None:
301 """Add request-specific stuff to the session."""
302 session.update(self=self, app=self.application, settings=self.settings)
304 def write_error(self, status_code: int, **kwargs: Any) -> None:
305 """Respond with error message."""
306 if self.content_type not in {
307 "application/vnd.python.pickle",
308 "application/vnd.uqfoundation.dill",
309 }:
310 super().write_error(status_code, **kwargs)
311 return
312 if "exc_info" in kwargs:
313 exc_info: tuple[
314 type[BaseException], BaseException, TracebackType
315 ] = kwargs["exc_info"]
316 if not issubclass(exc_info[0], HTTPError):
317 # pylint: disable=line-too-long
318 self.finish(self.serialize(self.get_error_message(**kwargs))) # type: ignore[unused-awaitable] # noqa: B950
319 return
320 self.finish(self.serialize((status_code, self._reason))) # type: ignore[unused-awaitable] # noqa: B950