Coverage for an_website/utils/static_file_from_traversable.py: 80.576%

139 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-16 19:56 +0000

1# This file is based on the StaticFileHandler from Tornado 

2# Source: https://github.com/tornadoweb/tornado/blob/b3f2a4bb6fb55f6b1b1e890cdd6332665cfe4a75/tornado/web.py # noqa: B950 # pylint: disable=line-too-long 

3# Licensed under the Apache License, Version 2.0 (the "License"); you may 

4# not use this file except in compliance with the License. You may obtain 

5# a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 

11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 

12# License for the specific language governing permissions and limitations 

13# under the License. 

14 

15"""A static file handler for the Traversable abc.""" 

16from __future__ import annotations 

17 

18import contextlib 

19import logging 

20import sys 

21from collections.abc import Awaitable, Iterable, Mapping, Sequence 

22from importlib.resources.abc import Traversable 

23from typing import Any, Final, Literal, override 

24from urllib.parse import urlsplit, urlunsplit 

25 

26from tornado import httputil, iostream 

27from tornado.web import GZipContentEncoding, HTTPError 

28from typed_stream import Stream 

29 

30from an_website.utils.utils import size_of_file 

31 

32from .base_request_handler import _RequestHandler 

33from .static_file_handling import content_type_from_path 

34 

35type Encoding = Literal["gz", "zst"] 

36 

37LOGGER: Final = logging.getLogger(__name__) 

38 

39ENCODINGS: Final[Sequence[tuple[str, Encoding]]] = ( 

40 # better first 

41 ("zstd", "zst"), 

42 ("gzip", "gz"), 

43) 

44REVERSE_ENCODINGS_MAP: Mapping[Encoding, str] = { 

45 value: key for key, value in ENCODINGS 

46} 

47 

48 

49class TraversableStaticFileHandler(_RequestHandler): 

50 """A static file handler for the Traversable abc.""" 

51 

52 root: Traversable 

53 file_hashes: Mapping[str, str] 

54 

55 @override 

56 def compute_etag(self) -> None | str: 

57 """Return a pre-computed ETag.""" 

58 return getattr(self, "file_hashes", {}).get(self.request.path) 

59 

60 async def get(self, path: str, *, head: bool = False) -> None: # noqa: C901 

61 # pylint: disable=too-complex, too-many-branches 

62 """Handle GET requests for files in the static file directory.""" 

63 if self.request.path.endswith("/"): 

64 self.replace_path_with_redirect(self.request.path.rstrip("/")) 

65 return 

66 

67 if path.startswith("/") or ".." in path.split("/") or "//" in path: 

68 raise HTTPError(404) 

69 

70 absolute_path, encoding = self.get_absolute_path_encoded(path) 

71 

72 if not absolute_path.is_file(): 

73 if self.get_absolute_path(path.lower()).is_file(): 

74 if self.request.path.endswith(path): 

75 self.replace_path_with_redirect( 

76 self.request.path.removesuffix(path) + path.lower() 

77 ) 

78 return 

79 LOGGER.error( 

80 "Failed to fix casing of %s", self.request.full_url() 

81 ) 

82 raise HTTPError(404) 

83 

84 self.set_header("Accept-Ranges", "bytes") 

85 

86 if encoding: 

87 self.set_header("Content-Encoding", REVERSE_ENCODINGS_MAP[encoding]) 

88 if content_type := self.get_content_type( 

89 path, self.get_absolute_path(path) 

90 ): 

91 self.set_header("Content-Type", content_type) 

92 del path 

93 

94 request_range = None 

95 range_header = self.request.headers.get("Range") 

96 if range_header: 

97 # As per RFC 2616 14.16, if an invalid Range header is specified, 

98 # the request will be treated as if the header didn't exist. 

99 # pylint: disable-next=protected-access 

100 request_range = httputil._parse_request_range(range_header) 

101 

102 size: int = size_of_file(absolute_path) 

103 

104 if request_range: 

105 start, end = request_range 

106 if start is not None and start < 0: 

107 start += size 

108 start = max(start, 0) 

109 if ( 

110 start is not None 

111 and (start >= size or (end is not None and start >= end)) 

112 ) or end == 0: # pylint: disable=use-implicit-booleaness-not-comparison-to-zero # noqa: B950 

113 # As per RFC 2616 14.35.1, a range is not satisfiable only: if 

114 # the first requested byte is equal to or greater than the 

115 # content, or when a suffix with length 0 is specified. 

116 # https://tools.ietf.org/html/rfc7233#section-2.1 

117 # A byte-range-spec is invalid if the last-byte-pos value is present 

118 # and less than the first-byte-pos. 

119 self.set_status(416) # Range Not Satisfiable 

120 self.set_header("Content-Type", "text/plain") 

121 self.set_header("Content-Range", f"bytes */{size}") 

122 return 

123 if end is not None and end > size: 

124 # Clients sometimes blindly use a large range to limit their 

125 # download size; cap the endpoint at the actual file size. 

126 end = size 

127 # Note: only return HTTP 206 if less than the entire range has been 

128 # requested. Not only is this semantically correct, but Chrome 

129 # refuses to play audio if it gets an HTTP 206 in response to 

130 # ``Range: bytes=0-``. 

131 if size != (end or size) - (start or 0): 

132 self.set_status(206) # Partial Content 

133 self.set_header( 

134 "Content-Range", 

135 # pylint: disable-next=protected-access 

136 httputil._get_content_range(start, end, size), 

137 ) 

138 else: 

139 start = end = None 

140 

141 content_length = len(range(size)[start:end]) 

142 self.set_header("Content-Length", content_length) 

143 

144 if head: 

145 assert self.request.method == "HEAD" 

146 await self.finish() 

147 return 

148 

149 for chunk in self.get_content(absolute_path, start=start, end=end): 

150 self.write(chunk) 

151 try: 

152 await self.flush() 

153 except iostream.StreamClosedError: 

154 return 

155 

156 with contextlib.suppress(iostream.StreamClosedError): 

157 await self.finish() 

158 

159 def get_absolute_path(self, path: str) -> Traversable: 

160 """Get the absolute path of a file.""" 

161 return self.root / path 

162 

163 def get_absolute_path_encoded( 

164 self, path: str 

165 ) -> tuple[Traversable, Encoding | None]: 

166 """Get the absolute path and the encoding.""" 

167 for transform in self._transforms: 

168 if isinstance(transform, GZipContentEncoding): 

169 # pylint: disable=protected-access 

170 transform._gzipping = False 

171 

172 accepted_encodings: frozenset[str] = ( 

173 Stream(self.request.headers.get_list("Accept-Encoding")) 

174 .flat_map(str.split, ",") 

175 .map(lambda string: string.split(";")[0]) # ignore quality specs 

176 .map(str.strip) 

177 .collect(frozenset) 

178 ) 

179 

180 absolute_path = self.get_absolute_path(path) 

181 encoding: Encoding | None = None 

182 

183 for key, encoding in ENCODINGS: 

184 if key in accepted_encodings: 

185 compressed_path = self.get_absolute_path(f"{path}.{encoding}") 

186 if compressed_path.is_file(): 

187 absolute_path = compressed_path 

188 break 

189 encoding = None # pylint: disable=redefined-loop-name 

190 

191 return absolute_path, encoding 

192 

193 @classmethod 

194 def get_content( 

195 cls, 

196 abspath: Traversable, 

197 start: int | None = None, 

198 end: int | None = None, 

199 ) -> Iterable[bytes]: 

200 """Read the content of a file in chunks.""" 

201 with abspath.open("rb") as file: 

202 if start is not None: 

203 file.seek(start) 

204 remaining: int | None = ( 

205 (end - (start or 0)) if end is not None else None 

206 ) 

207 

208 while True: # pylint: disable=while-used 

209 chunk_size = 64 * 1024 

210 if remaining is not None and remaining < chunk_size: 

211 chunk_size = remaining 

212 chunk = file.read(chunk_size) 

213 if chunk: 

214 if remaining is not None: 

215 remaining -= len(chunk) 

216 yield chunk 

217 else: 

218 assert not remaining 

219 return 

220 

221 @classmethod 

222 def get_content_type( 

223 cls, path: str, absolute_path: Traversable 

224 ) -> str | None: 

225 """Get the content-type of a file.""" 

226 return content_type_from_path(path, absolute_path) 

227 

228 def head(self, path: str) -> Awaitable[None]: 

229 """Handle HEAD requests for files in the static file directory.""" 

230 return self.get(path, head=True) 

231 

232 def initialize(self, root: Traversable, hashes: Mapping[str, str]) -> None: 

233 """Initialize this handler with a root directory and file hashes.""" 

234 self.root = root 

235 self.file_hashes = hashes 

236 if not sys.flags.dev_mode: 

237 self.set_etag_header() 

238 

239 def replace_path_with_redirect( 

240 self, new_path: str, *, status: int = 307 

241 ) -> None: 

242 """Redirect to the replaced path.""" 

243 scheme, netloc, _, query, _ = urlsplit(self.request.full_url()) 

244 self.redirect( 

245 urlunsplit( 

246 ( 

247 scheme, 

248 netloc, 

249 new_path, 

250 query, 

251 "", 

252 ) 

253 ), 

254 status=status, 

255 ) 

256 

257 @override 

258 def set_default_headers(self) -> None: 

259 """Set the default headers for this handler.""" 

260 super().set_default_headers() 

261 if not sys.flags.dev_mode: 

262 if "v" in self.request.arguments: 

263 self.set_header( # never changes 

264 "Cache-Control", 

265 f"public, immutable, max-age={86400 * 365 * 10}", 

266 ) 

267 else: 

268 self.set_etag_header() 

269 

270 @override 

271 def write_error(self, status_code: int, **kwargs: Any) -> None: 

272 """Write an error response.""" 

273 self.set_header("Content-Type", "text/plain;charset=utf-8") 

274 self.write(str(status_code)) 

275 self.write(" ") 

276 self.write(self._reason)