Coverage for an_website / emoji_chat / chat.py: 48.795%

166 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-01-11 19:37 +0000

1# This program is free software: you can redistribute it and/or modify 

2# it under the terms of the GNU Affero General Public License as 

3# published by the Free Software Foundation, either version 3 of the 

4# License, or (at your option) any later version. 

5# 

6# This program is distributed in the hope that it will be useful, 

7# but WITHOUT ANY WARRANTY; without even the implied warranty of 

8# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

9# GNU Affero General Public License for more details. 

10# 

11# You should have received a copy of the GNU Affero General Public License 

12# along with this program. If not, see <https://www.gnu.org/licenses/>. 

13 

14"""A 🆒 chat.""" 

15 

16from __future__ import annotations 

17 

18import asyncio 

19import logging 

20import random 

21import sys 

22import time 

23from collections.abc import Awaitable, Iterable, Mapping 

24from typing import Any, Final, Literal 

25 

26import orjson as json 

27from emoji import EMOJI_DATA, demojize, emoji_list, emojize, purely_emoji 

28from redis.asyncio import Redis 

29from tornado.web import Application, HTTPError 

30from tornado.websocket import WebSocketHandler 

31 

32from .. import EPOCH_MS, EVENT_REDIS, EVENT_SHUTDOWN, NAME, ORJSON_OPTIONS 

33from ..utils.base_request_handler import BaseRequestHandler 

34from ..utils.request_handler import APIRequestHandler, HTMLRequestHandler 

35from ..utils.utils import Permission, ratelimit 

36from .pub_sub_provider import PubSubProvider 

37 

38LOGGER: Final = logging.getLogger(__name__) 

39 

40EMOJIS_NO_FLAGS: Final[tuple[str, ...]] = tuple( 

41 emoji 

42 for emoji in EMOJI_DATA 

43 if ord(emoji[0]) not in range(0x1F1E6, 0x1F200) 

44) 

45 

46MAX_MESSAGE_SAVE_COUNT: Final = 200 

47MAX_MESSAGE_LENGTH: Final = 20 

48REDIS_CHANNEL: Final = f"{NAME}:emoji_chat_channel" 

49 

50 

51def get_ms_timestamp() -> int: 

52 """Get the current time in ms.""" 

53 return time.time_ns() // 1_000_000 - EPOCH_MS 

54 

55 

56async def subscribe_to_redis_channel( 

57 app: Application, worker: int | None 

58) -> None: 

59 """Subscribe to the Redis channel and handle incoming messages.""" 

60 get_pubsub = PubSubProvider((REDIS_CHANNEL,), app.settings, worker) 

61 del app 

62 

63 while not EVENT_SHUTDOWN.is_set(): # pylint: disable=while-used 

64 ps = await get_pubsub() 

65 try: 

66 message = await ps.get_message(timeout=5.0) 

67 except Exception as exc: # pylint: disable=broad-exception-caught 

68 if str(exc) == "Connection closed by server.": 

69 continue 

70 LOGGER.exception("Failed to get message on worker %s", worker) 

71 await asyncio.sleep(0) 

72 continue 

73 

74 match message: 

75 case None: 

76 pass 

77 case { 

78 "type": "message", 

79 "data": str(data), 

80 "channel": channel, 

81 } if ( 

82 channel == REDIS_CHANNEL 

83 ): 

84 await asyncio.gather( 

85 *[conn.write_message(data) for conn in OPEN_CONNECTIONS] 

86 ) 

87 case { 

88 "type": "subscribe", 

89 "data": 1, 

90 "channel": channel, 

91 } if ( 

92 channel == REDIS_CHANNEL 

93 ): 

94 logging.info( 

95 "Subscribed to Redis channel %r on worker %s", 

96 channel, 

97 worker, 

98 ) 

99 case _: 

100 logging.error( 

101 "Got unexpected message %s on worker %s", 

102 message, 

103 worker, 

104 ) 

105 await asyncio.sleep(0) 

106 

107 

108async def save_new_message( 

109 author: str, 

110 message: str, 

111 redis: Redis[str], 

112 redis_prefix: str, 

113) -> None: 

114 """Save a new message.""" 

115 message_dict = { 

116 "author": [data["emoji"] for data in emoji_list(author)], 

117 "content": [data["emoji"] for data in emoji_list(message)], 

118 "timestamp": get_ms_timestamp(), 

119 } 

120 await redis.rpush( 

121 f"{redis_prefix}:emoji-chat:message-list", 

122 json.dumps(message_dict, option=ORJSON_OPTIONS), 

123 ) 

124 await redis.ltrim( 

125 f"{redis_prefix}:emoji-chat:message-list", -MAX_MESSAGE_SAVE_COUNT, -1 

126 ) 

127 LOGGER.info("GOT new message %s", message_dict) 

128 await redis.publish( 

129 REDIS_CHANNEL, 

130 json.dumps( 

131 { 

132 "type": "message", 

133 "message": message_dict, 

134 }, 

135 option=ORJSON_OPTIONS, 

136 ), 

137 ) 

138 

139 

140async def get_messages( 

141 redis: Redis[str], 

142 redis_prefix: str, 

143 start: None | int = None, 

144 stop: int = -1, 

145) -> list[dict[str, Any]]: 

146 """Get the messages.""" 

147 start = start if start is not None else -MAX_MESSAGE_SAVE_COUNT 

148 messages = await redis.lrange( 

149 f"{redis_prefix}:emoji-chat:message-list", start, stop 

150 ) 

151 return [json.loads(message) for message in messages] 

152 

153 

154def check_message_invalid(message: str) -> Literal[False] | str: 

155 """Check if a message is an invalid message.""" 

156 if not message: 

157 return "Empty message not allowed." 

158 

159 if not purely_emoji(message): 

160 return "Message can only contain emojis." 

161 

162 if len(emoji_list(message)) > MAX_MESSAGE_LENGTH: 

163 return f"Message longer than {MAX_MESSAGE_LENGTH} emojis." 

164 

165 return False 

166 

167 

168def emojize_user_input(string: str) -> str: 

169 """Emojize user input.""" 

170 string = emojize(string, language="de") 

171 string = emojize(string, language="en") 

172 string = emojize(string, language="alias") 

173 return string 

174 

175 

176def normalize_emojis(string: str) -> str: 

177 """Normalize emojis in a string.""" 

178 return emojize(demojize(string)) 

179 

180 

181def get_random_name() -> str: 

182 """Generate a random name.""" 

183 return normalize_emojis( 

184 "".join(random.sample(EMOJIS_NO_FLAGS, 5)) # nosec: B311 

185 ) 

186 

187 

188class ChatHandler(BaseRequestHandler): 

189 """The request handler for the emoji chat.""" 

190 

191 RATELIMIT_GET_BUCKET = "emoji-chat-get-messages" 

192 RATELIMIT_GET_LIMIT = 10 

193 RATELIMIT_GET_COUNT_PER_PERIOD = 10 

194 RATELIMIT_GET_PERIOD = 1 

195 

196 RATELIMIT_POST_BUCKET = "emoji-chat-send-message" 

197 RATELIMIT_POST_LIMIT = 5 

198 RATELIMIT_POST_COUNT_PER_PERIOD = 5 

199 RATELIMIT_POST_PERIOD = 5 

200 

201 async def get( 

202 self, 

203 *, 

204 head: bool = False, 

205 ) -> None: 

206 """Show the users the current messages.""" 

207 if not EVENT_REDIS.is_set(): 

208 raise HTTPError(503) 

209 

210 if head: 

211 return 

212 

213 await self.render_chat( 

214 await get_messages(self.redis, self.redis_prefix) 

215 ) 

216 

217 async def get_name(self) -> str: 

218 """Get the name of the user.""" 

219 cookie = self.get_secure_cookie( 

220 "emoji-chat-name", 

221 max_age_days=90, 

222 min_version=2, 

223 ) 

224 

225 name = cookie.decode("UTF-8") if cookie else get_random_name() 

226 

227 # save it in cookie or reset expiry date 

228 if not self.get_secure_cookie( 

229 "emoji-chat-name", max_age_days=30, min_version=2 

230 ): 

231 self.set_secure_cookie( 

232 "emoji-chat-name", 

233 name.encode("UTF-8"), 

234 expires_days=90, 

235 path="/", 

236 samesite="Strict", 

237 ) 

238 

239 geoip = await self.geoip() or {} 

240 if "country_flag" in geoip: 

241 flag = geoip["country_flag"] 

242 elif self.request.host_name.endswith(".onion"): 

243 flag = "🏴‍☠" 

244 else: 

245 flag = "❔" 

246 

247 return normalize_emojis(name + flag) 

248 

249 async def get_name_as_list(self) -> list[str]: 

250 """Return the name as list of emojis.""" 

251 return [emoji["emoji"] for emoji in emoji_list(await self.get_name())] 

252 

253 async def post(self) -> None: 

254 """Let users send messages and show the users the current messages.""" 

255 if not EVENT_REDIS.is_set(): 

256 raise HTTPError(503) 

257 

258 message = emojize_user_input( 

259 normalize_emojis(self.get_argument("message")) 

260 ) 

261 

262 if err := check_message_invalid(message): 

263 raise HTTPError(400, reason=err) 

264 

265 await save_new_message( 

266 await self.get_name(), 

267 message, 

268 redis=self.redis, 

269 redis_prefix=self.redis_prefix, 

270 ) 

271 

272 await self.render_chat( 

273 await get_messages(self.redis, self.redis_prefix) 

274 ) 

275 

276 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None: 

277 """Render the chat.""" 

278 raise NotImplementedError 

279 

280 

281class HTMLChatHandler(ChatHandler, HTMLRequestHandler): 

282 """The HTML request handler for the emoji chat.""" 

283 

284 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None: 

285 """Render the chat.""" 

286 await self.render( 

287 "pages/emoji_chat.html", 

288 messages=messages, 

289 user_name=await self.get_name_as_list(), 

290 ) 

291 

292 

293class APIChatHandler(ChatHandler, APIRequestHandler): 

294 """The API request handler for the emoji chat.""" 

295 

296 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None: 

297 """Render the chat.""" 

298 await self.finish( 

299 { 

300 "current_user": await self.get_name_as_list(), 

301 "messages": messages, 

302 } 

303 ) 

304 

305 

306OPEN_CONNECTIONS: list[ChatWebSocketHandler] = [] 

307 

308 

309class ChatWebSocketHandler(WebSocketHandler, ChatHandler): 

310 """The handler for the chat WebSocket.""" 

311 

312 name: str 

313 connection_time: int 

314 

315 def on_close(self) -> None: # noqa: D102 

316 LOGGER.info("WebSocket closed") 

317 OPEN_CONNECTIONS.remove(self) 

318 for conn in OPEN_CONNECTIONS: 

319 conn.send_users() 

320 

321 def on_message(self, message: str | bytes) -> Awaitable[None] | None: 

322 """Respond to an incoming message.""" 

323 if not message: 

324 return None 

325 message2: dict[str, Any] = json.loads(message) 

326 if message2["type"] == "message": 

327 if "message" not in message2: 

328 return self.write_message( 

329 { 

330 "type": "error", 

331 "error": "Message needs message key with the message.", 

332 } 

333 ) 

334 return self.save_new_message(message2["message"]) 

335 

336 return self.write_message( 

337 {"type": "error", "error": f"Unknown type {message2['type']}."} 

338 ) 

339 

340 async def open(self, *args: str, **kwargs: str) -> None: 

341 # pylint: disable=invalid-overridden-method 

342 """Handle an opened connection.""" 

343 LOGGER.info("WebSocket opened") 

344 await self.write_message( 

345 { 

346 "type": "init", 

347 "current_user": [ 

348 emoji["emoji"] for emoji in emoji_list(self.name) 

349 ], 

350 } 

351 ) 

352 

353 self.connection_time = get_ms_timestamp() 

354 OPEN_CONNECTIONS.append(self) 

355 for conn in OPEN_CONNECTIONS: 

356 conn.send_users() 

357 

358 await self.send_messages() 

359 

360 async def prepare(self) -> None: # noqa: D102 

361 self.now = await self.get_time() 

362 

363 if not EVENT_REDIS.is_set(): 

364 raise HTTPError(503) 

365 

366 self.name = await self.get_name() 

367 

368 if not await self.ratelimit(True): 

369 await self.ratelimit() 

370 

371 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None: 

372 """Render the chat.""" 

373 raise NotImplementedError 

374 

375 async def save_new_message(self, msg_text: str) -> None: 

376 """Save a new message.""" 

377 msg_text = emojize_user_input(normalize_emojis(msg_text).strip()) 

378 if err := check_message_invalid(msg_text): 

379 return await self.write_message({"type": "error", "error": err}) 

380 

381 if self.settings.get("RATELIMITS") and not self.is_authorized( 

382 Permission.RATELIMITS 

383 ): 

384 if not EVENT_REDIS.is_set(): 

385 return await self.write_message({"type": "ratelimit"}) 

386 

387 ratelimited, headers = await ratelimit( 

388 self.redis, 

389 self.redis_prefix, 

390 str(self.request.remote_ip), 

391 bucket=self.RATELIMIT_POST_BUCKET, 

392 max_burst=self.RATELIMIT_POST_LIMIT - 1, 

393 count_per_period=self.RATELIMIT_POST_COUNT_PER_PERIOD, 

394 period=self.RATELIMIT_POST_PERIOD, 

395 tokens=1, 

396 ) 

397 

398 if ratelimited: 

399 return await self.write_message( 

400 {"type": "ratelimit", "retry_after": headers["Retry-After"]} 

401 ) 

402 

403 return await save_new_message( 

404 self.name, msg_text, self.redis, self.redis_prefix 

405 ) 

406 

407 async def send_messages(self) -> None: 

408 """Send this WebSocket all current messages.""" 

409 return await self.write_message( 

410 { 

411 "type": "messages", 

412 "messages": await get_messages(self.redis, self.redis_prefix), 

413 }, 

414 ) 

415 

416 def send_users(self) -> None: 

417 """Send this WebSocket all current users.""" 

418 if sys.flags.dev_mode: 

419 self.write_message( # type: ignore[unused-awaitable] 

420 { 

421 "type": "users", 

422 "users": [ 

423 { 

424 "name": conn.name, 

425 "joined_at": conn.connection_time, 

426 } 

427 for conn in OPEN_CONNECTIONS 

428 ], 

429 } 

430 )