Coverage for an_website/utils/logging.py: 36.275%

102 statements  

« prev     ^ index     » next       coverage.py v7.8.2, created at 2025-06-01 08:32 +0000

1# This program is free software: you can redistribute it and/or modify 

2# it under the terms of the GNU Affero General Public License as 

3# published by the Free Software Foundation, either version 3 of the 

4# License, or (at your option) any later version. 

5# 

6# This program is distributed in the hope that it will be useful, 

7# but WITHOUT ANY WARRANTY; without even the implied warranty of 

8# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

9# GNU Affero General Public License for more details. 

10# 

11# You should have received a copy of the GNU Affero General Public License 

12# along with this program. If not, see <https://www.gnu.org/licenses/>. 

13 

14"""Logging stuff used by the website.""" 

15 

16from __future__ import annotations 

17 

18import asyncio 

19import logging 

20import traceback 

21from asyncio import AbstractEventLoop 

22from collections.abc import Coroutine, Iterable 

23from concurrent.futures import Future 

24from datetime import datetime, tzinfo 

25from logging import LogRecord 

26from pathlib import Path 

27from typing import Never 

28 

29import orjson as json 

30from tornado.httpclient import AsyncHTTPClient 

31 

32from an_website import DIR as AN_WEBSITE_DIR 

33 

34from .. import CA_BUNDLE_PATH 

35 

36HOME: str = Path("~/").expanduser().as_posix().rstrip("/") 

37 

38 

39def minify_filepath(path: str) -> str: 

40 """Make a filepath smaller.""" 

41 if path.startswith(f"{HOME}/"): 

42 return "~" + path.removeprefix(HOME) 

43 return path 

44 

45 

46def get_minimal_traceback( 

47 record: LogRecord, prefix: str = "\n\n" 

48) -> Iterable[str]: 

49 """Get a minimal traceback from the log record.""" 

50 if not record.exc_info: 

51 return 

52 (_, value, tb) = record.exc_info 

53 if not (value and tb): 

54 return 

55 

56 yield prefix 

57 yield from traceback.format_exception(value, limit=0) 

58 

59 summary = traceback.extract_tb(tb) 

60 if isinstance(AN_WEBSITE_DIR, Path): 

61 start_path = f"{str(AN_WEBSITE_DIR).rstrip('/')}/" 

62 

63 for i in reversed(range(len(summary))): 

64 if summary[i].filename.startswith(start_path): 

65 summary = traceback.StackSummary(summary[i:]) 

66 break 

67 

68 for frame in summary: 

69 frame.filename = minify_filepath(frame.filename) 

70 

71 yield from summary.format() 

72 

73 

74class AsyncHandler(logging.Handler): 

75 """A logging handler that can handle log records asynchronously.""" 

76 

77 futures: set[Future[object]] 

78 loop: AbstractEventLoop 

79 

80 def __init__( 

81 self, 

82 level: int | str = logging.NOTSET, 

83 *, 

84 loop: AbstractEventLoop, 

85 ): 

86 """Initialize the handler.""" 

87 super().__init__(level=level) 

88 self.futures = set() 

89 self.loop = loop 

90 

91 def callback(self, future: Future[object]) -> None: 

92 """Remove the reference to the future from the handler.""" 

93 self.acquire() 

94 try: 

95 self.futures.discard(future) 

96 finally: 

97 self.release() 

98 

99 def emit( # type: ignore[override] 

100 self, record: LogRecord 

101 ) -> None | Coroutine[None, Never, object]: 

102 """ 

103 Do whatever it takes to actually log the specified logging record. 

104 

105 This version is intended to be implemented by subclasses and so 

106 raises a NotImplementedError. 

107 """ 

108 raise NotImplementedError( 

109 "emit must be implemented by AsyncHandler subclasses" 

110 ) 

111 

112 def handle( # type: ignore[override] 

113 self, record: LogRecord 

114 ) -> bool | LogRecord: 

115 """Handle incoming log records.""" 

116 rv = self.filter(record) 

117 if isinstance(rv, LogRecord): 

118 record = rv 

119 if rv and not self.loop.is_closed(): 

120 self.acquire() 

121 try: 

122 if awaitable := self.emit(record): 

123 future: Future[object] = asyncio.run_coroutine_threadsafe( 

124 awaitable, self.loop 

125 ) 

126 self.futures.add(future) 

127 future.add_done_callback(self.callback) 

128 finally: 

129 self.release() 

130 return rv 

131 

132 

133class DatetimeFormatter(logging.Formatter): 

134 """A logging formatter that formats the time using datetime.""" 

135 

136 timezone: None | tzinfo = None 

137 

138 def formatTime( # noqa: N802 

139 self, record: LogRecord, datefmt: None | str = None 

140 ) -> str: 

141 """Return the creation time of the LogRecord as formatted text.""" 

142 spam = datetime.fromtimestamp(record.created).astimezone(self.timezone) 

143 if datefmt: 

144 return spam.strftime(datefmt) 

145 return spam.isoformat() 

146 

147 

148class WebhookFormatter(DatetimeFormatter): 

149 """A logging formatter optimized for logging to a webhook.""" 

150 

151 escape_message = False 

152 max_message_length: int | None = None 

153 

154 def format(self, record: LogRecord) -> str: 

155 """Format the specified record as text.""" 

156 record.message = record.getMessage() 

157 if self.usesTime(): 

158 record.asctime = self.formatTime(record, self.datefmt) 

159 if ( 

160 self.max_message_length is not None 

161 and len(record.message) > self.max_message_length 

162 ): 

163 record.message = record.message[: self.max_message_length] 

164 for line in get_minimal_traceback(record): 

165 if ( 

166 self.max_message_length is not None 

167 and len(line) + len(record.message) > self.max_message_length 

168 ): 

169 if len("...") + len(record.message) <= self.max_message_length: 

170 record.message += "..." 

171 break 

172 record.message += line 

173 if self.escape_message: 

174 record.message = json.dumps(record.message).decode("UTF-8")[1:-1] 

175 return self.formatMessage(record) 

176 

177 

178class WebhookHandler(AsyncHandler): 

179 """A logging handler that sends logs to a webhook.""" 

180 

181 url: str 

182 content_type: str 

183 

184 def __init__( 

185 self, 

186 level: int | str = logging.NOTSET, 

187 *, 

188 loop: AbstractEventLoop, 

189 url: str, 

190 content_type: str, 

191 ): 

192 """Initialize the handler.""" 

193 super().__init__(level=level, loop=loop) 

194 self.url = url 

195 self.content_type = content_type 

196 

197 async def emit(self, record: LogRecord) -> None: # type: ignore[override] 

198 """Send the request to the webhook.""" 

199 # pylint: disable=invalid-overridden-method 

200 try: 

201 message = self.format(record) 

202 await AsyncHTTPClient().fetch( 

203 self.url, 

204 method="POST", 

205 headers={"Content-Type": self.content_type}, 

206 body=message.strip(), 

207 ca_certs=CA_BUNDLE_PATH, 

208 ) 

209 except Exception: # pylint: disable=broad-except 

210 self.handleError(record)