Coverage for an_website/utils/static_file_from_traversable.py: 94.667%

150 statements  

« prev     ^ index     » next       coverage.py v7.10.6, created at 2025-09-01 02:01 +0000

1# This file is based on the StaticFileHandler from Tornado 

2# Source: https://github.com/tornadoweb/tornado/blob/b3f2a4bb6fb55f6b1b1e890cdd6332665cfe4a75/tornado/web.py # noqa: B950 # pylint: disable=line-too-long 

3# Licensed under the Apache License, Version 2.0 (the "License"); you may 

4# not use this file except in compliance with the License. You may obtain 

5# a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 

11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 

12# License for the specific language governing permissions and limitations 

13# under the License. 

14 

15"""A static file handler for the Traversable abc.""" 

16from __future__ import annotations 

17 

18import contextlib 

19import logging 

20import sys 

21from collections.abc import Awaitable, Iterable, Mapping, Sequence 

22from importlib.resources.abc import Traversable 

23from types import MappingProxyType 

24from typing import Any, Final, Literal, override 

25from urllib.parse import urlsplit, urlunsplit 

26 

27from tornado import httputil, iostream 

28from tornado.web import GZipContentEncoding, HTTPError 

29from typed_stream import Stream 

30 

31from an_website.utils.utils import size_of_file 

32 

33from .base_request_handler import _RequestHandler 

34from .static_file_handling import content_type_from_path 

35 

36type Encoding = Literal["gz", "zst"] 

37 

38LOGGER: Final = logging.getLogger(__name__) 

39 

40ENCODINGS: Final[Sequence[tuple[str, Encoding]]] = ( 

41 # better first 

42 ("zstd", "zst"), 

43 ("gzip", "gz"), 

44) 

45REVERSE_ENCODINGS_MAP: Mapping[Encoding, str] = { 

46 value: key for key, value in ENCODINGS 

47} 

48 

49 

50class TraversableStaticFileHandler(_RequestHandler): 

51 """A static file handler for the Traversable abc.""" 

52 

53 root: Traversable 

54 file_hashes: Mapping[str, str] = {} 

55 headers: Iterable[tuple[str, str]] = () 

56 

57 @override 

58 def compute_etag(self) -> None | str: 

59 """Return a pre-computed ETag.""" 

60 return self.file_hashes.get(self.request.path) 

61 

62 async def get(self, path: str, *, head: bool = False) -> None: # noqa: C901 

63 # pylint: disable=too-complex, too-many-branches, too-many-statements 

64 """Handle GET requests for files in the static file directory.""" 

65 if self.request.path.endswith("/"): 

66 self.replace_path_with_redirect(self.request.path.rstrip("/")) 

67 return 

68 

69 if path.startswith("/") or ".." in path.split("/") or "//" in path: 

70 raise HTTPError(404) 

71 

72 absolute_path, encoding = self.get_absolute_path_encoded(path) 

73 

74 if not absolute_path.is_file(): 

75 if self.get_absolute_path(path.lower()).is_file(): 

76 if self.request.path.endswith(path): 

77 self.replace_path_with_redirect( 

78 self.request.path.removesuffix(path) + path.lower() 

79 ) 

80 return 

81 LOGGER.error( 

82 "Failed to fix casing of %s", self.request.full_url() 

83 ) 

84 raise HTTPError(404) 

85 

86 self.set_header("Accept-Ranges", "bytes") 

87 

88 if encoding: 

89 self.set_header("Content-Encoding", REVERSE_ENCODINGS_MAP[encoding]) 

90 if content_type := self.get_content_type( 

91 path, self.get_absolute_path(path) 

92 ): 

93 self.set_header("Content-Type", content_type) 

94 del path 

95 

96 request_range = None 

97 range_header = self.request.headers.get("Range") 

98 if range_header: 

99 # As per RFC 2616 14.16, if an invalid Range header is specified, 

100 # the request will be treated as if the header didn't exist. 

101 # pylint: disable-next=protected-access 

102 request_range = httputil._parse_request_range(range_header) 

103 

104 size: int = size_of_file(absolute_path) 

105 

106 if request_range: 

107 start, end = request_range 

108 if start is not None and start < 0: 

109 start += size 

110 start = max(start, 0) 

111 if ( 

112 start is not None 

113 and (start >= size or (end is not None and start >= end)) 

114 ) or end == 0: # pylint: disable=use-implicit-booleaness-not-comparison-to-zero # noqa: B950 

115 # As per RFC 2616 14.35.1, a range is not satisfiable only: if 

116 # the first requested byte is equal to or greater than the 

117 # content, or when a suffix with length 0 is specified. 

118 # https://tools.ietf.org/html/rfc7233#section-2.1 

119 # A byte-range-spec is invalid if the last-byte-pos value is present 

120 # and less than the first-byte-pos. 

121 self.set_status(416) # Range Not Satisfiable 

122 self.set_header("Content-Type", "text/plain") 

123 self.set_header("Content-Range", f"bytes */{size}") 

124 return 

125 if end is not None and end > size: 

126 # Clients sometimes blindly use a large range to limit their 

127 # download size; cap the endpoint at the actual file size. 

128 end = size 

129 # Note: only return HTTP 206 if less than the entire range has been 

130 # requested. Not only is this semantically correct, but Chrome 

131 # refuses to play audio if it gets an HTTP 206 in response to 

132 # ``Range: bytes=0-``. 

133 if size != (end or size) - (start or 0): 

134 self.set_status(206) # Partial Content 

135 self.set_header( 

136 "Content-Range", 

137 # pylint: disable-next=protected-access 

138 httputil._get_content_range(start, end, size), 

139 ) 

140 else: 

141 start = end = None 

142 

143 content_length = len(range(size)[start:end]) 

144 self.set_header("Content-Length", content_length) 

145 

146 if head: 

147 assert self.request.method == "HEAD" 

148 await self.finish() 

149 return 

150 

151 for chunk in self.get_content(absolute_path, start=start, end=end): 

152 self.write(chunk) 

153 try: 

154 await self.flush() 

155 except iostream.StreamClosedError: 

156 return 

157 except httputil.HTTPOutputError: 

158 LOGGER.exception( 

159 "Connection %s; %s", 

160 self.request.connection, 

161 getattr(self.request.connection, "__dict__", None), 

162 ) 

163 raise 

164 

165 with contextlib.suppress(iostream.StreamClosedError): 

166 await self.finish() 

167 

168 def get_absolute_path(self, path: str) -> Traversable: 

169 """Get the absolute path of a file.""" 

170 return self.root / path 

171 

172 def get_absolute_path_encoded( 

173 self, path: str 

174 ) -> tuple[Traversable, Encoding | None]: 

175 """Get the absolute path and the encoding.""" 

176 for transform in self._transforms: 

177 if isinstance(transform, GZipContentEncoding): 

178 # pylint: disable=protected-access 

179 transform._gzipping = False 

180 

181 accepted_encodings: frozenset[str] = ( 

182 Stream(self.request.headers.get_list("Accept-Encoding")) 

183 .flat_map(str.split, ",") 

184 .map(lambda string: string.split(";")[0]) # ignore quality specs 

185 .map(str.strip) 

186 .collect(frozenset) 

187 ) 

188 

189 absolute_path = self.get_absolute_path(path) 

190 encoding: Encoding | None = None 

191 

192 for key, encoding in ENCODINGS: 

193 if key in accepted_encodings: 

194 compressed_path = self.get_absolute_path(f"{path}.{encoding}") 

195 if compressed_path.is_file(): 

196 absolute_path = compressed_path 

197 break 

198 encoding = None # pylint: disable=redefined-loop-name 

199 

200 return absolute_path, encoding 

201 

202 @classmethod 

203 def get_content( 

204 cls, 

205 abspath: Traversable, 

206 start: int | None = None, 

207 end: int | None = None, 

208 ) -> Iterable[bytes]: 

209 """Read the content of a file in chunks.""" 

210 with abspath.open("rb") as file: 

211 if start is not None: 

212 file.seek(start) 

213 remaining: int | None = ( 

214 (end - (start or 0)) if end is not None else None 

215 ) 

216 

217 while True: # pylint: disable=while-used 

218 chunk_size = 64 * 1024 

219 if remaining is not None and remaining < chunk_size: 

220 chunk_size = remaining 

221 chunk = file.read(chunk_size) 

222 if chunk: 

223 if remaining is not None: 

224 remaining -= len(chunk) 

225 yield chunk 

226 else: 

227 assert not remaining 

228 return 

229 

230 @classmethod 

231 def get_content_type( 

232 cls, path: str, absolute_path: Traversable 

233 ) -> str | None: 

234 """Get the content-type of a file.""" 

235 return content_type_from_path(path, absolute_path) 

236 

237 def head(self, path: str) -> Awaitable[None]: 

238 """Handle HEAD requests for files in the static file directory.""" 

239 return self.get(path, head=True) 

240 

241 def initialize( 

242 self, 

243 root: Traversable, 

244 hashes: Mapping[str, str] = MappingProxyType({}), 

245 headers: Iterable[tuple[str, str]] = (), 

246 ) -> None: 

247 """Initialize this handler with a root directory and file hashes.""" 

248 self.root = root 

249 self.file_hashes = hashes 

250 self.headers = headers 

251 for name, value in headers: 

252 self.set_header(name, value) 

253 if not sys.flags.dev_mode: 

254 self.set_etag_header() 

255 

256 def replace_path_with_redirect( 

257 self, new_path: str, *, status: int = 307 

258 ) -> None: 

259 """Redirect to the replaced path.""" 

260 scheme, netloc, _, query, _ = urlsplit(self.request.full_url()) 

261 self.redirect( 

262 urlunsplit( 

263 ( 

264 scheme, 

265 netloc, 

266 new_path, 

267 query, 

268 "", 

269 ) 

270 ), 

271 status=status, 

272 ) 

273 

274 @override 

275 def set_default_headers(self) -> None: 

276 """Set the default headers for this handler.""" 

277 super().set_default_headers() 

278 for name, value in self.headers: 

279 self.set_header(name, value) 

280 

281 if not sys.flags.dev_mode: 

282 if "v" in self.request.arguments: 

283 self.set_header( # never changes 

284 "Cache-Control", 

285 f"public, immutable, max-age={86400 * 365 * 10}", 

286 ) 

287 else: 

288 self.set_etag_header() 

289 

290 @override 

291 def write_error(self, status_code: int, **kwargs: Any) -> None: 

292 """Write an error response.""" 

293 self.set_header("Content-Type", "text/plain;charset=utf-8") 

294 self.write(str(status_code)) 

295 self.write(" ") 

296 self.write(self._reason) 

297 self.finish("\n")