Coverage for an_website/emoji_chat/chat.py: 49.102%

167 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-07 13:44 +0000

1# This program is free software: you can redistribute it and/or modify 

2# it under the terms of the GNU Affero General Public License as 

3# published by the Free Software Foundation, either version 3 of the 

4# License, or (at your option) any later version. 

5# 

6# This program is distributed in the hope that it will be useful, 

7# but WITHOUT ANY WARRANTY; without even the implied warranty of 

8# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

9# GNU Affero General Public License for more details. 

10# 

11# You should have received a copy of the GNU Affero General Public License 

12# along with this program. If not, see <https://www.gnu.org/licenses/>. 

13 

14"""A 🆒 chat.""" 

15 

16from __future__ import annotations 

17 

18import asyncio 

19import logging 

20import random 

21import sys 

22import time 

23from collections.abc import Awaitable, Iterable, Mapping 

24from typing import Any, Final, Literal 

25 

26import orjson as json 

27from emoji import purely_emoji # type: ignore[attr-defined] 

28from emoji import EMOJI_DATA, demojize, emoji_list, emojize 

29from redis.asyncio import Redis 

30from tornado.web import Application, HTTPError 

31from tornado.websocket import WebSocketHandler 

32 

33from .. import EPOCH_MS, EVENT_REDIS, EVENT_SHUTDOWN, NAME, ORJSON_OPTIONS 

34from ..utils.base_request_handler import BaseRequestHandler 

35from ..utils.request_handler import APIRequestHandler, HTMLRequestHandler 

36from ..utils.utils import Permission, ratelimit 

37from .pub_sub_provider import PubSubProvider 

38 

39LOGGER: Final = logging.getLogger(__name__) 

40 

41EMOJIS_NO_FLAGS: Final[tuple[str, ...]] = tuple( 

42 emoji 

43 for emoji in EMOJI_DATA 

44 if ord(emoji[0]) not in range(0x1F1E6, 0x1F200) 

45) 

46 

47MAX_MESSAGE_SAVE_COUNT: Final = 100 

48MAX_MESSAGE_LENGTH: Final = 20 

49REDIS_CHANNEL: Final = f"{NAME}:emoji_chat_channel" 

50 

51 

52def get_ms_timestamp() -> int: 

53 """Get the current time in ms.""" 

54 return time.time_ns() // 1_000_000 - EPOCH_MS 

55 

56 

57async def subscribe_to_redis_channel( 

58 app: Application, worker: int | None 

59) -> None: 

60 """Subscribe to the Redis channel and handle incoming messages.""" 

61 get_pubsub = PubSubProvider((REDIS_CHANNEL,), app.settings, worker) 

62 del app 

63 

64 while not EVENT_SHUTDOWN.is_set(): # pylint: disable=while-used 

65 ps = await get_pubsub() 

66 try: 

67 message = await ps.get_message(timeout=5.0) 

68 except Exception as exc: # pylint: disable=broad-exception-caught 

69 if str(exc) == "Connection closed by server.": 

70 continue 

71 LOGGER.exception("Failed to get message on worker %s", worker) 

72 await asyncio.sleep(0) 

73 continue 

74 

75 match message: 

76 case None: 

77 pass 

78 case { 

79 "type": "message", 

80 "data": str(data), 

81 "channel": channel, 

82 } if channel == REDIS_CHANNEL: 

83 await asyncio.gather( 

84 *[conn.write_message(data) for conn in OPEN_CONNECTIONS] 

85 ) 

86 case { 

87 "type": "subscribe", 

88 "data": 1, 

89 "channel": channel, 

90 } if channel == REDIS_CHANNEL: 

91 logging.info( 

92 "Subscribed to Redis channel %r on worker %s", 

93 channel, 

94 worker, 

95 ) 

96 case _: 

97 logging.error( 

98 "Got unexpected message %s on worker %s", 

99 message, 

100 worker, 

101 ) 

102 await asyncio.sleep(0) 

103 

104 

105async def save_new_message( 

106 author: str, 

107 message: str, 

108 redis: Redis[str], 

109 redis_prefix: str, 

110) -> None: 

111 """Save a new message.""" 

112 message_dict = { 

113 "author": [data["emoji"] for data in emoji_list(author)], 

114 "content": [data["emoji"] for data in emoji_list(message)], 

115 "timestamp": get_ms_timestamp(), 

116 } 

117 await redis.rpush( 

118 f"{redis_prefix}:emoji-chat:message-list", 

119 json.dumps(message_dict, option=ORJSON_OPTIONS), 

120 ) 

121 await redis.ltrim( 

122 f"{redis_prefix}:emoji-chat:message-list", -MAX_MESSAGE_SAVE_COUNT, -1 

123 ) 

124 LOGGER.info("GOT new message %s", message_dict) 

125 await redis.publish( 

126 REDIS_CHANNEL, 

127 json.dumps( 

128 { 

129 "type": "message", 

130 "message": message_dict, 

131 }, 

132 option=ORJSON_OPTIONS, 

133 ), 

134 ) 

135 

136 

137async def get_messages( 

138 redis: Redis[str], 

139 redis_prefix: str, 

140 start: None | int = None, 

141 stop: int = -1, 

142) -> list[dict[str, Any]]: 

143 """Get the messages.""" 

144 start = start if start is not None else -MAX_MESSAGE_SAVE_COUNT 

145 messages = await redis.lrange( 

146 f"{redis_prefix}:emoji-chat:message-list", start, stop 

147 ) 

148 return [json.loads(message) for message in messages] 

149 

150 

151def check_message_invalid(message: str) -> Literal[False] | str: 

152 """Check if a message is an invalid message.""" 

153 if not message: 

154 return "Empty message not allowed." 

155 

156 if not purely_emoji(message): 

157 return "Message can only contain emojis." 

158 

159 if len(emoji_list(message)) > MAX_MESSAGE_LENGTH: 

160 return f"Message longer than {MAX_MESSAGE_LENGTH} emojis." 

161 

162 return False 

163 

164 

165def emojize_user_input(string: str) -> str: 

166 """Emojize user input.""" 

167 string = emojize(string, language="de") 

168 string = emojize(string, language="en") 

169 string = emojize(string, language="alias") 

170 return string 

171 

172 

173def normalize_emojis(string: str) -> str: 

174 """Normalize emojis in a string.""" 

175 return emojize(demojize(string)) 

176 

177 

178def get_random_name() -> str: 

179 """Generate a random name.""" 

180 return normalize_emojis( 

181 "".join(random.sample(EMOJIS_NO_FLAGS, 5)) # nosec: B311 

182 ) 

183 

184 

185class ChatHandler(BaseRequestHandler): 

186 """The request handler for the emoji chat.""" 

187 

188 RATELIMIT_GET_BUCKET = "emoji-chat-get-messages" 

189 RATELIMIT_GET_LIMIT = 10 

190 RATELIMIT_GET_COUNT_PER_PERIOD = 10 

191 RATELIMIT_GET_PERIOD = 1 

192 

193 RATELIMIT_POST_BUCKET = "emoji-chat-send-message" 

194 RATELIMIT_POST_LIMIT = 5 

195 RATELIMIT_POST_COUNT_PER_PERIOD = 5 

196 RATELIMIT_POST_PERIOD = 5 

197 

198 async def get( 

199 self, 

200 *, 

201 head: bool = False, 

202 ) -> None: 

203 """Show the users the current messages.""" 

204 if not EVENT_REDIS.is_set(): 

205 raise HTTPError(503) 

206 

207 if head: 

208 return 

209 

210 await self.render_chat( 

211 await get_messages(self.redis, self.redis_prefix) 

212 ) 

213 

214 async def get_name(self) -> str: 

215 """Get the name of the user.""" 

216 cookie = self.get_secure_cookie( 

217 "emoji-chat-name", 

218 max_age_days=90, 

219 min_version=2, 

220 ) 

221 

222 name = cookie.decode("UTF-8") if cookie else get_random_name() 

223 

224 # save it in cookie or reset expiry date 

225 if not self.get_secure_cookie( 

226 "emoji-chat-name", max_age_days=30, min_version=2 

227 ): 

228 self.set_secure_cookie( 

229 "emoji-chat-name", 

230 name.encode("UTF-8"), 

231 expires_days=90, 

232 path="/", 

233 samesite="Strict", 

234 ) 

235 

236 geoip = await self.geoip() or {} 

237 if "country_flag" in geoip: 

238 flag = geoip["country_flag"] 

239 elif self.request.host_name.endswith(".onion"): 

240 flag = "🏴☠" 

241 else: 

242 flag = "❔" 

243 

244 return normalize_emojis(name + flag) 

245 

246 async def get_name_as_list(self) -> list[str]: 

247 """Return the name as list of emojis.""" 

248 return [emoji["emoji"] for emoji in emoji_list(await self.get_name())] 

249 

250 async def post(self) -> None: 

251 """Let users send messages and show the users the current messages.""" 

252 if not EVENT_REDIS.is_set(): 

253 raise HTTPError(503) 

254 

255 message = emojize_user_input( 

256 normalize_emojis(self.get_argument("message")) 

257 ) 

258 

259 if err := check_message_invalid(message): 

260 raise HTTPError(400, reason=err) 

261 

262 await save_new_message( 

263 await self.get_name(), 

264 message, 

265 redis=self.redis, 

266 redis_prefix=self.redis_prefix, 

267 ) 

268 

269 await self.render_chat( 

270 await get_messages(self.redis, self.redis_prefix) 

271 ) 

272 

273 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None: 

274 """Render the chat.""" 

275 raise NotImplementedError 

276 

277 

278class HTMLChatHandler(ChatHandler, HTMLRequestHandler): 

279 """The HTML request handler for the emoji chat.""" 

280 

281 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None: 

282 """Render the chat.""" 

283 await self.render( 

284 "pages/emoji_chat.html", 

285 messages=messages, 

286 user_name=await self.get_name_as_list(), 

287 ) 

288 

289 

290class APIChatHandler(ChatHandler, APIRequestHandler): 

291 """The API request handler for the emoji chat.""" 

292 

293 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None: 

294 """Render the chat.""" 

295 await self.finish( 

296 { 

297 "current_user": await self.get_name_as_list(), 

298 "messages": messages, 

299 } 

300 ) 

301 

302 

303OPEN_CONNECTIONS: list[ChatWebSocketHandler] = [] 

304 

305 

306class ChatWebSocketHandler(WebSocketHandler, ChatHandler): 

307 """The handler for the chat WebSocket.""" 

308 

309 name: str 

310 connection_time: int 

311 

312 def on_close(self) -> None: # noqa: D102 

313 LOGGER.info("WebSocket closed") 

314 OPEN_CONNECTIONS.remove(self) 

315 for conn in OPEN_CONNECTIONS: 

316 conn.send_users() 

317 

318 def on_message(self, message: str | bytes) -> Awaitable[None] | None: 

319 """Respond to an incoming message.""" 

320 if not message: 

321 return None 

322 message2: dict[str, Any] = json.loads(message) 

323 if message2["type"] == "message": 

324 if "message" not in message2: 

325 return self.write_message( 

326 { 

327 "type": "error", 

328 "error": "Message needs message key with the message.", 

329 } 

330 ) 

331 return self.save_new_message(message2["message"]) 

332 

333 return self.write_message( 

334 {"type": "error", "error": f"Unknown type {message2['type']}."} 

335 ) 

336 

337 async def open(self, *args: str, **kwargs: str) -> None: 

338 # pylint: disable=invalid-overridden-method 

339 """Handle an opened connection.""" 

340 LOGGER.info("WebSocket opened") 

341 await self.write_message( 

342 { 

343 "type": "init", 

344 "current_user": [ 

345 emoji["emoji"] for emoji in emoji_list(self.name) 

346 ], 

347 } 

348 ) 

349 

350 self.connection_time = get_ms_timestamp() 

351 OPEN_CONNECTIONS.append(self) 

352 for conn in OPEN_CONNECTIONS: 

353 conn.send_users() 

354 

355 await self.send_messages() 

356 

357 async def prepare(self) -> None: # noqa: D102 

358 self.now = await self.get_time() 

359 

360 if not EVENT_REDIS.is_set(): 

361 raise HTTPError(503) 

362 

363 self.name = await self.get_name() 

364 

365 if not await self.ratelimit(True): 

366 await self.ratelimit() 

367 

368 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None: 

369 """Render the chat.""" 

370 raise NotImplementedError 

371 

372 async def save_new_message(self, msg_text: str) -> None: 

373 """Save a new message.""" 

374 msg_text = emojize_user_input(normalize_emojis(msg_text).strip()) 

375 if err := check_message_invalid(msg_text): 

376 return await self.write_message({"type": "error", "error": err}) 

377 

378 if self.settings.get("RATELIMITS") and not self.is_authorized( 

379 Permission.RATELIMITS 

380 ): 

381 if not EVENT_REDIS.is_set(): 

382 return await self.write_message({"type": "ratelimit"}) 

383 

384 ratelimited, headers = await ratelimit( 

385 self.redis, 

386 self.redis_prefix, 

387 str(self.request.remote_ip), 

388 bucket=self.RATELIMIT_POST_BUCKET, 

389 max_burst=self.RATELIMIT_POST_LIMIT - 1, 

390 count_per_period=self.RATELIMIT_POST_COUNT_PER_PERIOD, 

391 period=self.RATELIMIT_POST_PERIOD, 

392 tokens=1, 

393 ) 

394 

395 if ratelimited: 

396 return await self.write_message( 

397 {"type": "ratelimit", "retry_after": headers["Retry-After"]} 

398 ) 

399 

400 return await save_new_message( 

401 self.name, msg_text, self.redis, self.redis_prefix 

402 ) 

403 

404 async def send_messages(self) -> None: 

405 """Send this WebSocket all current messages.""" 

406 return await self.write_message( 

407 { 

408 "type": "messages", 

409 "messages": await get_messages(self.redis, self.redis_prefix), 

410 }, 

411 ) 

412 

413 def send_users(self) -> None: 

414 """Send this WebSocket all current users.""" 

415 if sys.flags.dev_mode: 

416 self.write_message( # type: ignore[unused-awaitable] 

417 { 

418 "type": "users", 

419 "users": [ 

420 { 

421 "name": conn.name, 

422 "joined_at": conn.connection_time, 

423 } 

424 for conn in OPEN_CONNECTIONS 

425 ], 

426 } 

427 )