Coverage for an_website/backdoor/backdoor.py: 83.140%

172 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-16 19:56 +0000

1# This program is free software: you can redistribute it and/or modify 

2# it under the terms of the GNU Affero General Public License as 

3# published by the Free Software Foundation, either version 3 of the 

4# License, or (at your option) any later version. 

5# 

6# This program is distributed in the hope that it will be useful, 

7# but WITHOUT ANY WARRANTY; without even the implied warranty of 

8# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

9# GNU Affero General Public License for more details. 

10# 

11# You should have received a copy of the GNU Affero General Public License 

12# along with this program. If not, see <https://www.gnu.org/licenses/>. 

13 

14"""The backdoor API of the website.""" 

15 

16from __future__ import annotations 

17 

18import io 

19import logging 

20import pickle # nosec: B403 

21import pickletools # nosec: B403 

22import pydoc 

23import traceback 

24from ast import PyCF_ALLOW_TOP_LEVEL_AWAIT, PyCF_ONLY_AST, PyCF_TYPE_COMMENTS 

25from asyncio import Future 

26from base64 import b85decode, b85encode 

27from collections.abc import MutableMapping 

28from inspect import CO_COROUTINE # pylint: disable=no-name-in-module 

29from random import Random 

30from types import TracebackType 

31from typing import Any, ClassVar, Final, cast 

32 

33import dill # type: ignore[import-untyped] # nosec: B403 

34import jsonpickle # type: ignore[import-untyped] 

35import regex 

36from tornado.web import HTTPError 

37 

38from .. import EVENT_REDIS, EVENT_SHUTDOWN, pytest_is_running 

39from ..utils.decorators import requires 

40from ..utils.request_handler import APIRequestHandler 

41from ..utils.utils import Permission 

42 

43LOGGER: Final = logging.getLogger(__name__) 

44SEPARATOR: Final = regex.compile(r"[,\s]+") 

45 

46 

47class PrintWrapper: # pylint: disable=too-few-public-methods 

48 """Wrapper for print().""" 

49 

50 def __call__(self, *args: Any, **kwargs: Any) -> None: # noqa: D102 

51 kwargs.setdefault("file", self._output) 

52 print(*args, **kwargs) 

53 

54 def __init__(self, output: io.TextIOBase) -> None: # noqa: D107 

55 self._output: io.TextIOBase = output 

56 

57 

58class Backdoor(APIRequestHandler): 

59 """The request handler for the backdoor API.""" 

60 

61 POSSIBLE_CONTENT_TYPES: ClassVar[tuple[str, ...]] = ( 

62 "application/vnd.uqfoundation.dill", 

63 "application/vnd.python.pickle", 

64 "application/json", 

65 "text/plain", 

66 ) 

67 

68 ALLOWED_METHODS: ClassVar[tuple[str, ...]] = ("POST",) 

69 

70 sessions: ClassVar[dict[str, dict[str, Any]]] = {} 

71 

72 async def backup_session(self) -> bool: 

73 """Backup a session using Redis and return whether it succeeded.""" 

74 session_id = self.request.headers.get("X-Backdoor-Session") 

75 if not (EVENT_REDIS.is_set() and session_id in self.sessions): 

76 return False 

77 session = self.sessions[session_id].copy() 

78 session.pop("self", None) 

79 session.pop("app", None) 

80 session.pop("settings", None) 

81 for key, value in tuple(session.items()): 

82 try: 

83 session[key] = pickletools.optimize( 

84 dill.dumps(value, max(dill.DEFAULT_PROTOCOL, 5)) 

85 ) 

86 except BaseException: # pylint: disable=broad-except # noqa: B036 

87 del session[key] 

88 return bool( 

89 await self.redis.setex( 

90 f"{self.redis_prefix}:backdoor-session:{session_id}", 

91 60 * 60 * 24 * 7, # time to live in seconds (1 week) 

92 b85encode(pickletools.optimize(dill.dumps(session))), 

93 ) 

94 ) 

95 

96 def ensure_serializable(self, obj: Any) -> Any: 

97 """Ensure that obj can be serialized.""" 

98 if self.serialize(obj) is None: 

99 return self.safe_repr(obj) 

100 return obj 

101 

102 def finish_serialized_dict(self, **kwargs: Any) -> Future[None]: 

103 """Finish with a serialized dictionary.""" 

104 return self.finish(self.serialize(kwargs)) 

105 

106 def get_flags(self, flags: int) -> int: 

107 """Get compiler flags.""" 

108 import __future__ # pylint: disable=import-outside-toplevel 

109 

110 for feature in SEPARATOR.split( 

111 self.request.headers.get("X-Future-Feature", "") 

112 ): 

113 if feature in __future__.all_feature_names: 

114 flags |= getattr(__future__, feature).compiler_flag 

115 

116 return flags 

117 

118 def get_protocol_version(self) -> int: 

119 """Get the protocol version for the response.""" 

120 try: 

121 return min( 

122 int( 

123 self.request.headers.get("X-Pickle-Protocol"), base=0 # type: ignore[arg-type] # noqa: B950 

124 ), 

125 pickle.HIGHEST_PROTOCOL, 

126 ) 

127 except (TypeError, ValueError): 

128 return 5 

129 

130 async def load_session(self) -> dict[str, Any]: 

131 """Load the backup of a session or create a new one.""" 

132 if not (session_id := self.request.headers.get("X-Backdoor-Session")): 

133 session: dict[str, Any] = { 

134 "__builtins__": __builtins__, 

135 "__name__": "this", 

136 } 

137 elif session_id in self.sessions: 

138 session = self.sessions[session_id] 

139 else: 

140 session_pickle = ( 

141 await self.redis.get( 

142 f"{self.redis_prefix}:backdoor-session:{session_id}" 

143 ) 

144 if EVENT_REDIS.is_set() 

145 else None 

146 ) 

147 if session_pickle: 

148 session = dill.loads(b85decode(session_pickle)) # nosec: B301 

149 for key, value in session.items(): 

150 try: 

151 session[key] = dill.loads(value) # nosec: B301 

152 except BaseException: # pylint: disable=broad-except # noqa: B036, B950 # fmt: skip 

153 LOGGER.exception( 

154 "Error while loading %r in session %r. Data: %r", 

155 key, 

156 session, 

157 value.decode("BRAILLE"), 

158 ) 

159 if self.apm_client: 

160 self.apm_client.capture_exception() # type: ignore[no-untyped-call] 

161 else: 

162 session = { 

163 "__builtins__": __builtins__, 

164 "__name__": "this", 

165 } 

166 if pytest_is_running(): 

167 session["session_id"] = session_id 

168 self.sessions[session_id] = session 

169 self.update_session(session) 

170 return session 

171 

172 @requires(Permission.BACKDOOR, allow_cookie_auth=False) 

173 async def post(self, mode: str) -> None: # noqa: C901 

174 # pylint: disable=too-complex, too-many-branches 

175 # pylint: disable=too-many-statements 

176 """Handle POST requests to the backdoor API.""" 

177 source, output = self.request.body, io.StringIO() 

178 exception: None | BaseException = None 

179 output_str: None | str 

180 result: Any 

181 try: 

182 random = Random(335573788461) 

183 parsed = compile( 

184 source, 

185 "", 

186 mode, 

187 self.get_flags(PyCF_ONLY_AST | PyCF_TYPE_COMMENTS), 

188 cast(bool, 0x5F3759DF), 

189 random.randrange(3), 

190 _feature_version=12, 

191 ) 

192 code = compile( 

193 parsed, 

194 "", 

195 mode, 

196 self.get_flags(PyCF_ALLOW_TOP_LEVEL_AWAIT), 

197 cast(bool, 0x5F3759DF), 

198 random.randrange(3), 

199 _feature_version=12, 

200 ) 

201 except SyntaxError as exc: 

202 exception = exc 

203 result = exc 

204 else: 

205 session = await self.load_session() 

206 if "print" not in session or isinstance( 

207 session["print"], PrintWrapper 

208 ): 

209 session["print"] = PrintWrapper(output) 

210 if "help" not in session or isinstance( 

211 session["help"], pydoc.Helper 

212 ): 

213 session["help"] = pydoc.Helper(io.StringIO(), output) 

214 try: 

215 try: 

216 result = eval( # pylint: disable=eval-used # nosec: B307 

217 code, session 

218 ) 

219 if code.co_flags & CO_COROUTINE: 

220 result = await result 

221 except KeyboardInterrupt: 

222 EVENT_SHUTDOWN.set() 

223 raise SystemExit("Shutdown initiated.") from None 

224 except SystemExit as exc: 

225 if self.content_type == "text/plain": 

226 return await self.finish( 

227 traceback.format_exception_only(exc)[0] 

228 ) 

229 session.pop("self", None) 

230 session.pop("app", None) 

231 session.pop("settings", None) 

232 await self.backup_session() 

233 exc.args = [self.ensure_serializable(arg) for arg in exc.args] # type: ignore[assignment] # noqa: B950 

234 output_str = output.getvalue() if not output.closed else None 

235 output.close() 

236 return await self.finish_serialized_dict( 

237 success=..., output=output_str, result=exc 

238 ) 

239 except BaseException as exc: # pylint: disable=broad-except # noqa: B036, B950 # fmt: skip 

240 exception = exc # pylint: disable=redefined-variable-type 

241 result = exc 

242 else: 

243 if result is session.get("print") and isinstance( 

244 result, PrintWrapper 

245 ): 

246 result = print 

247 elif result is session.get("help") and isinstance( 

248 result, pydoc.Helper 

249 ): 

250 result = help 

251 if result is not None: 

252 session["_"] = result 

253 finally: 

254 session.pop("self", None) 

255 session.pop("app", None) 

256 session.pop("settings", None) 

257 await self.backup_session() 

258 output_str = output.getvalue() if not output.closed else None 

259 output.close() 

260 exception_text = ( 

261 "".join(traceback.format_exception(exception)).strip() 

262 if exception is not None 

263 else None 

264 ) 

265 if self.content_type == "text/plain": 

266 if mode == "exec": 

267 return await self.finish(exception_text or output_str) 

268 return await self.finish(exception_text or self.safe_repr(result)) 

269 serialized_result = self.serialize(result) 

270 result_tuple: tuple[None | str, None | bytes] = ( 

271 exception_text or self.safe_repr(result), 

272 serialized_result, 

273 ) 

274 return await self.finish_serialized_dict( 

275 success=exception is None, 

276 output=output_str, 

277 result=( 

278 None if exception is None and result is None else result_tuple 

279 ), 

280 ) 

281 

282 def safe_repr(self, obj: Any) -> str: # pylint: disable=no-self-use 

283 """Safe version of repr().""" 

284 try: 

285 return repr(obj) 

286 except BaseException: # pylint: disable=broad-except # noqa: B036 

287 return object.__repr__(obj) 

288 

289 def serialize(self, data: Any, protocol: None | int = None) -> None | bytes: 

290 """Serialize the data and return it.""" 

291 try: 

292 if self.content_type == "application/json": 

293 return cast(bytes, jsonpickle.encode(data)) 

294 protocol = protocol or self.get_protocol_version() 

295 if self.content_type == "application/vnd.uqfoundation.dill": 

296 return cast(bytes, dill.dumps(data, protocol)) 

297 return pickle.dumps(data, protocol) 

298 except BaseException: # pylint: disable=broad-except # noqa: B036 

299 return None 

300 

301 def update_session(self, session: MutableMapping[str, Any]) -> None: 

302 """Add request-specific stuff to the session.""" 

303 session.update(self=self, app=self.application, settings=self.settings) 

304 

305 def write_error(self, status_code: int, **kwargs: Any) -> None: 

306 """Respond with error message.""" 

307 if self.content_type not in { 

308 "application/vnd.python.pickle", 

309 "application/vnd.uqfoundation.dill", 

310 }: 

311 super().write_error(status_code, **kwargs) 

312 return 

313 if "exc_info" in kwargs: 

314 exc_info: tuple[ 

315 type[BaseException], BaseException, TracebackType 

316 ] = kwargs["exc_info"] 

317 if not issubclass(exc_info[0], HTTPError): 

318 # pylint: disable=line-too-long 

319 self.finish(self.serialize(self.get_error_message(**kwargs))) # type: ignore[unused-awaitable] # noqa: B950 

320 return 

321 self.finish(self.serialize((status_code, self._reason))) # type: ignore[unused-awaitable] # noqa: B950