Coverage for an_website/utils/background_tasks.py: 45.283%

53 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-16 19:56 +0000

1# This program is free software: you can redistribute it and/or modify 

2# it under the terms of the GNU Affero General Public License as 

3# published by the Free Software Foundation, either version 3 of the 

4# License, or (at your option) any later version. 

5# 

6# This program is distributed in the hope that it will be useful, 

7# but WITHOUT ANY WARRANTY; without even the implied warranty of 

8# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

9# GNU Affero General Public License for more details. 

10# 

11# You should have received a copy of the GNU Affero General Public License 

12# along with this program. If not, see <https://www.gnu.org/licenses/>. 

13"""Tasks running in the background.""" 

14 

15from __future__ import annotations 

16 

17import asyncio 

18import logging 

19import os 

20import time 

21from collections.abc import Iterable, Set 

22from functools import wraps 

23from typing import TYPE_CHECKING, Final, Protocol, assert_type, cast 

24 

25import typed_stream 

26from elasticsearch import AsyncElasticsearch 

27from redis.asyncio import Redis 

28from tornado.web import Application 

29 

30from .. import EVENT_ELASTICSEARCH, EVENT_REDIS, EVENT_SHUTDOWN 

31from .elasticsearch_setup import setup_elasticsearch_configs 

32 

33if TYPE_CHECKING: 

34 from .utils import ModuleInfo 

35 

36LOGGER: Final = logging.getLogger(__name__) 

37 

38HEARTBEAT: float = 0 

39 

40 

41class BackgroundTask(Protocol): 

42 """A protocol representing a background task.""" 

43 

44 async def __call__(self, *, app: Application, worker: int | None) -> None: 

45 """Start the background task.""" 

46 

47 @property 

48 def __name__(self) -> str: # pylint: disable=bad-dunder-name 

49 """The name of the task.""" 

50 

51 

52async def check_elasticsearch( 

53 app: Application, worker: int | None 

54) -> None: # pragma: no cover 

55 """Check Elasticsearch.""" 

56 while not EVENT_SHUTDOWN.is_set(): # pylint: disable=while-used 

57 es: AsyncElasticsearch = cast( 

58 AsyncElasticsearch, app.settings.get("ELASTICSEARCH") 

59 ) 

60 try: 

61 await es.transport.perform_request("HEAD", "/") 

62 except Exception: # pylint: disable=broad-except 

63 EVENT_ELASTICSEARCH.clear() 

64 LOGGER.exception( 

65 "Connecting to Elasticsearch failed on worker: %s", worker 

66 ) 

67 else: 

68 if not EVENT_ELASTICSEARCH.is_set(): 

69 try: 

70 await setup_elasticsearch_configs( 

71 es, app.settings["ELASTICSEARCH_PREFIX"] 

72 ) 

73 except Exception: # pylint: disable=broad-except 

74 LOGGER.exception( 

75 "An exception occured while configuring Elasticsearch on worker: %s", # noqa: B950 

76 worker, 

77 ) 

78 else: 

79 EVENT_ELASTICSEARCH.set() 

80 await asyncio.sleep(20) 

81 

82 

83async def check_if_ppid_changed(ppid: int) -> None: 

84 """Check whether Technoblade hates us.""" 

85 while not EVENT_SHUTDOWN.is_set(): # pylint: disable=while-used 

86 if os.getppid() != ppid: 

87 EVENT_SHUTDOWN.set() 

88 return 

89 await asyncio.sleep(1) 

90 

91 

92async def check_redis( 

93 app: Application, worker: int | None 

94) -> None: # pragma: no cover 

95 """Check Redis.""" 

96 while not EVENT_SHUTDOWN.is_set(): # pylint: disable=while-used 

97 redis: Redis[str] = cast("Redis[str]", app.settings.get("REDIS")) 

98 try: 

99 await redis.ping() 

100 except Exception: # pylint: disable=broad-except 

101 EVENT_REDIS.clear() 

102 LOGGER.exception("Connecting to Redis failed on worker %s", worker) 

103 else: 

104 EVENT_REDIS.set() 

105 await asyncio.sleep(20) 

106 

107 

108async def heartbeat() -> None: 

109 """Heartbeat.""" 

110 global HEARTBEAT # pylint: disable=global-statement 

111 while HEARTBEAT: # pylint: disable=while-used 

112 HEARTBEAT = time.monotonic() 

113 await asyncio.sleep(0.05) 

114 

115 

116async def wait_for_shutdown() -> None: # pragma: no cover 

117 """Wait for the shutdown event.""" 

118 loop = asyncio.get_running_loop() 

119 while not EVENT_SHUTDOWN.is_set(): # pylint: disable=while-used 

120 await asyncio.sleep(0.05) 

121 loop.stop() 

122 

123 

124def start_background_tasks( # pylint: disable=too-many-arguments 

125 *, 

126 app: Application, 

127 processes: int, 

128 module_infos: Iterable[ModuleInfo], 

129 loop: asyncio.AbstractEventLoop, 

130 main_pid: int, 

131 elasticsearch_is_enabled: bool, 

132 redis_is_enabled: bool, 

133 worker: int | None, 

134) -> Set[asyncio.Task[None]]: 

135 """Start all required background tasks.""" 

136 

137 async def execute_background_task(task: BackgroundTask, /) -> None: 

138 """Execute a background task with error handling.""" 

139 try: 

140 await task(app=app, worker=worker) 

141 except asyncio.exceptions.CancelledError: 

142 pass 

143 except BaseException as exc: # pylint: disable=broad-exception-caught 

144 LOGGER.exception( 

145 "A %s exception occured while executing background task %s.%s", 

146 exc.__class__.__name__, 

147 task.__module__, 

148 task.__name__, 

149 ) 

150 if not isinstance(exc, Exception): 

151 raise 

152 else: 

153 LOGGER.debug( 

154 "Background task %s.%s finished executing", 

155 task.__module__, 

156 task.__name__, 

157 ) 

158 

159 background_tasks: set[asyncio.Task[None]] = set() 

160 

161 def create_task(fun: BackgroundTask, /) -> asyncio.Task[None]: 

162 """Create an asyncio.Task object from a BackgroundTask.""" 

163 name = f"{fun.__module__}.{fun.__name__}" 

164 if not worker: # log only once 

165 LOGGER.info("starting %s background task", name) 

166 task = loop.create_task(execute_background_task(fun), name=name) 

167 task.add_done_callback(background_tasks.discard) 

168 return task 

169 

170 task_stream: typed_stream.Stream[asyncio.Task[None]] = assert_type( 

171 typed_stream.Stream(module_infos) 

172 .flat_map(lambda info: info.required_background_tasks) 

173 .chain( 

174 typed_stream.Stream((heartbeat, wait_for_shutdown)).map( 

175 lambda fun: wraps(fun)(lambda **_: fun()) 

176 ) 

177 ) 

178 .chain( 

179 [ 

180 wraps(check_if_ppid_changed)( 

181 lambda **k: check_if_ppid_changed(main_pid) 

182 ) 

183 ] 

184 if processes 

185 else () 

186 ) 

187 .chain([check_elasticsearch] if elasticsearch_is_enabled else ()) 

188 .chain([check_redis] if redis_is_enabled else ()) 

189 .distinct() 

190 .map(create_task), 

191 typed_stream.Stream[asyncio.Task[None]], 

192 ) 

193 

194 background_tasks.update(task_stream) 

195 

196 return background_tasks