Coverage for an_website/backdoor/backdoor.py: 83.626%

171 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-10 18:56 +0000

1# This program is free software: you can redistribute it and/or modify 

2# it under the terms of the GNU Affero General Public License as 

3# published by the Free Software Foundation, either version 3 of the 

4# License, or (at your option) any later version. 

5# 

6# This program is distributed in the hope that it will be useful, 

7# but WITHOUT ANY WARRANTY; without even the implied warranty of 

8# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

9# GNU Affero General Public License for more details. 

10# 

11# You should have received a copy of the GNU Affero General Public License 

12# along with this program. If not, see <https://www.gnu.org/licenses/>. 

13 

14"""The backdoor API of the website.""" 

15 

16import io 

17import logging 

18import pickle # nosec: B403 

19import pickletools # nosec: B403 

20import pydoc 

21import traceback 

22from ast import PyCF_ALLOW_TOP_LEVEL_AWAIT, PyCF_ONLY_AST, PyCF_TYPE_COMMENTS 

23from asyncio import Future 

24from base64 import b85decode, b85encode 

25from collections.abc import MutableMapping 

26from inspect import CO_COROUTINE # pylint: disable=no-name-in-module 

27from random import Random 

28from types import TracebackType 

29from typing import Any, ClassVar, Final, cast 

30 

31import dill # type: ignore[import-untyped] # nosec: B403 

32import jsonpickle # type: ignore[import-untyped] 

33import regex 

34from tornado.web import HTTPError 

35 

36from .. import EVENT_REDIS, EVENT_SHUTDOWN, pytest_is_running 

37from ..utils.decorators import requires 

38from ..utils.request_handler import APIRequestHandler 

39from ..utils.utils import Permission 

40 

41LOGGER: Final = logging.getLogger(__name__) 

42SEPARATOR: Final = regex.compile(r"[,\s]+") 

43 

44 

45class PrintWrapper: # pylint: disable=too-few-public-methods 

46 """Wrapper for print().""" 

47 

48 def __call__(self, *args: Any, **kwargs: Any) -> None: # noqa: D102 

49 kwargs.setdefault("file", self._output) 

50 print(*args, **kwargs) 

51 

52 def __init__(self, output: io.TextIOBase) -> None: # noqa: D107 

53 self._output: io.TextIOBase = output 

54 

55 

56class Backdoor(APIRequestHandler): 

57 """The request handler for the backdoor API.""" 

58 

59 POSSIBLE_CONTENT_TYPES: ClassVar[tuple[str, ...]] = ( 

60 "application/vnd.uqfoundation.dill", 

61 "application/vnd.python.pickle", 

62 "application/json", 

63 "text/plain", 

64 ) 

65 

66 ALLOWED_METHODS: ClassVar[tuple[str, ...]] = ("POST",) 

67 

68 sessions: ClassVar[dict[str, dict[str, Any]]] = {} 

69 

70 async def backup_session(self) -> bool: 

71 """Backup a session using Redis and return whether it succeeded.""" 

72 session_id = self.request.headers.get("X-Backdoor-Session") 

73 if not (EVENT_REDIS.is_set() and session_id in self.sessions): 

74 return False 

75 session = self.sessions[session_id].copy() 

76 session.pop("self", None) 

77 session.pop("app", None) 

78 session.pop("settings", None) 

79 for key, value in tuple(session.items()): 

80 try: 

81 session[key] = pickletools.optimize( 

82 dill.dumps(value, max(dill.DEFAULT_PROTOCOL, 5)) 

83 ) 

84 except BaseException: # pylint: disable=broad-except # noqa: B036 

85 del session[key] 

86 return bool( 

87 await self.redis.setex( 

88 f"{self.redis_prefix}:backdoor-session:{session_id}", 

89 60 * 60 * 24 * 7, # time to live in seconds (1 week) 

90 b85encode(pickletools.optimize(dill.dumps(session))), 

91 ) 

92 ) 

93 

94 def ensure_serializable(self, obj: Any) -> Any: 

95 """Ensure that obj can be serialized.""" 

96 if self.serialize(obj) is None: 

97 return self.safe_repr(obj) 

98 return obj 

99 

100 def finish_serialized_dict(self, **kwargs: Any) -> Future[None]: 

101 """Finish with a serialized dictionary.""" 

102 return self.finish(self.serialize(kwargs)) 

103 

104 def get_flags(self, flags: int) -> int: 

105 """Get compiler flags.""" 

106 import __future__ # pylint: disable=import-outside-toplevel 

107 

108 for feature in SEPARATOR.split( 

109 self.request.headers.get("X-Future-Feature", "") 

110 ): 

111 if feature in __future__.all_feature_names: 

112 flags |= getattr(__future__, feature).compiler_flag 

113 

114 return flags 

115 

116 def get_protocol_version(self) -> int: 

117 """Get the protocol version for the response.""" 

118 try: 

119 return min( 

120 int( 

121 self.request.headers.get("X-Pickle-Protocol"), base=0 # type: ignore[arg-type] # noqa: B950 

122 ), 

123 pickle.HIGHEST_PROTOCOL, 

124 ) 

125 except TypeError, ValueError: 

126 return 5 

127 

128 async def load_session(self) -> dict[str, Any]: 

129 """Load the backup of a session or create a new one.""" 

130 if not (session_id := self.request.headers.get("X-Backdoor-Session")): 

131 session: dict[str, Any] = { 

132 "__builtins__": __builtins__, 

133 "__name__": "this", 

134 } 

135 elif session_id in self.sessions: 

136 session = self.sessions[session_id] 

137 else: 

138 session_pickle = ( 

139 await self.redis.get( 

140 f"{self.redis_prefix}:backdoor-session:{session_id}" 

141 ) 

142 if EVENT_REDIS.is_set() 

143 else None 

144 ) 

145 if session_pickle: 

146 session = dill.loads(b85decode(session_pickle)) # nosec: B301 

147 for key, value in session.items(): 

148 try: 

149 session[key] = dill.loads(value) # nosec: B301 

150 except BaseException: # pylint: disable=broad-except # noqa: B036, B950 # fmt: skip 

151 LOGGER.exception( 

152 "Error while loading %r in session %r. Data: %r", 

153 key, 

154 session, 

155 value.decode("BRAILLE"), 

156 ) 

157 if self.apm_client: 

158 self.apm_client.capture_exception() # type: ignore[no-untyped-call] 

159 else: 

160 session = { 

161 "__builtins__": __builtins__, 

162 "__name__": "this", 

163 } 

164 if pytest_is_running(): 

165 session["session_id"] = session_id 

166 self.sessions[session_id] = session 

167 self.update_session(session) 

168 return session 

169 

170 @requires(Permission.BACKDOOR, allow_cookie_auth=False) 

171 async def post(self, mode: str) -> None: # noqa: C901 

172 # pylint: disable=too-complex, too-many-branches 

173 # pylint: disable=too-many-statements 

174 """Handle POST requests to the backdoor API.""" 

175 source, output = self.request.body, io.StringIO() 

176 exception: None | BaseException = None 

177 output_str: None | str 

178 result: Any 

179 try: 

180 random = Random(335573788461) 

181 parsed = compile( 

182 source, 

183 "", 

184 mode, 

185 self.get_flags(PyCF_ONLY_AST | PyCF_TYPE_COMMENTS), 

186 cast(bool, 0x5F3759DF), 

187 random.randrange(3), 

188 _feature_version=12, 

189 ) 

190 code = compile( 

191 parsed, 

192 "", 

193 mode, 

194 self.get_flags(PyCF_ALLOW_TOP_LEVEL_AWAIT), 

195 cast(bool, 0x5F3759DF), 

196 random.randrange(3), 

197 _feature_version=12, 

198 ) 

199 except SyntaxError as exc: 

200 exception = exc 

201 result = exc 

202 else: 

203 session = await self.load_session() 

204 if "print" not in session or isinstance( 

205 session["print"], PrintWrapper 

206 ): 

207 session["print"] = PrintWrapper(output) 

208 if "help" not in session or isinstance( 

209 session["help"], pydoc.Helper 

210 ): 

211 session["help"] = pydoc.Helper(io.StringIO(), output) 

212 try: 

213 try: 

214 result = eval( # pylint: disable=eval-used # nosec: B307 

215 code, session 

216 ) 

217 if code.co_flags & CO_COROUTINE: 

218 result = await result 

219 except KeyboardInterrupt: 

220 EVENT_SHUTDOWN.set() 

221 raise SystemExit("Shutdown initiated.") from None 

222 except SystemExit as exc: 

223 if self.content_type == "text/plain": 

224 return await self.finish( 

225 traceback.format_exception_only(exc)[0] 

226 ) 

227 session.pop("self", None) 

228 session.pop("app", None) 

229 session.pop("settings", None) 

230 await self.backup_session() 

231 exc.args = [self.ensure_serializable(arg) for arg in exc.args] # type: ignore[assignment] # noqa: B950 

232 output_str = output.getvalue() if not output.closed else None 

233 output.close() 

234 return await self.finish_serialized_dict( 

235 success=..., output=output_str, result=exc 

236 ) 

237 except BaseException as exc: # pylint: disable=broad-except # noqa: B036, B950 # fmt: skip 

238 exception = exc # pylint: disable=redefined-variable-type 

239 result = exc 

240 else: 

241 if result is session.get("print") and isinstance( 

242 result, PrintWrapper 

243 ): 

244 result = print 

245 elif result is session.get("help") and isinstance( 

246 result, pydoc.Helper 

247 ): 

248 result = help 

249 if result is not None: 

250 session["_"] = result 

251 finally: 

252 session.pop("self", None) 

253 session.pop("app", None) 

254 session.pop("settings", None) 

255 await self.backup_session() 

256 output_str = output.getvalue() if not output.closed else None 

257 output.close() 

258 exception_text = ( 

259 "".join(traceback.format_exception(exception)).strip() 

260 if exception is not None 

261 else None 

262 ) 

263 if self.content_type == "text/plain": 

264 if mode == "exec": 

265 return await self.finish(exception_text or output_str) 

266 return await self.finish(exception_text or self.safe_repr(result)) 

267 serialized_result = self.serialize(result) 

268 result_tuple: tuple[None | str, None | bytes] = ( 

269 exception_text or self.safe_repr(result), 

270 serialized_result, 

271 ) 

272 return await self.finish_serialized_dict( 

273 success=exception is None, 

274 output=output_str, 

275 result=( 

276 None if exception is None and result is None else result_tuple 

277 ), 

278 ) 

279 

280 def safe_repr(self, obj: Any) -> str: # pylint: disable=no-self-use 

281 """Safe version of repr().""" 

282 try: 

283 return repr(obj) 

284 except BaseException: # pylint: disable=broad-except # noqa: B036 

285 return object.__repr__(obj) 

286 

287 def serialize(self, data: Any, protocol: None | int = None) -> None | bytes: 

288 """Serialize the data and return it.""" 

289 try: 

290 if self.content_type == "application/json": 

291 return cast(bytes, jsonpickle.encode(data)) 

292 protocol = protocol or self.get_protocol_version() 

293 if self.content_type == "application/vnd.uqfoundation.dill": 

294 return cast(bytes, dill.dumps(data, protocol)) 

295 return pickle.dumps(data, protocol) 

296 except BaseException: # pylint: disable=broad-except # noqa: B036 

297 return None 

298 

299 def update_session(self, session: MutableMapping[str, Any]) -> None: 

300 """Add request-specific stuff to the session.""" 

301 session.update(self=self, app=self.application, settings=self.settings) 

302 

303 def write_error(self, status_code: int, **kwargs: Any) -> None: 

304 """Respond with error message.""" 

305 if self.content_type not in { 

306 "application/vnd.python.pickle", 

307 "application/vnd.uqfoundation.dill", 

308 }: 

309 super().write_error(status_code, **kwargs) 

310 return 

311 if "exc_info" in kwargs: 

312 exc_info: tuple[ 

313 type[BaseException], BaseException, TracebackType 

314 ] = kwargs["exc_info"] 

315 if not issubclass(exc_info[0], HTTPError): 

316 # pylint: disable=line-too-long 

317 self.finish(self.serialize(self.get_error_message(**kwargs))) # type: ignore[unused-awaitable] # noqa: B950 

318 return 

319 self.finish(self.serialize((status_code, self._reason))) # type: ignore[unused-awaitable] # noqa: B950