Coverage for an_website/utils/logging.py: 35.644%

101 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-16 19:56 +0000

1# This program is free software: you can redistribute it and/or modify 

2# it under the terms of the GNU Affero General Public License as 

3# published by the Free Software Foundation, either version 3 of the 

4# License, or (at your option) any later version. 

5# 

6# This program is distributed in the hope that it will be useful, 

7# but WITHOUT ANY WARRANTY; without even the implied warranty of 

8# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

9# GNU Affero General Public License for more details. 

10# 

11# You should have received a copy of the GNU Affero General Public License 

12# along with this program. If not, see <https://www.gnu.org/licenses/>. 

13 

14"""Logging stuff used by the website.""" 

15 

16from __future__ import annotations 

17 

18import asyncio 

19import logging 

20import traceback 

21from asyncio import AbstractEventLoop 

22from collections.abc import Awaitable, Iterable 

23from concurrent.futures import Future 

24from datetime import datetime, tzinfo 

25from logging import LogRecord 

26from pathlib import Path 

27 

28import orjson as json 

29from tornado.httpclient import AsyncHTTPClient 

30 

31from an_website import DIR as AN_WEBSITE_DIR 

32 

33from .. import CA_BUNDLE_PATH 

34 

35HOME: str = Path("~/").expanduser().as_posix().rstrip("/") 

36 

37 

38def minify_filepath(path: str) -> str: 

39 """Make a filepath smaller.""" 

40 if path.startswith(f"{HOME}/"): 

41 return "~" + path.removeprefix(HOME) 

42 return path 

43 

44 

45def get_minimal_traceback( 

46 record: LogRecord, prefix: str = "\n\n" 

47) -> Iterable[str]: 

48 """Get a minimal traceback from the log record.""" 

49 if not record.exc_info: 

50 return 

51 (_, value, tb) = record.exc_info 

52 if not (value and tb): 

53 return 

54 

55 yield prefix 

56 yield from traceback.format_exception(value, limit=0) 

57 

58 summary = traceback.extract_tb(tb) 

59 if isinstance(AN_WEBSITE_DIR, Path): 

60 start_path = f"{str(AN_WEBSITE_DIR).rstrip('/')}/" 

61 

62 for i in reversed(range(len(summary))): 

63 if summary[i].filename.startswith(start_path): 

64 summary = traceback.StackSummary(summary[i:]) 

65 break 

66 

67 for frame in summary: 

68 frame.filename = minify_filepath(frame.filename) 

69 

70 yield from summary.format() 

71 

72 

73class AsyncHandler(logging.Handler): 

74 """A logging handler that can handle log records asynchronously.""" 

75 

76 futures: set[Future[object]] 

77 loop: AbstractEventLoop 

78 

79 def __init__( 

80 self, 

81 level: int | str = logging.NOTSET, 

82 *, 

83 loop: AbstractEventLoop, 

84 ): 

85 """Initialize the handler.""" 

86 super().__init__(level=level) 

87 self.futures = set() 

88 self.loop = loop 

89 

90 def callback(self, future: Future[object]) -> None: 

91 """Remove the reference to the future from the handler.""" 

92 self.acquire() 

93 try: 

94 self.futures.discard(future) 

95 finally: 

96 self.release() 

97 

98 def emit( # type: ignore[override] 

99 self, record: LogRecord 

100 ) -> None | Awaitable[object]: 

101 """ 

102 Do whatever it takes to actually log the specified logging record. 

103 

104 This version is intended to be implemented by subclasses and so 

105 raises a NotImplementedError. 

106 """ 

107 raise NotImplementedError( 

108 "emit must be implemented by AsyncHandler subclasses" 

109 ) 

110 

111 def handle( # type: ignore[override] 

112 self, record: LogRecord 

113 ) -> bool | LogRecord: 

114 """Handle incoming log records.""" 

115 rv = self.filter(record) 

116 if isinstance(rv, LogRecord): 

117 record = rv 

118 if rv and not self.loop.is_closed(): 

119 self.acquire() 

120 try: 

121 if awaitable := self.emit(record): 

122 future = asyncio.run_coroutine_threadsafe( 

123 awaitable, self.loop 

124 ) 

125 self.futures.add(future) 

126 future.add_done_callback(self.callback) 

127 finally: 

128 self.release() 

129 return rv 

130 

131 

132class DatetimeFormatter(logging.Formatter): 

133 """A logging formatter that formats the time using datetime.""" 

134 

135 timezone: None | tzinfo = None 

136 

137 def formatTime( # noqa: N802 

138 self, record: LogRecord, datefmt: None | str = None 

139 ) -> str: 

140 """Return the creation time of the LogRecord as formatted text.""" 

141 spam = datetime.fromtimestamp(record.created).astimezone(self.timezone) 

142 if datefmt: 

143 return spam.strftime(datefmt) 

144 return spam.isoformat() 

145 

146 

147class WebhookFormatter(DatetimeFormatter): 

148 """A logging formatter optimized for logging to a webhook.""" 

149 

150 escape_message = False 

151 max_message_length: int | None = None 

152 

153 def format(self, record: LogRecord) -> str: 

154 """Format the specified record as text.""" 

155 record.message = record.getMessage() 

156 if self.usesTime(): 

157 record.asctime = self.formatTime(record, self.datefmt) 

158 if ( 

159 self.max_message_length is not None 

160 and len(record.message) > self.max_message_length 

161 ): 

162 record.message = record.message[: self.max_message_length] 

163 for line in get_minimal_traceback(record): 

164 if ( 

165 self.max_message_length is not None 

166 and len(line) + len(record.message) > self.max_message_length 

167 ): 

168 if len("...") + len(record.message) <= self.max_message_length: 

169 record.message += "..." 

170 break 

171 record.message += line 

172 if self.escape_message: 

173 record.message = json.dumps(record.message).decode("UTF-8")[1:-1] 

174 return self.formatMessage(record) 

175 

176 

177class WebhookHandler(AsyncHandler): 

178 """A logging handler that sends logs to a webhook.""" 

179 

180 url: str 

181 content_type: str 

182 

183 def __init__( 

184 self, 

185 level: int | str = logging.NOTSET, 

186 *, 

187 loop: AbstractEventLoop, 

188 url: str, 

189 content_type: str, 

190 ): 

191 """Initialize the handler.""" 

192 super().__init__(level=level, loop=loop) 

193 self.url = url 

194 self.content_type = content_type 

195 

196 async def emit(self, record: LogRecord) -> None: # type: ignore[override] 

197 """Send the request to the webhook.""" 

198 # pylint: disable=invalid-overridden-method 

199 try: 

200 message = self.format(record) 

201 await AsyncHTTPClient().fetch( 

202 self.url, 

203 method="POST", 

204 headers={"Content-Type": self.content_type}, 

205 body=message.strip(), 

206 ca_certs=CA_BUNDLE_PATH, 

207 ) 

208 except Exception: # pylint: disable=broad-except 

209 self.handleError(record)