Coverage for an_website/emoji_chat/chat.py: 47.853%

163 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-10 18:56 +0000

1# This program is free software: you can redistribute it and/or modify 

2# it under the terms of the GNU Affero General Public License as 

3# published by the Free Software Foundation, either version 3 of the 

4# License, or (at your option) any later version. 

5# 

6# This program is distributed in the hope that it will be useful, 

7# but WITHOUT ANY WARRANTY; without even the implied warranty of 

8# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

9# GNU Affero General Public License for more details. 

10# 

11# You should have received a copy of the GNU Affero General Public License 

12# along with this program. If not, see <https://www.gnu.org/licenses/>. 

13 

14"""A 🆒 chat.""" 

15 

16import asyncio 

17import logging 

18import random 

19import sys 

20import time 

21from collections.abc import Awaitable, Iterable, Mapping 

22from typing import Any, Final, Literal 

23 

24import orjson as json 

25from emoji import EMOJI_DATA, demojize, emoji_list, emojize, purely_emoji 

26from redis.asyncio import Redis 

27from tornado.web import Application, HTTPError 

28from tornado.websocket import WebSocketHandler 

29 

30from .. import EPOCH_MS, EVENT_REDIS, EVENT_SHUTDOWN, NAME, ORJSON_OPTIONS 

31from ..utils.base_request_handler import BaseRequestHandler 

32from ..utils.request_handler import APIRequestHandler, HTMLRequestHandler 

33from ..utils.utils import Permission, ratelimit 

34from .pub_sub_provider import PubSubProvider 

35 

36LOGGER: Final = logging.getLogger(__name__) 

37 

38EMOJIS_NO_FLAGS: Final[tuple[str, ...]] = tuple( 

39 emoji 

40 for emoji in EMOJI_DATA 

41 if ord(emoji[0]) not in range(0x1F1E6, 0x1F200) 

42) 

43 

44MAX_MESSAGE_SAVE_COUNT: Final = 200 

45MAX_MESSAGE_LENGTH: Final = 20 

46REDIS_CHANNEL: Final = f"{NAME}:emoji_chat_channel" 

47 

48 

49def get_ms_timestamp() -> int: 

50 """Get the current time in ms.""" 

51 return time.time_ns() // 1_000_000 - EPOCH_MS 

52 

53 

54async def subscribe_to_redis_channel( 

55 app: Application, worker: int | None 

56) -> None: 

57 """Subscribe to the Redis channel and handle incoming messages.""" 

58 get_pubsub = PubSubProvider((REDIS_CHANNEL,), app.settings, worker) 

59 del app 

60 

61 while not EVENT_SHUTDOWN.is_set(): # pylint: disable=while-used 

62 ps = await get_pubsub() 

63 try: 

64 message = await ps.get_message(timeout=5.0) 

65 except Exception as exc: # pylint: disable=broad-exception-caught 

66 if str(exc) == "Connection closed by server.": 

67 continue 

68 LOGGER.exception("Failed to get message on worker %s", worker) 

69 await asyncio.sleep(0) 

70 continue 

71 

72 match message: 

73 case None: 

74 pass 

75 case { 

76 "type": "message", 

77 "data": str() as data, 

78 "channel": channel, 

79 } if ( 

80 channel == REDIS_CHANNEL 

81 ): 

82 await asyncio.gather( 

83 *[conn.write_message(data) for conn in OPEN_CONNECTIONS] 

84 ) 

85 case { 

86 "type": "subscribe", 

87 "data": 1, 

88 "channel": channel, 

89 } if ( 

90 channel == REDIS_CHANNEL 

91 ): 

92 logging.info( 

93 "Subscribed to Redis channel %r on worker %s", 

94 channel, 

95 worker, 

96 ) 

97 case _: 

98 logging.error( 

99 "Got unexpected message %s on worker %s", 

100 message, 

101 worker, 

102 ) 

103 await asyncio.sleep(0) 

104 

105 

106async def save_new_message( 

107 author: str, 

108 message: str, 

109 redis: Redis[str], 

110 redis_prefix: str, 

111) -> None: 

112 """Save a new message.""" 

113 message_dict = { 

114 "author": [data["emoji"] for data in emoji_list(author)], 

115 "content": [data["emoji"] for data in emoji_list(message)], 

116 "timestamp": get_ms_timestamp(), 

117 } 

118 await redis.rpush( 

119 f"{redis_prefix}:emoji-chat:message-list", 

120 json.dumps(message_dict, option=ORJSON_OPTIONS), 

121 ) 

122 await redis.ltrim( 

123 f"{redis_prefix}:emoji-chat:message-list", -MAX_MESSAGE_SAVE_COUNT, -1 

124 ) 

125 LOGGER.info("GOT new message %s", message_dict) 

126 await redis.publish( 

127 REDIS_CHANNEL, 

128 json.dumps( 

129 { 

130 "type": "message", 

131 "message": message_dict, 

132 }, 

133 option=ORJSON_OPTIONS, 

134 ), 

135 ) 

136 

137 

138async def get_messages( 

139 redis: Redis[str], 

140 redis_prefix: str, 

141 start: None | int = None, 

142 stop: int = -1, 

143) -> list[dict[str, Any]]: 

144 """Get the messages.""" 

145 start = start if start is not None else -MAX_MESSAGE_SAVE_COUNT 

146 messages = await redis.lrange( 

147 f"{redis_prefix}:emoji-chat:message-list", start, stop 

148 ) 

149 return [json.loads(message) for message in messages] 

150 

151 

152def check_message_invalid(message: str) -> Literal[False] | str: 

153 """Check if a message is an invalid message.""" 

154 if not message: 

155 return "Empty message not allowed." 

156 

157 if not purely_emoji(message): 

158 return "Message can only contain emojis." 

159 

160 if len(emoji_list(message)) > MAX_MESSAGE_LENGTH: 

161 return f"Message longer than {MAX_MESSAGE_LENGTH} emojis." 

162 

163 return False 

164 

165 

166def emojize_user_input(string: str) -> str: 

167 """Emojize user input.""" 

168 string = emojize(string, language="de") 

169 string = emojize(string, language="en") 

170 string = emojize(string, language="alias") 

171 return string 

172 

173 

174def normalize_emojis(string: str) -> str: 

175 """Normalize emojis in a string.""" 

176 return emojize(demojize(string)) 

177 

178 

179def get_random_name() -> str: 

180 """Generate a random name.""" 

181 return normalize_emojis( 

182 "".join(random.sample(EMOJIS_NO_FLAGS, 5)) # nosec: B311 

183 ) 

184 

185 

186class ChatHandler(BaseRequestHandler): 

187 """The request handler for the emoji chat.""" 

188 

189 RATELIMIT_GET_BUCKET = "emoji-chat-get-messages" 

190 RATELIMIT_GET_LIMIT = 10 

191 RATELIMIT_GET_COUNT_PER_PERIOD = 10 

192 RATELIMIT_GET_PERIOD = 1 

193 

194 RATELIMIT_POST_BUCKET = "emoji-chat-send-message" 

195 RATELIMIT_POST_LIMIT = 5 

196 RATELIMIT_POST_COUNT_PER_PERIOD = 5 

197 RATELIMIT_POST_PERIOD = 5 

198 

199 async def get( 

200 self, 

201 *, 

202 head: bool = False, 

203 ) -> None: 

204 """Show the users the current messages.""" 

205 if not EVENT_REDIS.is_set(): 

206 raise HTTPError(503) 

207 

208 if head: 

209 return 

210 

211 await self.render_chat( 

212 await get_messages(self.redis, self.redis_prefix) 

213 ) 

214 

215 async def get_name(self) -> str: 

216 """Get the name of the user.""" 

217 cookie = self.get_secure_cookie( 

218 "emoji-chat-name", 

219 max_age_days=90, 

220 min_version=2, 

221 ) 

222 

223 name = cookie.decode("UTF-8") if cookie else get_random_name() 

224 

225 # save it in cookie or reset expiry date 

226 if not self.get_secure_cookie( 

227 "emoji-chat-name", max_age_days=30, min_version=2 

228 ): 

229 self.set_secure_cookie( 

230 "emoji-chat-name", 

231 name.encode("UTF-8"), 

232 expires_days=90, 

233 path="/", 

234 samesite="Strict", 

235 ) 

236 

237 geoip = await self.geoip() or {} 

238 if "country_flag" in geoip: 

239 flag = geoip["country_flag"] 

240 elif self.request.host_name.endswith(".onion"): 

241 flag = "🏴‍☠" 

242 else: 

243 flag = "❔" 

244 

245 return normalize_emojis(name + flag) 

246 

247 async def get_name_as_list(self) -> list[str]: 

248 """Return the name as list of emojis.""" 

249 return [emoji["emoji"] for emoji in emoji_list(await self.get_name())] 

250 

251 async def post(self) -> None: 

252 """Let users send messages and show the users the current messages.""" 

253 if not EVENT_REDIS.is_set(): 

254 raise HTTPError(503) 

255 

256 message = emojize_user_input( 

257 normalize_emojis(self.get_argument("message")) 

258 ) 

259 

260 if err := check_message_invalid(message): 

261 raise HTTPError(400, reason=err) 

262 

263 await save_new_message( 

264 await self.get_name(), 

265 message, 

266 redis=self.redis, 

267 redis_prefix=self.redis_prefix, 

268 ) 

269 

270 await self.render_chat( 

271 await get_messages(self.redis, self.redis_prefix) 

272 ) 

273 

274 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None: 

275 """Render the chat.""" 

276 raise NotImplementedError 

277 

278 

279class HTMLChatHandler(ChatHandler, HTMLRequestHandler): 

280 """The HTML request handler for the emoji chat.""" 

281 

282 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None: 

283 """Render the chat.""" 

284 await self.render( 

285 "pages/emoji_chat.html", 

286 messages=messages, 

287 user_name=await self.get_name_as_list(), 

288 ) 

289 

290 

291class APIChatHandler(ChatHandler, APIRequestHandler): 

292 """The API request handler for the emoji chat.""" 

293 

294 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None: 

295 """Render the chat.""" 

296 await self.finish( 

297 { 

298 "current_user": await self.get_name_as_list(), 

299 "messages": messages, 

300 } 

301 ) 

302 

303 

304OPEN_CONNECTIONS: list[ChatWebSocketHandler] = [] 

305 

306 

307class ChatWebSocketHandler(WebSocketHandler, ChatHandler): 

308 """The handler for the chat WebSocket.""" 

309 

310 name: str 

311 connection_time: int 

312 

313 def on_close(self) -> None: # noqa: D102 

314 LOGGER.info("WebSocket closed") 

315 OPEN_CONNECTIONS.remove(self) 

316 for conn in OPEN_CONNECTIONS: 

317 conn.send_users() 

318 

319 def on_message(self, message: str | bytes) -> Awaitable[None] | None: 

320 """Respond to an incoming message.""" 

321 if not message: 

322 return None 

323 message2: dict[str, Any] = json.loads(message) 

324 if message2["type"] == "message": 

325 if "message" not in message2: 

326 return self.write_message( 

327 { 

328 "type": "error", 

329 "error": "Message needs message key with the message.", 

330 } 

331 ) 

332 return self.save_new_message(message2["message"]) 

333 

334 return self.write_message( 

335 {"type": "error", "error": f"Unknown type {message2['type']}."} 

336 ) 

337 

338 async def open(self, *args: str, **kwargs: str) -> None: 

339 # pylint: disable=invalid-overridden-method 

340 """Handle an opened connection.""" 

341 LOGGER.info("WebSocket opened") 

342 await self.write_message( 

343 { 

344 "type": "init", 

345 "current_user": [ 

346 emoji["emoji"] for emoji in emoji_list(self.name) 

347 ], 

348 } 

349 ) 

350 

351 self.connection_time = get_ms_timestamp() 

352 OPEN_CONNECTIONS.append(self) 

353 for conn in OPEN_CONNECTIONS: 

354 conn.send_users() 

355 

356 await self.send_messages() 

357 

358 async def prepare(self) -> None: # noqa: D102 

359 self.now = await self.get_time() 

360 

361 if not EVENT_REDIS.is_set(): 

362 raise HTTPError(503) 

363 

364 self.name = await self.get_name() 

365 

366 if not await self.ratelimit(True): 

367 await self.ratelimit() 

368 

369 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None: 

370 """Render the chat.""" 

371 raise NotImplementedError 

372 

373 async def save_new_message(self, msg_text: str) -> None: 

374 """Save a new message.""" 

375 msg_text = emojize_user_input(normalize_emojis(msg_text).strip()) 

376 if err := check_message_invalid(msg_text): 

377 return await self.write_message({"type": "error", "error": err}) 

378 

379 if self.settings.get("RATELIMITS") and not self.is_authorized( 

380 Permission.RATELIMITS 

381 ): 

382 if not EVENT_REDIS.is_set(): 

383 return await self.write_message({"type": "ratelimit"}) 

384 

385 ratelimited, headers = await ratelimit( 

386 self.redis, 

387 self.redis_prefix, 

388 str(self.request.remote_ip), 

389 bucket=self.RATELIMIT_POST_BUCKET, 

390 max_burst=self.RATELIMIT_POST_LIMIT - 1, 

391 count_per_period=self.RATELIMIT_POST_COUNT_PER_PERIOD, 

392 period=self.RATELIMIT_POST_PERIOD, 

393 tokens=1, 

394 ) 

395 

396 if ratelimited: 

397 return await self.write_message( 

398 {"type": "ratelimit", "retry_after": headers["Retry-After"]} 

399 ) 

400 

401 return await save_new_message( 

402 self.name, msg_text, self.redis, self.redis_prefix 

403 ) 

404 

405 async def send_messages(self) -> None: 

406 """Send this WebSocket all current messages.""" 

407 return await self.write_message( 

408 { 

409 "type": "messages", 

410 "messages": await get_messages(self.redis, self.redis_prefix), 

411 }, 

412 ) 

413 

414 def send_users(self) -> None: 

415 """Send this WebSocket all current users.""" 

416 if sys.flags.dev_mode: 

417 self.write_message( # type: ignore[unused-awaitable] 

418 { 

419 "type": "users", 

420 "users": [ 

421 { 

422 "name": conn.name, 

423 "joined_at": conn.connection_time, 

424 } 

425 for conn in OPEN_CONNECTIONS 

426 ], 

427 } 

428 )