Coverage for an_website / emoji_chat / chat.py: 47.853%

163 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-03-15 14:36 +0000

1# This program is free software: you can redistribute it and/or modify 

2# it under the terms of the GNU Affero General Public License as 

3# published by the Free Software Foundation, either version 3 of the 

4# License, or (at your option) any later version. 

5# 

6# This program is distributed in the hope that it will be useful, 

7# but WITHOUT ANY WARRANTY; without even the implied warranty of 

8# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

9# GNU Affero General Public License for more details. 

10# 

11# You should have received a copy of the GNU Affero General Public License 

12# along with this program. If not, see <https://www.gnu.org/licenses/>. 

13 

14"""A 🆒 chat.""" 

15 

16 

17import asyncio 

18import logging 

19import random 

20import sys 

21import time 

22from collections.abc import Awaitable, Iterable, Mapping 

23from typing import Any, Final, Literal 

24 

25import orjson as json 

26from emoji import EMOJI_DATA, demojize, emoji_list, emojize, purely_emoji 

27from redis.asyncio import Redis 

28from tornado.web import Application, HTTPError 

29from tornado.websocket import WebSocketHandler 

30 

31from .. import EPOCH_MS, EVENT_REDIS, EVENT_SHUTDOWN, NAME, ORJSON_OPTIONS 

32from ..utils.base_request_handler import BaseRequestHandler 

33from ..utils.request_handler import APIRequestHandler, HTMLRequestHandler 

34from ..utils.utils import Permission, ratelimit 

35from .pub_sub_provider import PubSubProvider 

36 

37LOGGER: Final = logging.getLogger(__name__) 

38 

39EMOJIS_NO_FLAGS: Final[tuple[str, ...]] = tuple( 

40 emoji 

41 for emoji in EMOJI_DATA 

42 if ord(emoji[0]) not in range(0x1F1E6, 0x1F200) 

43) 

44 

45MAX_MESSAGE_SAVE_COUNT: Final = 200 

46MAX_MESSAGE_LENGTH: Final = 20 

47REDIS_CHANNEL: Final = f"{NAME}:emoji_chat_channel" 

48 

49 

50def get_ms_timestamp() -> int: 

51 """Get the current time in ms.""" 

52 return time.time_ns() // 1_000_000 - EPOCH_MS 

53 

54 

55async def subscribe_to_redis_channel( 

56 app: Application, worker: int | None 

57) -> None: 

58 """Subscribe to the Redis channel and handle incoming messages.""" 

59 get_pubsub = PubSubProvider((REDIS_CHANNEL,), app.settings, worker) 

60 del app 

61 

62 while not EVENT_SHUTDOWN.is_set(): # pylint: disable=while-used 

63 ps = await get_pubsub() 

64 try: 

65 message = await ps.get_message(timeout=5.0) 

66 except Exception as exc: # pylint: disable=broad-exception-caught 

67 if str(exc) == "Connection closed by server.": 

68 continue 

69 LOGGER.exception("Failed to get message on worker %s", worker) 

70 await asyncio.sleep(0) 

71 continue 

72 

73 match message: 

74 case None: 

75 pass 

76 case { 

77 "type": "message", 

78 "data": str() as data, 

79 "channel": channel, 

80 } if ( 

81 channel == REDIS_CHANNEL 

82 ): 

83 await asyncio.gather( 

84 *[conn.write_message(data) for conn in OPEN_CONNECTIONS] 

85 ) 

86 case { 

87 "type": "subscribe", 

88 "data": 1, 

89 "channel": channel, 

90 } if ( 

91 channel == REDIS_CHANNEL 

92 ): 

93 logging.info( 

94 "Subscribed to Redis channel %r on worker %s", 

95 channel, 

96 worker, 

97 ) 

98 case _: 

99 logging.error( 

100 "Got unexpected message %s on worker %s", 

101 message, 

102 worker, 

103 ) 

104 await asyncio.sleep(0) 

105 

106 

107async def save_new_message( 

108 author: str, 

109 message: str, 

110 redis: Redis[str], 

111 redis_prefix: str, 

112) -> None: 

113 """Save a new message.""" 

114 message_dict = { 

115 "author": [data["emoji"] for data in emoji_list(author)], 

116 "content": [data["emoji"] for data in emoji_list(message)], 

117 "timestamp": get_ms_timestamp(), 

118 } 

119 await redis.rpush( 

120 f"{redis_prefix}:emoji-chat:message-list", 

121 json.dumps(message_dict, option=ORJSON_OPTIONS), 

122 ) 

123 await redis.ltrim( 

124 f"{redis_prefix}:emoji-chat:message-list", -MAX_MESSAGE_SAVE_COUNT, -1 

125 ) 

126 LOGGER.info("GOT new message %s", message_dict) 

127 await redis.publish( 

128 REDIS_CHANNEL, 

129 json.dumps( 

130 { 

131 "type": "message", 

132 "message": message_dict, 

133 }, 

134 option=ORJSON_OPTIONS, 

135 ), 

136 ) 

137 

138 

139async def get_messages( 

140 redis: Redis[str], 

141 redis_prefix: str, 

142 start: None | int = None, 

143 stop: int = -1, 

144) -> list[dict[str, Any]]: 

145 """Get the messages.""" 

146 start = start if start is not None else -MAX_MESSAGE_SAVE_COUNT 

147 messages = await redis.lrange( 

148 f"{redis_prefix}:emoji-chat:message-list", start, stop 

149 ) 

150 return [json.loads(message) for message in messages] 

151 

152 

153def check_message_invalid(message: str) -> Literal[False] | str: 

154 """Check if a message is an invalid message.""" 

155 if not message: 

156 return "Empty message not allowed." 

157 

158 if not purely_emoji(message): 

159 return "Message can only contain emojis." 

160 

161 if len(emoji_list(message)) > MAX_MESSAGE_LENGTH: 

162 return f"Message longer than {MAX_MESSAGE_LENGTH} emojis." 

163 

164 return False 

165 

166 

167def emojize_user_input(string: str) -> str: 

168 """Emojize user input.""" 

169 string = emojize(string, language="de") 

170 string = emojize(string, language="en") 

171 string = emojize(string, language="alias") 

172 return string 

173 

174 

175def normalize_emojis(string: str) -> str: 

176 """Normalize emojis in a string.""" 

177 return emojize(demojize(string)) 

178 

179 

180def get_random_name() -> str: 

181 """Generate a random name.""" 

182 return normalize_emojis( 

183 "".join(random.sample(EMOJIS_NO_FLAGS, 5)) # nosec: B311 

184 ) 

185 

186 

187class ChatHandler(BaseRequestHandler): 

188 """The request handler for the emoji chat.""" 

189 

190 RATELIMIT_GET_BUCKET = "emoji-chat-get-messages" 

191 RATELIMIT_GET_LIMIT = 10 

192 RATELIMIT_GET_COUNT_PER_PERIOD = 10 

193 RATELIMIT_GET_PERIOD = 1 

194 

195 RATELIMIT_POST_BUCKET = "emoji-chat-send-message" 

196 RATELIMIT_POST_LIMIT = 5 

197 RATELIMIT_POST_COUNT_PER_PERIOD = 5 

198 RATELIMIT_POST_PERIOD = 5 

199 

200 async def get( 

201 self, 

202 *, 

203 head: bool = False, 

204 ) -> None: 

205 """Show the users the current messages.""" 

206 if not EVENT_REDIS.is_set(): 

207 raise HTTPError(503) 

208 

209 if head: 

210 return 

211 

212 await self.render_chat( 

213 await get_messages(self.redis, self.redis_prefix) 

214 ) 

215 

216 async def get_name(self) -> str: 

217 """Get the name of the user.""" 

218 cookie = self.get_secure_cookie( 

219 "emoji-chat-name", 

220 max_age_days=90, 

221 min_version=2, 

222 ) 

223 

224 name = cookie.decode("UTF-8") if cookie else get_random_name() 

225 

226 # save it in cookie or reset expiry date 

227 if not self.get_secure_cookie( 

228 "emoji-chat-name", max_age_days=30, min_version=2 

229 ): 

230 self.set_secure_cookie( 

231 "emoji-chat-name", 

232 name.encode("UTF-8"), 

233 expires_days=90, 

234 path="/", 

235 samesite="Strict", 

236 ) 

237 

238 geoip = await self.geoip() or {} 

239 if "country_flag" in geoip: 

240 flag = geoip["country_flag"] 

241 elif self.request.host_name.endswith(".onion"): 

242 flag = "🏴‍☠" 

243 else: 

244 flag = "❔" 

245 

246 return normalize_emojis(name + flag) 

247 

248 async def get_name_as_list(self) -> list[str]: 

249 """Return the name as list of emojis.""" 

250 return [emoji["emoji"] for emoji in emoji_list(await self.get_name())] 

251 

252 async def post(self) -> None: 

253 """Let users send messages and show the users the current messages.""" 

254 if not EVENT_REDIS.is_set(): 

255 raise HTTPError(503) 

256 

257 message = emojize_user_input( 

258 normalize_emojis(self.get_argument("message")) 

259 ) 

260 

261 if err := check_message_invalid(message): 

262 raise HTTPError(400, reason=err) 

263 

264 await save_new_message( 

265 await self.get_name(), 

266 message, 

267 redis=self.redis, 

268 redis_prefix=self.redis_prefix, 

269 ) 

270 

271 await self.render_chat( 

272 await get_messages(self.redis, self.redis_prefix) 

273 ) 

274 

275 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None: 

276 """Render the chat.""" 

277 raise NotImplementedError 

278 

279 

280class HTMLChatHandler(ChatHandler, HTMLRequestHandler): 

281 """The HTML request handler for the emoji chat.""" 

282 

283 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None: 

284 """Render the chat.""" 

285 await self.render( 

286 "pages/emoji_chat.html", 

287 messages=messages, 

288 user_name=await self.get_name_as_list(), 

289 ) 

290 

291 

292class APIChatHandler(ChatHandler, APIRequestHandler): 

293 """The API request handler for the emoji chat.""" 

294 

295 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None: 

296 """Render the chat.""" 

297 await self.finish( 

298 { 

299 "current_user": await self.get_name_as_list(), 

300 "messages": messages, 

301 } 

302 ) 

303 

304 

305OPEN_CONNECTIONS: list[ChatWebSocketHandler] = [] 

306 

307 

308class ChatWebSocketHandler(WebSocketHandler, ChatHandler): 

309 """The handler for the chat WebSocket.""" 

310 

311 name: str 

312 connection_time: int 

313 

314 def on_close(self) -> None: # noqa: D102 

315 LOGGER.info("WebSocket closed") 

316 OPEN_CONNECTIONS.remove(self) 

317 for conn in OPEN_CONNECTIONS: 

318 conn.send_users() 

319 

320 def on_message(self, message: str | bytes) -> Awaitable[None] | None: 

321 """Respond to an incoming message.""" 

322 if not message: 

323 return None 

324 message2: dict[str, Any] = json.loads(message) 

325 if message2["type"] == "message": 

326 if "message" not in message2: 

327 return self.write_message( 

328 { 

329 "type": "error", 

330 "error": "Message needs message key with the message.", 

331 } 

332 ) 

333 return self.save_new_message(message2["message"]) 

334 

335 return self.write_message( 

336 {"type": "error", "error": f"Unknown type {message2['type']}."} 

337 ) 

338 

339 async def open(self, *args: str, **kwargs: str) -> None: 

340 # pylint: disable=invalid-overridden-method 

341 """Handle an opened connection.""" 

342 LOGGER.info("WebSocket opened") 

343 await self.write_message( 

344 { 

345 "type": "init", 

346 "current_user": [ 

347 emoji["emoji"] for emoji in emoji_list(self.name) 

348 ], 

349 } 

350 ) 

351 

352 self.connection_time = get_ms_timestamp() 

353 OPEN_CONNECTIONS.append(self) 

354 for conn in OPEN_CONNECTIONS: 

355 conn.send_users() 

356 

357 await self.send_messages() 

358 

359 async def prepare(self) -> None: # noqa: D102 

360 self.now = await self.get_time() 

361 

362 if not EVENT_REDIS.is_set(): 

363 raise HTTPError(503) 

364 

365 self.name = await self.get_name() 

366 

367 if not await self.ratelimit(True): 

368 await self.ratelimit() 

369 

370 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None: 

371 """Render the chat.""" 

372 raise NotImplementedError 

373 

374 async def save_new_message(self, msg_text: str) -> None: 

375 """Save a new message.""" 

376 msg_text = emojize_user_input(normalize_emojis(msg_text).strip()) 

377 if err := check_message_invalid(msg_text): 

378 return await self.write_message({"type": "error", "error": err}) 

379 

380 if self.settings.get("RATELIMITS") and not self.is_authorized( 

381 Permission.RATELIMITS 

382 ): 

383 if not EVENT_REDIS.is_set(): 

384 return await self.write_message({"type": "ratelimit"}) 

385 

386 ratelimited, headers = await ratelimit( 

387 self.redis, 

388 self.redis_prefix, 

389 str(self.request.remote_ip), 

390 bucket=self.RATELIMIT_POST_BUCKET, 

391 max_burst=self.RATELIMIT_POST_LIMIT - 1, 

392 count_per_period=self.RATELIMIT_POST_COUNT_PER_PERIOD, 

393 period=self.RATELIMIT_POST_PERIOD, 

394 tokens=1, 

395 ) 

396 

397 if ratelimited: 

398 return await self.write_message( 

399 {"type": "ratelimit", "retry_after": headers["Retry-After"]} 

400 ) 

401 

402 return await save_new_message( 

403 self.name, msg_text, self.redis, self.redis_prefix 

404 ) 

405 

406 async def send_messages(self) -> None: 

407 """Send this WebSocket all current messages.""" 

408 return await self.write_message( 

409 { 

410 "type": "messages", 

411 "messages": await get_messages(self.redis, self.redis_prefix), 

412 }, 

413 ) 

414 

415 def send_users(self) -> None: 

416 """Send this WebSocket all current users.""" 

417 if sys.flags.dev_mode: 

418 self.write_message( # type: ignore[unused-awaitable] 

419 { 

420 "type": "users", 

421 "users": [ 

422 { 

423 "name": conn.name, 

424 "joined_at": conn.connection_time, 

425 } 

426 for conn in OPEN_CONNECTIONS 

427 ], 

428 } 

429 )