Coverage for an_website/backdoor/backdoor.py: 83.626%
171 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-10 18:56 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-10 18:56 +0000
1# This program is free software: you can redistribute it and/or modify
2# it under the terms of the GNU Affero General Public License as
3# published by the Free Software Foundation, either version 3 of the
4# License, or (at your option) any later version.
5#
6# This program is distributed in the hope that it will be useful,
7# but WITHOUT ANY WARRANTY; without even the implied warranty of
8# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9# GNU Affero General Public License for more details.
10#
11# You should have received a copy of the GNU Affero General Public License
12# along with this program. If not, see <https://www.gnu.org/licenses/>.
14"""The backdoor API of the website."""
16import io
17import logging
18import pickle # nosec: B403
19import pickletools # nosec: B403
20import pydoc
21import traceback
22from ast import PyCF_ALLOW_TOP_LEVEL_AWAIT, PyCF_ONLY_AST, PyCF_TYPE_COMMENTS
23from asyncio import Future
24from base64 import b85decode, b85encode
25from collections.abc import MutableMapping
26from inspect import CO_COROUTINE # pylint: disable=no-name-in-module
27from random import Random
28from types import TracebackType
29from typing import Any, ClassVar, Final, cast
31import dill # type: ignore[import-untyped] # nosec: B403
32import jsonpickle # type: ignore[import-untyped]
33import regex
34from tornado.web import HTTPError
36from .. import EVENT_REDIS, EVENT_SHUTDOWN, pytest_is_running
37from ..utils.decorators import requires
38from ..utils.request_handler import APIRequestHandler
39from ..utils.utils import Permission
41LOGGER: Final = logging.getLogger(__name__)
42SEPARATOR: Final = regex.compile(r"[,\s]+")
45class PrintWrapper: # pylint: disable=too-few-public-methods
46 """Wrapper for print()."""
48 def __call__(self, *args: Any, **kwargs: Any) -> None: # noqa: D102
49 kwargs.setdefault("file", self._output)
50 print(*args, **kwargs)
52 def __init__(self, output: io.TextIOBase) -> None: # noqa: D107
53 self._output: io.TextIOBase = output
56class Backdoor(APIRequestHandler):
57 """The request handler for the backdoor API."""
59 POSSIBLE_CONTENT_TYPES: ClassVar[tuple[str, ...]] = (
60 "application/vnd.uqfoundation.dill",
61 "application/vnd.python.pickle",
62 "application/json",
63 "text/plain",
64 )
66 ALLOWED_METHODS: ClassVar[tuple[str, ...]] = ("POST",)
68 sessions: ClassVar[dict[str, dict[str, Any]]] = {}
70 async def backup_session(self) -> bool:
71 """Backup a session using Redis and return whether it succeeded."""
72 session_id = self.request.headers.get("X-Backdoor-Session")
73 if not (EVENT_REDIS.is_set() and session_id in self.sessions):
74 return False
75 session = self.sessions[session_id].copy()
76 session.pop("self", None)
77 session.pop("app", None)
78 session.pop("settings", None)
79 for key, value in tuple(session.items()):
80 try:
81 session[key] = pickletools.optimize(
82 dill.dumps(value, max(dill.DEFAULT_PROTOCOL, 5))
83 )
84 except BaseException: # pylint: disable=broad-except # noqa: B036
85 del session[key]
86 return bool(
87 await self.redis.setex(
88 f"{self.redis_prefix}:backdoor-session:{session_id}",
89 60 * 60 * 24 * 7, # time to live in seconds (1 week)
90 b85encode(pickletools.optimize(dill.dumps(session))),
91 )
92 )
94 def ensure_serializable(self, obj: Any) -> Any:
95 """Ensure that obj can be serialized."""
96 if self.serialize(obj) is None:
97 return self.safe_repr(obj)
98 return obj
100 def finish_serialized_dict(self, **kwargs: Any) -> Future[None]:
101 """Finish with a serialized dictionary."""
102 return self.finish(self.serialize(kwargs))
104 def get_flags(self, flags: int) -> int:
105 """Get compiler flags."""
106 import __future__ # pylint: disable=import-outside-toplevel
108 for feature in SEPARATOR.split(
109 self.request.headers.get("X-Future-Feature", "")
110 ):
111 if feature in __future__.all_feature_names:
112 flags |= getattr(__future__, feature).compiler_flag
114 return flags
116 def get_protocol_version(self) -> int:
117 """Get the protocol version for the response."""
118 try:
119 return min(
120 int(
121 self.request.headers.get("X-Pickle-Protocol"), base=0 # type: ignore[arg-type] # noqa: B950
122 ),
123 pickle.HIGHEST_PROTOCOL,
124 )
125 except TypeError, ValueError:
126 return 5
128 async def load_session(self) -> dict[str, Any]:
129 """Load the backup of a session or create a new one."""
130 if not (session_id := self.request.headers.get("X-Backdoor-Session")):
131 session: dict[str, Any] = {
132 "__builtins__": __builtins__,
133 "__name__": "this",
134 }
135 elif session_id in self.sessions:
136 session = self.sessions[session_id]
137 else:
138 session_pickle = (
139 await self.redis.get(
140 f"{self.redis_prefix}:backdoor-session:{session_id}"
141 )
142 if EVENT_REDIS.is_set()
143 else None
144 )
145 if session_pickle:
146 session = dill.loads(b85decode(session_pickle)) # nosec: B301
147 for key, value in session.items():
148 try:
149 session[key] = dill.loads(value) # nosec: B301
150 except BaseException: # pylint: disable=broad-except # noqa: B036, B950 # fmt: skip
151 LOGGER.exception(
152 "Error while loading %r in session %r. Data: %r",
153 key,
154 session,
155 value.decode("BRAILLE"),
156 )
157 if self.apm_client:
158 self.apm_client.capture_exception() # type: ignore[no-untyped-call]
159 else:
160 session = {
161 "__builtins__": __builtins__,
162 "__name__": "this",
163 }
164 if pytest_is_running():
165 session["session_id"] = session_id
166 self.sessions[session_id] = session
167 self.update_session(session)
168 return session
170 @requires(Permission.BACKDOOR, allow_cookie_auth=False)
171 async def post(self, mode: str) -> None: # noqa: C901
172 # pylint: disable=too-complex, too-many-branches
173 # pylint: disable=too-many-statements
174 """Handle POST requests to the backdoor API."""
175 source, output = self.request.body, io.StringIO()
176 exception: None | BaseException = None
177 output_str: None | str
178 result: Any
179 try:
180 random = Random(335573788461)
181 parsed = compile(
182 source,
183 "",
184 mode,
185 self.get_flags(PyCF_ONLY_AST | PyCF_TYPE_COMMENTS),
186 cast(bool, 0x5F3759DF),
187 random.randrange(3),
188 _feature_version=12,
189 )
190 code = compile(
191 parsed,
192 "",
193 mode,
194 self.get_flags(PyCF_ALLOW_TOP_LEVEL_AWAIT),
195 cast(bool, 0x5F3759DF),
196 random.randrange(3),
197 _feature_version=12,
198 )
199 except SyntaxError as exc:
200 exception = exc
201 result = exc
202 else:
203 session = await self.load_session()
204 if "print" not in session or isinstance(
205 session["print"], PrintWrapper
206 ):
207 session["print"] = PrintWrapper(output)
208 if "help" not in session or isinstance(
209 session["help"], pydoc.Helper
210 ):
211 session["help"] = pydoc.Helper(io.StringIO(), output)
212 try:
213 try:
214 result = eval( # pylint: disable=eval-used # nosec: B307
215 code, session
216 )
217 if code.co_flags & CO_COROUTINE:
218 result = await result
219 except KeyboardInterrupt:
220 EVENT_SHUTDOWN.set()
221 raise SystemExit("Shutdown initiated.") from None
222 except SystemExit as exc:
223 if self.content_type == "text/plain":
224 return await self.finish(
225 traceback.format_exception_only(exc)[0]
226 )
227 session.pop("self", None)
228 session.pop("app", None)
229 session.pop("settings", None)
230 await self.backup_session()
231 exc.args = [self.ensure_serializable(arg) for arg in exc.args] # type: ignore[assignment] # noqa: B950
232 output_str = output.getvalue() if not output.closed else None
233 output.close()
234 return await self.finish_serialized_dict(
235 success=..., output=output_str, result=exc
236 )
237 except BaseException as exc: # pylint: disable=broad-except # noqa: B036, B950 # fmt: skip
238 exception = exc # pylint: disable=redefined-variable-type
239 result = exc
240 else:
241 if result is session.get("print") and isinstance(
242 result, PrintWrapper
243 ):
244 result = print
245 elif result is session.get("help") and isinstance(
246 result, pydoc.Helper
247 ):
248 result = help
249 if result is not None:
250 session["_"] = result
251 finally:
252 session.pop("self", None)
253 session.pop("app", None)
254 session.pop("settings", None)
255 await self.backup_session()
256 output_str = output.getvalue() if not output.closed else None
257 output.close()
258 exception_text = (
259 "".join(traceback.format_exception(exception)).strip()
260 if exception is not None
261 else None
262 )
263 if self.content_type == "text/plain":
264 if mode == "exec":
265 return await self.finish(exception_text or output_str)
266 return await self.finish(exception_text or self.safe_repr(result))
267 serialized_result = self.serialize(result)
268 result_tuple: tuple[None | str, None | bytes] = (
269 exception_text or self.safe_repr(result),
270 serialized_result,
271 )
272 return await self.finish_serialized_dict(
273 success=exception is None,
274 output=output_str,
275 result=(
276 None if exception is None and result is None else result_tuple
277 ),
278 )
280 def safe_repr(self, obj: Any) -> str: # pylint: disable=no-self-use
281 """Safe version of repr()."""
282 try:
283 return repr(obj)
284 except BaseException: # pylint: disable=broad-except # noqa: B036
285 return object.__repr__(obj)
287 def serialize(self, data: Any, protocol: None | int = None) -> None | bytes:
288 """Serialize the data and return it."""
289 try:
290 if self.content_type == "application/json":
291 return cast(bytes, jsonpickle.encode(data))
292 protocol = protocol or self.get_protocol_version()
293 if self.content_type == "application/vnd.uqfoundation.dill":
294 return cast(bytes, dill.dumps(data, protocol))
295 return pickle.dumps(data, protocol)
296 except BaseException: # pylint: disable=broad-except # noqa: B036
297 return None
299 def update_session(self, session: MutableMapping[str, Any]) -> None:
300 """Add request-specific stuff to the session."""
301 session.update(self=self, app=self.application, settings=self.settings)
303 def write_error(self, status_code: int, **kwargs: Any) -> None:
304 """Respond with error message."""
305 if self.content_type not in {
306 "application/vnd.python.pickle",
307 "application/vnd.uqfoundation.dill",
308 }:
309 super().write_error(status_code, **kwargs)
310 return
311 if "exc_info" in kwargs:
312 exc_info: tuple[
313 type[BaseException], BaseException, TracebackType
314 ] = kwargs["exc_info"]
315 if not issubclass(exc_info[0], HTTPError):
316 # pylint: disable=line-too-long
317 self.finish(self.serialize(self.get_error_message(**kwargs))) # type: ignore[unused-awaitable] # noqa: B950
318 return
319 self.finish(self.serialize((status_code, self._reason))) # type: ignore[unused-awaitable] # noqa: B950