Coverage for an_website/utils/logging.py: 36.275%
102 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-06-01 08:32 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-06-01 08:32 +0000
1# This program is free software: you can redistribute it and/or modify
2# it under the terms of the GNU Affero General Public License as
3# published by the Free Software Foundation, either version 3 of the
4# License, or (at your option) any later version.
5#
6# This program is distributed in the hope that it will be useful,
7# but WITHOUT ANY WARRANTY; without even the implied warranty of
8# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9# GNU Affero General Public License for more details.
10#
11# You should have received a copy of the GNU Affero General Public License
12# along with this program. If not, see <https://www.gnu.org/licenses/>.
14"""Logging stuff used by the website."""
16from __future__ import annotations
18import asyncio
19import logging
20import traceback
21from asyncio import AbstractEventLoop
22from collections.abc import Coroutine, Iterable
23from concurrent.futures import Future
24from datetime import datetime, tzinfo
25from logging import LogRecord
26from pathlib import Path
27from typing import Never
29import orjson as json
30from tornado.httpclient import AsyncHTTPClient
32from an_website import DIR as AN_WEBSITE_DIR
34from .. import CA_BUNDLE_PATH
36HOME: str = Path("~/").expanduser().as_posix().rstrip("/")
39def minify_filepath(path: str) -> str:
40 """Make a filepath smaller."""
41 if path.startswith(f"{HOME}/"):
42 return "~" + path.removeprefix(HOME)
43 return path
46def get_minimal_traceback(
47 record: LogRecord, prefix: str = "\n\n"
48) -> Iterable[str]:
49 """Get a minimal traceback from the log record."""
50 if not record.exc_info:
51 return
52 (_, value, tb) = record.exc_info
53 if not (value and tb):
54 return
56 yield prefix
57 yield from traceback.format_exception(value, limit=0)
59 summary = traceback.extract_tb(tb)
60 if isinstance(AN_WEBSITE_DIR, Path):
61 start_path = f"{str(AN_WEBSITE_DIR).rstrip('/')}/"
63 for i in reversed(range(len(summary))):
64 if summary[i].filename.startswith(start_path):
65 summary = traceback.StackSummary(summary[i:])
66 break
68 for frame in summary:
69 frame.filename = minify_filepath(frame.filename)
71 yield from summary.format()
74class AsyncHandler(logging.Handler):
75 """A logging handler that can handle log records asynchronously."""
77 futures: set[Future[object]]
78 loop: AbstractEventLoop
80 def __init__(
81 self,
82 level: int | str = logging.NOTSET,
83 *,
84 loop: AbstractEventLoop,
85 ):
86 """Initialize the handler."""
87 super().__init__(level=level)
88 self.futures = set()
89 self.loop = loop
91 def callback(self, future: Future[object]) -> None:
92 """Remove the reference to the future from the handler."""
93 self.acquire()
94 try:
95 self.futures.discard(future)
96 finally:
97 self.release()
99 def emit( # type: ignore[override]
100 self, record: LogRecord
101 ) -> None | Coroutine[None, Never, object]:
102 """
103 Do whatever it takes to actually log the specified logging record.
105 This version is intended to be implemented by subclasses and so
106 raises a NotImplementedError.
107 """
108 raise NotImplementedError(
109 "emit must be implemented by AsyncHandler subclasses"
110 )
112 def handle( # type: ignore[override]
113 self, record: LogRecord
114 ) -> bool | LogRecord:
115 """Handle incoming log records."""
116 rv = self.filter(record)
117 if isinstance(rv, LogRecord):
118 record = rv
119 if rv and not self.loop.is_closed():
120 self.acquire()
121 try:
122 if awaitable := self.emit(record):
123 future: Future[object] = asyncio.run_coroutine_threadsafe(
124 awaitable, self.loop
125 )
126 self.futures.add(future)
127 future.add_done_callback(self.callback)
128 finally:
129 self.release()
130 return rv
133class DatetimeFormatter(logging.Formatter):
134 """A logging formatter that formats the time using datetime."""
136 timezone: None | tzinfo = None
138 def formatTime( # noqa: N802
139 self, record: LogRecord, datefmt: None | str = None
140 ) -> str:
141 """Return the creation time of the LogRecord as formatted text."""
142 spam = datetime.fromtimestamp(record.created).astimezone(self.timezone)
143 if datefmt:
144 return spam.strftime(datefmt)
145 return spam.isoformat()
148class WebhookFormatter(DatetimeFormatter):
149 """A logging formatter optimized for logging to a webhook."""
151 escape_message = False
152 max_message_length: int | None = None
154 def format(self, record: LogRecord) -> str:
155 """Format the specified record as text."""
156 record.message = record.getMessage()
157 if self.usesTime():
158 record.asctime = self.formatTime(record, self.datefmt)
159 if (
160 self.max_message_length is not None
161 and len(record.message) > self.max_message_length
162 ):
163 record.message = record.message[: self.max_message_length]
164 for line in get_minimal_traceback(record):
165 if (
166 self.max_message_length is not None
167 and len(line) + len(record.message) > self.max_message_length
168 ):
169 if len("...") + len(record.message) <= self.max_message_length:
170 record.message += "..."
171 break
172 record.message += line
173 if self.escape_message:
174 record.message = json.dumps(record.message).decode("UTF-8")[1:-1]
175 return self.formatMessage(record)
178class WebhookHandler(AsyncHandler):
179 """A logging handler that sends logs to a webhook."""
181 url: str
182 content_type: str
184 def __init__(
185 self,
186 level: int | str = logging.NOTSET,
187 *,
188 loop: AbstractEventLoop,
189 url: str,
190 content_type: str,
191 ):
192 """Initialize the handler."""
193 super().__init__(level=level, loop=loop)
194 self.url = url
195 self.content_type = content_type
197 async def emit(self, record: LogRecord) -> None: # type: ignore[override]
198 """Send the request to the webhook."""
199 # pylint: disable=invalid-overridden-method
200 try:
201 message = self.format(record)
202 await AsyncHTTPClient().fetch(
203 self.url,
204 method="POST",
205 headers={"Content-Type": self.content_type},
206 body=message.strip(),
207 ca_certs=CA_BUNDLE_PATH,
208 )
209 except Exception: # pylint: disable=broad-except
210 self.handleError(record)