Coverage for an_website / utils / logging.py: 32.673%

101 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-19 18:33 +0000

1# This program is free software: you can redistribute it and/or modify 

2# it under the terms of the GNU Affero General Public License as 

3# published by the Free Software Foundation, either version 3 of the 

4# License, or (at your option) any later version. 

5# 

6# This program is distributed in the hope that it will be useful, 

7# but WITHOUT ANY WARRANTY; without even the implied warranty of 

8# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

9# GNU Affero General Public License for more details. 

10# 

11# You should have received a copy of the GNU Affero General Public License 

12# along with this program. If not, see <https://www.gnu.org/licenses/>. 

13 

14"""Logging stuff used by the website.""" 

15 

16import asyncio 

17import logging 

18import traceback 

19from asyncio import AbstractEventLoop 

20from collections.abc import Callable, Coroutine, Iterable 

21from concurrent.futures import Future 

22from datetime import datetime, tzinfo 

23from logging import LogRecord 

24from pathlib import Path 

25from typing import Never 

26 

27import orjson as json 

28from tornado.httpclient import AsyncHTTPClient 

29 

30from an_website import DIR as AN_WEBSITE_DIR 

31 

32from .. import CA_BUNDLE_PATH 

33 

34HOME: str = Path("~/").expanduser().as_posix().rstrip("/") 

35 

36 

37def minify_filepath(path: str) -> str: 

38 """Make a filepath smaller.""" 

39 if path.startswith(f"{HOME}/"): 

40 return "~" + path.removeprefix(HOME) 

41 return path 

42 

43 

44def get_minimal_traceback( 

45 record: LogRecord, prefix: str = "\n\n" 

46) -> Iterable[str]: 

47 """Get a minimal traceback from the log record.""" 

48 if not record.exc_info: 

49 return 

50 _, value, tb = record.exc_info 

51 if not (value and tb): 

52 return 

53 

54 yield prefix 

55 yield from traceback.format_exception(value, limit=0) 

56 

57 summary = traceback.extract_tb(tb) 

58 if isinstance(AN_WEBSITE_DIR, Path): 

59 start_path = f"{str(AN_WEBSITE_DIR).rstrip('/')}/" 

60 

61 for i in reversed(range(len(summary))): 

62 if summary[i].filename.startswith(start_path): 

63 summary = traceback.StackSummary(summary[i:]) 

64 break 

65 

66 for frame in summary: 

67 frame.filename = minify_filepath(frame.filename) 

68 

69 yield from summary.format() 

70 

71 

72class AsyncHandler(logging.Handler): 

73 """A logging handler that can handle log records asynchronously.""" 

74 

75 futures: set[Future[object]] 

76 loop: AbstractEventLoop 

77 

78 def __init__( 

79 self, 

80 level: int | str = logging.NOTSET, 

81 *, 

82 loop: AbstractEventLoop, 

83 ): 

84 """Initialize the handler.""" 

85 super().__init__(level=level) 

86 self.futures = set() 

87 self.loop = loop 

88 

89 def callback(self, future: Future[object]) -> None: 

90 """Remove the reference to the future from the handler.""" 

91 self.acquire() 

92 try: 

93 self.futures.discard(future) 

94 finally: 

95 self.release() 

96 

97 def emit( # type: ignore[override] 

98 self, record: LogRecord 

99 ) -> None | Coroutine[None, Never, object]: 

100 """ 

101 Do whatever it takes to actually log the specified logging record. 

102 

103 This version is intended to be implemented by subclasses and so 

104 raises a NotImplementedError. 

105 """ 

106 raise NotImplementedError( 

107 "emit must be implemented by AsyncHandler subclasses" 

108 ) 

109 

110 def handle( # type: ignore[override] 

111 self, record: LogRecord 

112 ) -> bool | LogRecord: 

113 """Handle incoming log records.""" 

114 rv = self.filter(record) 

115 if isinstance(rv, LogRecord): 

116 record = rv 

117 if rv and not self.loop.is_closed(): 

118 self.acquire() 

119 try: 

120 if awaitable := self.emit(record): 

121 future: Future[object] = asyncio.run_coroutine_threadsafe( 

122 awaitable, self.loop 

123 ) 

124 self.futures.add(future) 

125 future.add_done_callback(self.callback) 

126 finally: 

127 self.release() 

128 return rv 

129 

130 

131class DatetimeFormatter(logging.Formatter): 

132 """A logging formatter that formats the time using datetime.""" 

133 

134 timezone: None | tzinfo = None 

135 

136 def formatTime( # noqa: N802 

137 self, record: LogRecord, datefmt: None | str = None 

138 ) -> str: 

139 """Return the creation time of the LogRecord as formatted text.""" 

140 spam = datetime.fromtimestamp(record.created).astimezone(self.timezone) 

141 if datefmt: 

142 return spam.strftime(datefmt) 

143 return spam.isoformat() 

144 

145 

146class WebhookFormatter(DatetimeFormatter): 

147 """A logging formatter optimized for logging to a webhook.""" 

148 

149 escape_message = False 

150 max_message_length: int | None = None 

151 get_context_line: Callable[[LogRecord], str | None] | None = None 

152 

153 def format(self, record: LogRecord) -> str: 

154 """Format the specified record as text.""" 

155 record.message = record.getMessage() 

156 if self.usesTime(): 

157 record.asctime = self.formatTime(record, self.datefmt) 

158 if ( 

159 self.max_message_length is not None 

160 and len(record.message) > self.max_message_length 

161 ): 

162 record.message = record.message[: self.max_message_length] 

163 for line in get_minimal_traceback(record): 

164 if ( 

165 self.max_message_length is not None 

166 and len(line) + len(record.message) > self.max_message_length 

167 ): 

168 ellipsis = "…" 

169 if ( 

170 len(ellipsis) + len(record.message) 

171 <= self.max_message_length 

172 ): 

173 record.message += ellipsis 

174 break 

175 record.message += line 

176 if ( 

177 self.get_context_line 

178 and (context_line := self.get_context_line(record)) 

179 and ( 

180 (len(record.message) + 2 + len(context_line)) 

181 <= self.max_message_length 

182 if self.max_message_length 

183 else True 

184 ) 

185 ): 

186 record.message += f"\n\n{context_line}" 

187 if self.escape_message: 

188 record.message = json.dumps(record.message).decode("UTF-8")[1:-1] 

189 return self.formatMessage(record) 

190 

191 

192class WebhookHandler(AsyncHandler): 

193 """A logging handler that sends logs to a webhook.""" 

194 

195 url: str 

196 content_type: str 

197 

198 def __init__( 

199 self, 

200 level: int | str = logging.NOTSET, 

201 *, 

202 loop: AbstractEventLoop, 

203 url: str, 

204 content_type: str, 

205 ): 

206 """Initialize the handler.""" 

207 super().__init__(level=level, loop=loop) 

208 self.url = url 

209 self.content_type = content_type 

210 

211 async def emit(self, record: LogRecord) -> None: # type: ignore[override] 

212 """Send the request to the webhook.""" 

213 # pylint: disable=invalid-overridden-method 

214 try: 

215 message = self.format(record) 

216 await AsyncHTTPClient().fetch( 

217 self.url, 

218 method="POST", 

219 headers={"Content-Type": self.content_type}, 

220 body=message.strip(), 

221 ca_certs=CA_BUNDLE_PATH, 

222 ) 

223 except Exception: # pylint: disable=broad-except 

224 self.handleError(record)