Coverage for an_website/utils/static_file_from_traversable.py: 80.576%
139 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-16 19:56 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-16 19:56 +0000
1# This file is based on the StaticFileHandler from Tornado
2# Source: https://github.com/tornadoweb/tornado/blob/b3f2a4bb6fb55f6b1b1e890cdd6332665cfe4a75/tornado/web.py # noqa: B950 # pylint: disable=line-too-long
3# Licensed under the Apache License, Version 2.0 (the "License"); you may
4# not use this file except in compliance with the License. You may obtain
5# a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations
13# under the License.
15"""A static file handler for the Traversable abc."""
16from __future__ import annotations
18import contextlib
19import logging
20import sys
21from collections.abc import Awaitable, Iterable, Mapping, Sequence
22from importlib.resources.abc import Traversable
23from typing import Any, Final, Literal, override
24from urllib.parse import urlsplit, urlunsplit
26from tornado import httputil, iostream
27from tornado.web import GZipContentEncoding, HTTPError
28from typed_stream import Stream
30from an_website.utils.utils import size_of_file
32from .base_request_handler import _RequestHandler
33from .static_file_handling import content_type_from_path
35type Encoding = Literal["gz", "zst"]
37LOGGER: Final = logging.getLogger(__name__)
39ENCODINGS: Final[Sequence[tuple[str, Encoding]]] = (
40 # better first
41 ("zstd", "zst"),
42 ("gzip", "gz"),
43)
44REVERSE_ENCODINGS_MAP: Mapping[Encoding, str] = {
45 value: key for key, value in ENCODINGS
46}
49class TraversableStaticFileHandler(_RequestHandler):
50 """A static file handler for the Traversable abc."""
52 root: Traversable
53 file_hashes: Mapping[str, str]
55 @override
56 def compute_etag(self) -> None | str:
57 """Return a pre-computed ETag."""
58 return getattr(self, "file_hashes", {}).get(self.request.path)
60 async def get(self, path: str, *, head: bool = False) -> None: # noqa: C901
61 # pylint: disable=too-complex, too-many-branches
62 """Handle GET requests for files in the static file directory."""
63 if self.request.path.endswith("/"):
64 self.replace_path_with_redirect(self.request.path.rstrip("/"))
65 return
67 if path.startswith("/") or ".." in path.split("/") or "//" in path:
68 raise HTTPError(404)
70 absolute_path, encoding = self.get_absolute_path_encoded(path)
72 if not absolute_path.is_file():
73 if self.get_absolute_path(path.lower()).is_file():
74 if self.request.path.endswith(path):
75 self.replace_path_with_redirect(
76 self.request.path.removesuffix(path) + path.lower()
77 )
78 return
79 LOGGER.error(
80 "Failed to fix casing of %s", self.request.full_url()
81 )
82 raise HTTPError(404)
84 self.set_header("Accept-Ranges", "bytes")
86 if encoding:
87 self.set_header("Content-Encoding", REVERSE_ENCODINGS_MAP[encoding])
88 if content_type := self.get_content_type(
89 path, self.get_absolute_path(path)
90 ):
91 self.set_header("Content-Type", content_type)
92 del path
94 request_range = None
95 range_header = self.request.headers.get("Range")
96 if range_header:
97 # As per RFC 2616 14.16, if an invalid Range header is specified,
98 # the request will be treated as if the header didn't exist.
99 # pylint: disable-next=protected-access
100 request_range = httputil._parse_request_range(range_header)
102 size: int = size_of_file(absolute_path)
104 if request_range:
105 start, end = request_range
106 if start is not None and start < 0:
107 start += size
108 start = max(start, 0)
109 if (
110 start is not None
111 and (start >= size or (end is not None and start >= end))
112 ) or end == 0: # pylint: disable=use-implicit-booleaness-not-comparison-to-zero # noqa: B950
113 # As per RFC 2616 14.35.1, a range is not satisfiable only: if
114 # the first requested byte is equal to or greater than the
115 # content, or when a suffix with length 0 is specified.
116 # https://tools.ietf.org/html/rfc7233#section-2.1
117 # A byte-range-spec is invalid if the last-byte-pos value is present
118 # and less than the first-byte-pos.
119 self.set_status(416) # Range Not Satisfiable
120 self.set_header("Content-Type", "text/plain")
121 self.set_header("Content-Range", f"bytes */{size}")
122 return
123 if end is not None and end > size:
124 # Clients sometimes blindly use a large range to limit their
125 # download size; cap the endpoint at the actual file size.
126 end = size
127 # Note: only return HTTP 206 if less than the entire range has been
128 # requested. Not only is this semantically correct, but Chrome
129 # refuses to play audio if it gets an HTTP 206 in response to
130 # ``Range: bytes=0-``.
131 if size != (end or size) - (start or 0):
132 self.set_status(206) # Partial Content
133 self.set_header(
134 "Content-Range",
135 # pylint: disable-next=protected-access
136 httputil._get_content_range(start, end, size),
137 )
138 else:
139 start = end = None
141 content_length = len(range(size)[start:end])
142 self.set_header("Content-Length", content_length)
144 if head:
145 assert self.request.method == "HEAD"
146 await self.finish()
147 return
149 for chunk in self.get_content(absolute_path, start=start, end=end):
150 self.write(chunk)
151 try:
152 await self.flush()
153 except iostream.StreamClosedError:
154 return
156 with contextlib.suppress(iostream.StreamClosedError):
157 await self.finish()
159 def get_absolute_path(self, path: str) -> Traversable:
160 """Get the absolute path of a file."""
161 return self.root / path
163 def get_absolute_path_encoded(
164 self, path: str
165 ) -> tuple[Traversable, Encoding | None]:
166 """Get the absolute path and the encoding."""
167 for transform in self._transforms:
168 if isinstance(transform, GZipContentEncoding):
169 # pylint: disable=protected-access
170 transform._gzipping = False
172 accepted_encodings: frozenset[str] = (
173 Stream(self.request.headers.get_list("Accept-Encoding"))
174 .flat_map(str.split, ",")
175 .map(lambda string: string.split(";")[0]) # ignore quality specs
176 .map(str.strip)
177 .collect(frozenset)
178 )
180 absolute_path = self.get_absolute_path(path)
181 encoding: Encoding | None = None
183 for key, encoding in ENCODINGS:
184 if key in accepted_encodings:
185 compressed_path = self.get_absolute_path(f"{path}.{encoding}")
186 if compressed_path.is_file():
187 absolute_path = compressed_path
188 break
189 encoding = None # pylint: disable=redefined-loop-name
191 return absolute_path, encoding
193 @classmethod
194 def get_content(
195 cls,
196 abspath: Traversable,
197 start: int | None = None,
198 end: int | None = None,
199 ) -> Iterable[bytes]:
200 """Read the content of a file in chunks."""
201 with abspath.open("rb") as file:
202 if start is not None:
203 file.seek(start)
204 remaining: int | None = (
205 (end - (start or 0)) if end is not None else None
206 )
208 while True: # pylint: disable=while-used
209 chunk_size = 64 * 1024
210 if remaining is not None and remaining < chunk_size:
211 chunk_size = remaining
212 chunk = file.read(chunk_size)
213 if chunk:
214 if remaining is not None:
215 remaining -= len(chunk)
216 yield chunk
217 else:
218 assert not remaining
219 return
221 @classmethod
222 def get_content_type(
223 cls, path: str, absolute_path: Traversable
224 ) -> str | None:
225 """Get the content-type of a file."""
226 return content_type_from_path(path, absolute_path)
228 def head(self, path: str) -> Awaitable[None]:
229 """Handle HEAD requests for files in the static file directory."""
230 return self.get(path, head=True)
232 def initialize(self, root: Traversable, hashes: Mapping[str, str]) -> None:
233 """Initialize this handler with a root directory and file hashes."""
234 self.root = root
235 self.file_hashes = hashes
236 if not sys.flags.dev_mode:
237 self.set_etag_header()
239 def replace_path_with_redirect(
240 self, new_path: str, *, status: int = 307
241 ) -> None:
242 """Redirect to the replaced path."""
243 scheme, netloc, _, query, _ = urlsplit(self.request.full_url())
244 self.redirect(
245 urlunsplit(
246 (
247 scheme,
248 netloc,
249 new_path,
250 query,
251 "",
252 )
253 ),
254 status=status,
255 )
257 @override
258 def set_default_headers(self) -> None:
259 """Set the default headers for this handler."""
260 super().set_default_headers()
261 if not sys.flags.dev_mode:
262 if "v" in self.request.arguments:
263 self.set_header( # never changes
264 "Cache-Control",
265 f"public, immutable, max-age={86400 * 365 * 10}",
266 )
267 else:
268 self.set_etag_header()
270 @override
271 def write_error(self, status_code: int, **kwargs: Any) -> None:
272 """Write an error response."""
273 self.set_header("Content-Type", "text/plain;charset=utf-8")
274 self.write(str(status_code))
275 self.write(" ")
276 self.write(self._reason)