Coverage for an_website/utils/static_file_from_traversable.py: 94.667%
150 statements
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-01 02:01 +0000
« prev ^ index » next coverage.py v7.10.6, created at 2025-09-01 02:01 +0000
1# This file is based on the StaticFileHandler from Tornado
2# Source: https://github.com/tornadoweb/tornado/blob/b3f2a4bb6fb55f6b1b1e890cdd6332665cfe4a75/tornado/web.py # noqa: B950 # pylint: disable=line-too-long
3# Licensed under the Apache License, Version 2.0 (the "License"); you may
4# not use this file except in compliance with the License. You may obtain
5# a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations
13# under the License.
15"""A static file handler for the Traversable abc."""
16from __future__ import annotations
18import contextlib
19import logging
20import sys
21from collections.abc import Awaitable, Iterable, Mapping, Sequence
22from importlib.resources.abc import Traversable
23from types import MappingProxyType
24from typing import Any, Final, Literal, override
25from urllib.parse import urlsplit, urlunsplit
27from tornado import httputil, iostream
28from tornado.web import GZipContentEncoding, HTTPError
29from typed_stream import Stream
31from an_website.utils.utils import size_of_file
33from .base_request_handler import _RequestHandler
34from .static_file_handling import content_type_from_path
36type Encoding = Literal["gz", "zst"]
38LOGGER: Final = logging.getLogger(__name__)
40ENCODINGS: Final[Sequence[tuple[str, Encoding]]] = (
41 # better first
42 ("zstd", "zst"),
43 ("gzip", "gz"),
44)
45REVERSE_ENCODINGS_MAP: Mapping[Encoding, str] = {
46 value: key for key, value in ENCODINGS
47}
50class TraversableStaticFileHandler(_RequestHandler):
51 """A static file handler for the Traversable abc."""
53 root: Traversable
54 file_hashes: Mapping[str, str] = {}
55 headers: Iterable[tuple[str, str]] = ()
57 @override
58 def compute_etag(self) -> None | str:
59 """Return a pre-computed ETag."""
60 return self.file_hashes.get(self.request.path)
62 async def get(self, path: str, *, head: bool = False) -> None: # noqa: C901
63 # pylint: disable=too-complex, too-many-branches, too-many-statements
64 """Handle GET requests for files in the static file directory."""
65 if self.request.path.endswith("/"):
66 self.replace_path_with_redirect(self.request.path.rstrip("/"))
67 return
69 if path.startswith("/") or ".." in path.split("/") or "//" in path:
70 raise HTTPError(404)
72 absolute_path, encoding = self.get_absolute_path_encoded(path)
74 if not absolute_path.is_file():
75 if self.get_absolute_path(path.lower()).is_file():
76 if self.request.path.endswith(path):
77 self.replace_path_with_redirect(
78 self.request.path.removesuffix(path) + path.lower()
79 )
80 return
81 LOGGER.error(
82 "Failed to fix casing of %s", self.request.full_url()
83 )
84 raise HTTPError(404)
86 self.set_header("Accept-Ranges", "bytes")
88 if encoding:
89 self.set_header("Content-Encoding", REVERSE_ENCODINGS_MAP[encoding])
90 if content_type := self.get_content_type(
91 path, self.get_absolute_path(path)
92 ):
93 self.set_header("Content-Type", content_type)
94 del path
96 request_range = None
97 range_header = self.request.headers.get("Range")
98 if range_header:
99 # As per RFC 2616 14.16, if an invalid Range header is specified,
100 # the request will be treated as if the header didn't exist.
101 # pylint: disable-next=protected-access
102 request_range = httputil._parse_request_range(range_header)
104 size: int = size_of_file(absolute_path)
106 if request_range:
107 start, end = request_range
108 if start is not None and start < 0:
109 start += size
110 start = max(start, 0)
111 if (
112 start is not None
113 and (start >= size or (end is not None and start >= end))
114 ) or end == 0: # pylint: disable=use-implicit-booleaness-not-comparison-to-zero # noqa: B950
115 # As per RFC 2616 14.35.1, a range is not satisfiable only: if
116 # the first requested byte is equal to or greater than the
117 # content, or when a suffix with length 0 is specified.
118 # https://tools.ietf.org/html/rfc7233#section-2.1
119 # A byte-range-spec is invalid if the last-byte-pos value is present
120 # and less than the first-byte-pos.
121 self.set_status(416) # Range Not Satisfiable
122 self.set_header("Content-Type", "text/plain")
123 self.set_header("Content-Range", f"bytes */{size}")
124 return
125 if end is not None and end > size:
126 # Clients sometimes blindly use a large range to limit their
127 # download size; cap the endpoint at the actual file size.
128 end = size
129 # Note: only return HTTP 206 if less than the entire range has been
130 # requested. Not only is this semantically correct, but Chrome
131 # refuses to play audio if it gets an HTTP 206 in response to
132 # ``Range: bytes=0-``.
133 if size != (end or size) - (start or 0):
134 self.set_status(206) # Partial Content
135 self.set_header(
136 "Content-Range",
137 # pylint: disable-next=protected-access
138 httputil._get_content_range(start, end, size),
139 )
140 else:
141 start = end = None
143 content_length = len(range(size)[start:end])
144 self.set_header("Content-Length", content_length)
146 if head:
147 assert self.request.method == "HEAD"
148 await self.finish()
149 return
151 for chunk in self.get_content(absolute_path, start=start, end=end):
152 self.write(chunk)
153 try:
154 await self.flush()
155 except iostream.StreamClosedError:
156 return
157 except httputil.HTTPOutputError:
158 LOGGER.exception(
159 "Connection %s; %s",
160 self.request.connection,
161 getattr(self.request.connection, "__dict__", None),
162 )
163 raise
165 with contextlib.suppress(iostream.StreamClosedError):
166 await self.finish()
168 def get_absolute_path(self, path: str) -> Traversable:
169 """Get the absolute path of a file."""
170 return self.root / path
172 def get_absolute_path_encoded(
173 self, path: str
174 ) -> tuple[Traversable, Encoding | None]:
175 """Get the absolute path and the encoding."""
176 for transform in self._transforms:
177 if isinstance(transform, GZipContentEncoding):
178 # pylint: disable=protected-access
179 transform._gzipping = False
181 accepted_encodings: frozenset[str] = (
182 Stream(self.request.headers.get_list("Accept-Encoding"))
183 .flat_map(str.split, ",")
184 .map(lambda string: string.split(";")[0]) # ignore quality specs
185 .map(str.strip)
186 .collect(frozenset)
187 )
189 absolute_path = self.get_absolute_path(path)
190 encoding: Encoding | None = None
192 for key, encoding in ENCODINGS:
193 if key in accepted_encodings:
194 compressed_path = self.get_absolute_path(f"{path}.{encoding}")
195 if compressed_path.is_file():
196 absolute_path = compressed_path
197 break
198 encoding = None # pylint: disable=redefined-loop-name
200 return absolute_path, encoding
202 @classmethod
203 def get_content(
204 cls,
205 abspath: Traversable,
206 start: int | None = None,
207 end: int | None = None,
208 ) -> Iterable[bytes]:
209 """Read the content of a file in chunks."""
210 with abspath.open("rb") as file:
211 if start is not None:
212 file.seek(start)
213 remaining: int | None = (
214 (end - (start or 0)) if end is not None else None
215 )
217 while True: # pylint: disable=while-used
218 chunk_size = 64 * 1024
219 if remaining is not None and remaining < chunk_size:
220 chunk_size = remaining
221 chunk = file.read(chunk_size)
222 if chunk:
223 if remaining is not None:
224 remaining -= len(chunk)
225 yield chunk
226 else:
227 assert not remaining
228 return
230 @classmethod
231 def get_content_type(
232 cls, path: str, absolute_path: Traversable
233 ) -> str | None:
234 """Get the content-type of a file."""
235 return content_type_from_path(path, absolute_path)
237 def head(self, path: str) -> Awaitable[None]:
238 """Handle HEAD requests for files in the static file directory."""
239 return self.get(path, head=True)
241 def initialize(
242 self,
243 root: Traversable,
244 hashes: Mapping[str, str] = MappingProxyType({}),
245 headers: Iterable[tuple[str, str]] = (),
246 ) -> None:
247 """Initialize this handler with a root directory and file hashes."""
248 self.root = root
249 self.file_hashes = hashes
250 self.headers = headers
251 for name, value in headers:
252 self.set_header(name, value)
253 if not sys.flags.dev_mode:
254 self.set_etag_header()
256 def replace_path_with_redirect(
257 self, new_path: str, *, status: int = 307
258 ) -> None:
259 """Redirect to the replaced path."""
260 scheme, netloc, _, query, _ = urlsplit(self.request.full_url())
261 self.redirect(
262 urlunsplit(
263 (
264 scheme,
265 netloc,
266 new_path,
267 query,
268 "",
269 )
270 ),
271 status=status,
272 )
274 @override
275 def set_default_headers(self) -> None:
276 """Set the default headers for this handler."""
277 super().set_default_headers()
278 for name, value in self.headers:
279 self.set_header(name, value)
281 if not sys.flags.dev_mode:
282 if "v" in self.request.arguments:
283 self.set_header( # never changes
284 "Cache-Control",
285 f"public, immutable, max-age={86400 * 365 * 10}",
286 )
287 else:
288 self.set_etag_header()
290 @override
291 def write_error(self, status_code: int, **kwargs: Any) -> None:
292 """Write an error response."""
293 self.set_header("Content-Type", "text/plain;charset=utf-8")
294 self.write(str(status_code))
295 self.write(" ")
296 self.write(self._reason)
297 self.finish("\n")