Coverage for an_website / utils / static_file_from_traversable.py: 94.595%

148 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-24 17:35 +0000

1# This file is based on the StaticFileHandler from Tornado 

2# Source: https://github.com/tornadoweb/tornado/blob/b3f2a4bb6fb55f6b1b1e890cdd6332665cfe4a75/tornado/web.py # noqa: B950 # pylint: disable=line-too-long 

3# Licensed under the Apache License, Version 2.0 (the "License"); you may 

4# not use this file except in compliance with the License. You may obtain 

5# a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 

11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 

12# License for the specific language governing permissions and limitations 

13# under the License. 

14 

15"""A static file handler for the Traversable abc.""" 

16 

17import contextlib 

18import logging 

19import sys 

20from collections.abc import Awaitable, Iterable, Mapping, Sequence 

21from importlib.resources.abc import Traversable 

22from types import MappingProxyType 

23from typing import Any, Final, Literal, override 

24from urllib.parse import urlsplit, urlunsplit 

25 

26from tornado import httputil, iostream 

27from tornado.web import GZipContentEncoding, HTTPError 

28from typed_stream import Stream 

29 

30from an_website.utils.utils import size_of_file 

31 

32from .base_request_handler import _RequestHandler 

33from .static_file_handling import content_type_from_path 

34 

35type Encoding = Literal["gz", "zst"] 

36 

37LOGGER: Final = logging.getLogger(__name__) 

38 

39ENCODINGS: Final[Sequence[tuple[str, Encoding]]] = ( 

40 # better first 

41 ("zstd", "zst"), 

42 ("gzip", "gz"), 

43) 

44REVERSE_ENCODINGS_MAP: Mapping[Encoding, str] = { 

45 value: key for key, value in ENCODINGS 

46} 

47 

48 

49class TraversableStaticFileHandler(_RequestHandler): 

50 """A static file handler for the Traversable abc.""" 

51 

52 root: Traversable 

53 file_hashes: Mapping[str, str] = {} 

54 headers: Iterable[tuple[str, str]] = () 

55 

56 @override 

57 def compute_etag(self) -> None | str: 

58 """Return a pre-computed ETag.""" 

59 return self.file_hashes.get(self.request.path) 

60 

61 async def get(self, path: str, *, head: bool = False) -> None: # noqa: C901 

62 # pylint: disable=too-complex, too-many-branches, too-many-statements 

63 """Handle GET requests for files in the static file directory.""" 

64 if self.request.path.endswith("/"): 

65 self.replace_path_with_redirect(self.request.path.rstrip("/")) 

66 return 

67 

68 if path.startswith("/") or ".." in path.split("/") or "//" in path: 

69 raise HTTPError(404) 

70 

71 absolute_path, encoding = self.get_absolute_path_encoded(path) 

72 

73 if not absolute_path.is_file(): 

74 if self.get_absolute_path(path.lower()).is_file(): 

75 if self.request.path.endswith(path): 

76 self.replace_path_with_redirect( 

77 self.request.path.removesuffix(path) + path.lower() 

78 ) 

79 return 

80 LOGGER.error( 

81 "Failed to fix casing of %s", self.request.full_url() 

82 ) 

83 raise HTTPError(404) 

84 

85 self.set_header("Accept-Ranges", "bytes") 

86 

87 if encoding: 

88 self.set_header("Content-Encoding", REVERSE_ENCODINGS_MAP[encoding]) 

89 if content_type := self.get_content_type( 

90 path, self.get_absolute_path(path) 

91 ): 

92 self.set_header("Content-Type", content_type) 

93 del path 

94 

95 request_range = None 

96 range_header = self.request.headers.get("Range") 

97 if range_header: 

98 # As per RFC 2616 14.16, if an invalid Range header is specified, 

99 # the request will be treated as if the header didn't exist. 

100 # pylint: disable-next=protected-access 

101 request_range = httputil._parse_request_range(range_header) 

102 

103 size: int = size_of_file(absolute_path) 

104 

105 if request_range: 

106 start, end = request_range 

107 if start is not None and start < 0: 

108 start += size 

109 start = max(start, 0) 

110 if ( 

111 start is not None 

112 and (start >= size or (end is not None and start >= end)) 

113 ) or end == 0: # pylint: disable=use-implicit-booleaness-not-comparison-to-zero # noqa: B950 

114 # As per RFC 2616 14.35.1, a range is not satisfiable only: if 

115 # the first requested byte is equal to or greater than the 

116 # content, or when a suffix with length 0 is specified. 

117 # https://tools.ietf.org/html/rfc7233#section-2.1 

118 # A byte-range-spec is invalid if the last-byte-pos value is present 

119 # and less than the first-byte-pos. 

120 self.set_status(416) # Range Not Satisfiable 

121 self.set_header("Content-Type", "text/plain") 

122 self.set_header("Content-Range", f"bytes */{size}") 

123 return 

124 if end is not None and end > size: 

125 # Clients sometimes blindly use a large range to limit their 

126 # download size; cap the endpoint at the actual file size. 

127 end = size 

128 # Note: only return HTTP 206 if less than the entire range has been 

129 # requested. Not only is this semantically correct, but Chrome 

130 # refuses to play audio if it gets an HTTP 206 in response to 

131 # ``Range: bytes=0-``. 

132 if size != (end or size) - (start or 0): 

133 self.set_status(206) # Partial Content 

134 self.set_header( 

135 "Content-Range", 

136 # pylint: disable-next=protected-access 

137 httputil._get_content_range(start, end, size), 

138 ) 

139 else: 

140 start = end = None 

141 

142 content_length = len(range(size)[start:end]) 

143 self.set_header("Content-Length", content_length) 

144 

145 if head: 

146 assert self.request.method == "HEAD" 

147 await self.finish() 

148 return 

149 

150 for chunk in self.get_content(absolute_path, start=start, end=end): 

151 self.write(chunk) 

152 try: 

153 await self.flush() 

154 except iostream.StreamClosedError: 

155 return 

156 except httputil.HTTPOutputError: 

157 LOGGER.exception( 

158 "Connection %s; %s", 

159 self.request.connection, 

160 getattr(self.request.connection, "__dict__", None), 

161 ) 

162 raise 

163 

164 with contextlib.suppress(iostream.StreamClosedError): 

165 await self.finish() 

166 

167 def get_absolute_path(self, path: str) -> Traversable: 

168 """Get the absolute path of a file.""" 

169 return self.root / path 

170 

171 def get_absolute_path_encoded( 

172 self, path: str 

173 ) -> tuple[Traversable, Encoding | None]: 

174 """Get the absolute path and the encoding.""" 

175 for transform in self._transforms: 

176 if isinstance(transform, GZipContentEncoding): 

177 # pylint: disable=protected-access 

178 transform._gzipping = False 

179 

180 accepted_encodings: frozenset[str] = ( 

181 Stream(self.request.headers.get_list("Accept-Encoding")) 

182 .flat_map(str.split, ",") 

183 .map(lambda string: string.split(";")[0]) # ignore quality specs 

184 .map(str.strip) 

185 .collect(frozenset) 

186 ) 

187 

188 absolute_path = self.get_absolute_path(path) 

189 encoding: Encoding | None = None 

190 

191 for key, encoding in ENCODINGS: 

192 if key in accepted_encodings: 

193 compressed_path = self.get_absolute_path(f"{path}.{encoding}") 

194 if compressed_path.is_file(): 

195 absolute_path = compressed_path 

196 break 

197 encoding = None # pylint: disable=redefined-loop-name 

198 

199 return absolute_path, encoding 

200 

201 @classmethod 

202 def get_content( 

203 cls, 

204 abspath: Traversable, 

205 start: int | None = None, 

206 end: int | None = None, 

207 ) -> Iterable[bytes]: 

208 """Read the content of a file in chunks.""" 

209 with abspath.open("rb") as file: 

210 if start is not None: 

211 file.seek(start) 

212 remaining: int | None = ( 

213 (end - (start or 0)) if end is not None else None 

214 ) 

215 

216 while True: # pylint: disable=while-used 

217 chunk_size = 64 * 1024 

218 if remaining is not None and remaining < chunk_size: 

219 chunk_size = remaining 

220 chunk = file.read(chunk_size) 

221 if chunk: 

222 if remaining is not None: 

223 remaining -= len(chunk) 

224 yield chunk 

225 else: 

226 assert not remaining 

227 return 

228 

229 @classmethod 

230 def get_content_type( 

231 cls, path: str, absolute_path: Traversable 

232 ) -> str | None: 

233 """Get the content-type of a file.""" 

234 return content_type_from_path(path, absolute_path) 

235 

236 def head(self, path: str) -> Awaitable[None]: 

237 """Handle HEAD requests for files in the static file directory.""" 

238 return self.get(path, head=True) 

239 

240 def initialize( 

241 self, 

242 root: Traversable, 

243 hashes: Mapping[str, str] = MappingProxyType({}), 

244 headers: Iterable[tuple[str, str]] = (), 

245 ) -> None: 

246 """Initialize this handler with a root directory and file hashes.""" 

247 self.root = root 

248 self.file_hashes = hashes 

249 self.headers = headers 

250 for name, value in headers: 

251 self.set_header(name, value) 

252 if not sys.flags.dev_mode: 

253 self.set_etag_header() 

254 

255 def replace_path_with_redirect( 

256 self, new_path: str, *, status: int = 307 

257 ) -> None: 

258 """Redirect to the replaced path.""" 

259 scheme, netloc, _, query, _ = urlsplit(self.request.full_url()) 

260 self.redirect( 

261 urlunsplit( 

262 ( 

263 scheme, 

264 netloc, 

265 new_path, 

266 query, 

267 "", 

268 ) 

269 ), 

270 status=status, 

271 ) 

272 

273 @override 

274 def set_default_headers(self) -> None: 

275 """Set the default headers for this handler.""" 

276 super().set_default_headers() 

277 for name, value in self.headers: 

278 self.set_header(name, value) 

279 

280 if not sys.flags.dev_mode: 

281 if "v" in self.request.arguments: 

282 self.set_header( # never changes 

283 "Cache-Control", 

284 f"public,immutable,max-age={86400 * 365 * 10}", 

285 ) 

286 else: 

287 self.set_etag_header() 

288 

289 @override 

290 def write_error(self, status_code: int, **kwargs: Any) -> None: 

291 """Write an error response.""" 

292 self.set_header("Content-Type", "text/plain;charset=utf-8") 

293 self.write(str(status_code)) 

294 self.write(" ") 

295 self.write(self._reason) 

296 _ = self.finish("\n")