Coverage for an_website / utils / static_file_from_traversable.py: 94.595%
148 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-24 17:35 +0000
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-24 17:35 +0000
1# This file is based on the StaticFileHandler from Tornado
2# Source: https://github.com/tornadoweb/tornado/blob/b3f2a4bb6fb55f6b1b1e890cdd6332665cfe4a75/tornado/web.py # noqa: B950 # pylint: disable=line-too-long
3# Licensed under the Apache License, Version 2.0 (the "License"); you may
4# not use this file except in compliance with the License. You may obtain
5# a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations
13# under the License.
15"""A static file handler for the Traversable abc."""
17import contextlib
18import logging
19import sys
20from collections.abc import Awaitable, Iterable, Mapping, Sequence
21from importlib.resources.abc import Traversable
22from types import MappingProxyType
23from typing import Any, Final, Literal, override
24from urllib.parse import urlsplit, urlunsplit
26from tornado import httputil, iostream
27from tornado.web import GZipContentEncoding, HTTPError
28from typed_stream import Stream
30from an_website.utils.utils import size_of_file
32from .base_request_handler import _RequestHandler
33from .static_file_handling import content_type_from_path
35type Encoding = Literal["gz", "zst"]
37LOGGER: Final = logging.getLogger(__name__)
39ENCODINGS: Final[Sequence[tuple[str, Encoding]]] = (
40 # better first
41 ("zstd", "zst"),
42 ("gzip", "gz"),
43)
44REVERSE_ENCODINGS_MAP: Mapping[Encoding, str] = {
45 value: key for key, value in ENCODINGS
46}
49class TraversableStaticFileHandler(_RequestHandler):
50 """A static file handler for the Traversable abc."""
52 root: Traversable
53 file_hashes: Mapping[str, str] = {}
54 headers: Iterable[tuple[str, str]] = ()
56 @override
57 def compute_etag(self) -> None | str:
58 """Return a pre-computed ETag."""
59 return self.file_hashes.get(self.request.path)
61 async def get(self, path: str, *, head: bool = False) -> None: # noqa: C901
62 # pylint: disable=too-complex, too-many-branches, too-many-statements
63 """Handle GET requests for files in the static file directory."""
64 if self.request.path.endswith("/"):
65 self.replace_path_with_redirect(self.request.path.rstrip("/"))
66 return
68 if path.startswith("/") or ".." in path.split("/") or "//" in path:
69 raise HTTPError(404)
71 absolute_path, encoding = self.get_absolute_path_encoded(path)
73 if not absolute_path.is_file():
74 if self.get_absolute_path(path.lower()).is_file():
75 if self.request.path.endswith(path):
76 self.replace_path_with_redirect(
77 self.request.path.removesuffix(path) + path.lower()
78 )
79 return
80 LOGGER.error(
81 "Failed to fix casing of %s", self.request.full_url()
82 )
83 raise HTTPError(404)
85 self.set_header("Accept-Ranges", "bytes")
87 if encoding:
88 self.set_header("Content-Encoding", REVERSE_ENCODINGS_MAP[encoding])
89 if content_type := self.get_content_type(
90 path, self.get_absolute_path(path)
91 ):
92 self.set_header("Content-Type", content_type)
93 del path
95 request_range = None
96 range_header = self.request.headers.get("Range")
97 if range_header:
98 # As per RFC 2616 14.16, if an invalid Range header is specified,
99 # the request will be treated as if the header didn't exist.
100 # pylint: disable-next=protected-access
101 request_range = httputil._parse_request_range(range_header)
103 size: int = size_of_file(absolute_path)
105 if request_range:
106 start, end = request_range
107 if start is not None and start < 0:
108 start += size
109 start = max(start, 0)
110 if (
111 start is not None
112 and (start >= size or (end is not None and start >= end))
113 ) or end == 0: # pylint: disable=use-implicit-booleaness-not-comparison-to-zero # noqa: B950
114 # As per RFC 2616 14.35.1, a range is not satisfiable only: if
115 # the first requested byte is equal to or greater than the
116 # content, or when a suffix with length 0 is specified.
117 # https://tools.ietf.org/html/rfc7233#section-2.1
118 # A byte-range-spec is invalid if the last-byte-pos value is present
119 # and less than the first-byte-pos.
120 self.set_status(416) # Range Not Satisfiable
121 self.set_header("Content-Type", "text/plain")
122 self.set_header("Content-Range", f"bytes */{size}")
123 return
124 if end is not None and end > size:
125 # Clients sometimes blindly use a large range to limit their
126 # download size; cap the endpoint at the actual file size.
127 end = size
128 # Note: only return HTTP 206 if less than the entire range has been
129 # requested. Not only is this semantically correct, but Chrome
130 # refuses to play audio if it gets an HTTP 206 in response to
131 # ``Range: bytes=0-``.
132 if size != (end or size) - (start or 0):
133 self.set_status(206) # Partial Content
134 self.set_header(
135 "Content-Range",
136 # pylint: disable-next=protected-access
137 httputil._get_content_range(start, end, size),
138 )
139 else:
140 start = end = None
142 content_length = len(range(size)[start:end])
143 self.set_header("Content-Length", content_length)
145 if head:
146 assert self.request.method == "HEAD"
147 await self.finish()
148 return
150 for chunk in self.get_content(absolute_path, start=start, end=end):
151 self.write(chunk)
152 try:
153 await self.flush()
154 except iostream.StreamClosedError:
155 return
156 except httputil.HTTPOutputError:
157 LOGGER.exception(
158 "Connection %s; %s",
159 self.request.connection,
160 getattr(self.request.connection, "__dict__", None),
161 )
162 raise
164 with contextlib.suppress(iostream.StreamClosedError):
165 await self.finish()
167 def get_absolute_path(self, path: str) -> Traversable:
168 """Get the absolute path of a file."""
169 return self.root / path
171 def get_absolute_path_encoded(
172 self, path: str
173 ) -> tuple[Traversable, Encoding | None]:
174 """Get the absolute path and the encoding."""
175 for transform in self._transforms:
176 if isinstance(transform, GZipContentEncoding):
177 # pylint: disable=protected-access
178 transform._gzipping = False
180 accepted_encodings: frozenset[str] = (
181 Stream(self.request.headers.get_list("Accept-Encoding"))
182 .flat_map(str.split, ",")
183 .map(lambda string: string.split(";")[0]) # ignore quality specs
184 .map(str.strip)
185 .collect(frozenset)
186 )
188 absolute_path = self.get_absolute_path(path)
189 encoding: Encoding | None = None
191 for key, encoding in ENCODINGS:
192 if key in accepted_encodings:
193 compressed_path = self.get_absolute_path(f"{path}.{encoding}")
194 if compressed_path.is_file():
195 absolute_path = compressed_path
196 break
197 encoding = None # pylint: disable=redefined-loop-name
199 return absolute_path, encoding
201 @classmethod
202 def get_content(
203 cls,
204 abspath: Traversable,
205 start: int | None = None,
206 end: int | None = None,
207 ) -> Iterable[bytes]:
208 """Read the content of a file in chunks."""
209 with abspath.open("rb") as file:
210 if start is not None:
211 file.seek(start)
212 remaining: int | None = (
213 (end - (start or 0)) if end is not None else None
214 )
216 while True: # pylint: disable=while-used
217 chunk_size = 64 * 1024
218 if remaining is not None and remaining < chunk_size:
219 chunk_size = remaining
220 chunk = file.read(chunk_size)
221 if chunk:
222 if remaining is not None:
223 remaining -= len(chunk)
224 yield chunk
225 else:
226 assert not remaining
227 return
229 @classmethod
230 def get_content_type(
231 cls, path: str, absolute_path: Traversable
232 ) -> str | None:
233 """Get the content-type of a file."""
234 return content_type_from_path(path, absolute_path)
236 def head(self, path: str) -> Awaitable[None]:
237 """Handle HEAD requests for files in the static file directory."""
238 return self.get(path, head=True)
240 def initialize(
241 self,
242 root: Traversable,
243 hashes: Mapping[str, str] = MappingProxyType({}),
244 headers: Iterable[tuple[str, str]] = (),
245 ) -> None:
246 """Initialize this handler with a root directory and file hashes."""
247 self.root = root
248 self.file_hashes = hashes
249 self.headers = headers
250 for name, value in headers:
251 self.set_header(name, value)
252 if not sys.flags.dev_mode:
253 self.set_etag_header()
255 def replace_path_with_redirect(
256 self, new_path: str, *, status: int = 307
257 ) -> None:
258 """Redirect to the replaced path."""
259 scheme, netloc, _, query, _ = urlsplit(self.request.full_url())
260 self.redirect(
261 urlunsplit(
262 (
263 scheme,
264 netloc,
265 new_path,
266 query,
267 "",
268 )
269 ),
270 status=status,
271 )
273 @override
274 def set_default_headers(self) -> None:
275 """Set the default headers for this handler."""
276 super().set_default_headers()
277 for name, value in self.headers:
278 self.set_header(name, value)
280 if not sys.flags.dev_mode:
281 if "v" in self.request.arguments:
282 self.set_header( # never changes
283 "Cache-Control",
284 f"public,immutable,max-age={86400 * 365 * 10}",
285 )
286 else:
287 self.set_etag_header()
289 @override
290 def write_error(self, status_code: int, **kwargs: Any) -> None:
291 """Write an error response."""
292 self.set_header("Content-Type", "text/plain;charset=utf-8")
293 self.write(str(status_code))
294 self.write(" ")
295 self.write(self._reason)
296 _ = self.finish("\n")