Coverage for an_website/emoji_chat/chat.py: 49.102%
167 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-07 13:44 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-07 13:44 +0000
1# This program is free software: you can redistribute it and/or modify
2# it under the terms of the GNU Affero General Public License as
3# published by the Free Software Foundation, either version 3 of the
4# License, or (at your option) any later version.
5#
6# This program is distributed in the hope that it will be useful,
7# but WITHOUT ANY WARRANTY; without even the implied warranty of
8# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9# GNU Affero General Public License for more details.
10#
11# You should have received a copy of the GNU Affero General Public License
12# along with this program. If not, see <https://www.gnu.org/licenses/>.
14"""A 🆒 chat."""
16from __future__ import annotations
18import asyncio
19import logging
20import random
21import sys
22import time
23from collections.abc import Awaitable, Iterable, Mapping
24from typing import Any, Final, Literal
26import orjson as json
27from emoji import purely_emoji # type: ignore[attr-defined]
28from emoji import EMOJI_DATA, demojize, emoji_list, emojize
29from redis.asyncio import Redis
30from tornado.web import Application, HTTPError
31from tornado.websocket import WebSocketHandler
33from .. import EPOCH_MS, EVENT_REDIS, EVENT_SHUTDOWN, NAME, ORJSON_OPTIONS
34from ..utils.base_request_handler import BaseRequestHandler
35from ..utils.request_handler import APIRequestHandler, HTMLRequestHandler
36from ..utils.utils import Permission, ratelimit
37from .pub_sub_provider import PubSubProvider
39LOGGER: Final = logging.getLogger(__name__)
41EMOJIS_NO_FLAGS: Final[tuple[str, ...]] = tuple(
42 emoji
43 for emoji in EMOJI_DATA
44 if ord(emoji[0]) not in range(0x1F1E6, 0x1F200)
45)
47MAX_MESSAGE_SAVE_COUNT: Final = 100
48MAX_MESSAGE_LENGTH: Final = 20
49REDIS_CHANNEL: Final = f"{NAME}:emoji_chat_channel"
52def get_ms_timestamp() -> int:
53 """Get the current time in ms."""
54 return time.time_ns() // 1_000_000 - EPOCH_MS
57async def subscribe_to_redis_channel(
58 app: Application, worker: int | None
59) -> None:
60 """Subscribe to the Redis channel and handle incoming messages."""
61 get_pubsub = PubSubProvider((REDIS_CHANNEL,), app.settings, worker)
62 del app
64 while not EVENT_SHUTDOWN.is_set(): # pylint: disable=while-used
65 ps = await get_pubsub()
66 try:
67 message = await ps.get_message(timeout=5.0)
68 except Exception as exc: # pylint: disable=broad-exception-caught
69 if str(exc) == "Connection closed by server.":
70 continue
71 LOGGER.exception("Failed to get message on worker %s", worker)
72 await asyncio.sleep(0)
73 continue
75 match message:
76 case None:
77 pass
78 case {
79 "type": "message",
80 "data": str(data),
81 "channel": channel,
82 } if channel == REDIS_CHANNEL:
83 await asyncio.gather(
84 *[conn.write_message(data) for conn in OPEN_CONNECTIONS]
85 )
86 case {
87 "type": "subscribe",
88 "data": 1,
89 "channel": channel,
90 } if channel == REDIS_CHANNEL:
91 logging.info(
92 "Subscribed to Redis channel %r on worker %s",
93 channel,
94 worker,
95 )
96 case _:
97 logging.error(
98 "Got unexpected message %s on worker %s",
99 message,
100 worker,
101 )
102 await asyncio.sleep(0)
105async def save_new_message(
106 author: str,
107 message: str,
108 redis: Redis[str],
109 redis_prefix: str,
110) -> None:
111 """Save a new message."""
112 message_dict = {
113 "author": [data["emoji"] for data in emoji_list(author)],
114 "content": [data["emoji"] for data in emoji_list(message)],
115 "timestamp": get_ms_timestamp(),
116 }
117 await redis.rpush(
118 f"{redis_prefix}:emoji-chat:message-list",
119 json.dumps(message_dict, option=ORJSON_OPTIONS),
120 )
121 await redis.ltrim(
122 f"{redis_prefix}:emoji-chat:message-list", -MAX_MESSAGE_SAVE_COUNT, -1
123 )
124 LOGGER.info("GOT new message %s", message_dict)
125 await redis.publish(
126 REDIS_CHANNEL,
127 json.dumps(
128 {
129 "type": "message",
130 "message": message_dict,
131 },
132 option=ORJSON_OPTIONS,
133 ),
134 )
137async def get_messages(
138 redis: Redis[str],
139 redis_prefix: str,
140 start: None | int = None,
141 stop: int = -1,
142) -> list[dict[str, Any]]:
143 """Get the messages."""
144 start = start if start is not None else -MAX_MESSAGE_SAVE_COUNT
145 messages = await redis.lrange(
146 f"{redis_prefix}:emoji-chat:message-list", start, stop
147 )
148 return [json.loads(message) for message in messages]
151def check_message_invalid(message: str) -> Literal[False] | str:
152 """Check if a message is an invalid message."""
153 if not message:
154 return "Empty message not allowed."
156 if not purely_emoji(message):
157 return "Message can only contain emojis."
159 if len(emoji_list(message)) > MAX_MESSAGE_LENGTH:
160 return f"Message longer than {MAX_MESSAGE_LENGTH} emojis."
162 return False
165def emojize_user_input(string: str) -> str:
166 """Emojize user input."""
167 string = emojize(string, language="de")
168 string = emojize(string, language="en")
169 string = emojize(string, language="alias")
170 return string
173def normalize_emojis(string: str) -> str:
174 """Normalize emojis in a string."""
175 return emojize(demojize(string))
178def get_random_name() -> str:
179 """Generate a random name."""
180 return normalize_emojis(
181 "".join(random.sample(EMOJIS_NO_FLAGS, 5)) # nosec: B311
182 )
185class ChatHandler(BaseRequestHandler):
186 """The request handler for the emoji chat."""
188 RATELIMIT_GET_BUCKET = "emoji-chat-get-messages"
189 RATELIMIT_GET_LIMIT = 10
190 RATELIMIT_GET_COUNT_PER_PERIOD = 10
191 RATELIMIT_GET_PERIOD = 1
193 RATELIMIT_POST_BUCKET = "emoji-chat-send-message"
194 RATELIMIT_POST_LIMIT = 5
195 RATELIMIT_POST_COUNT_PER_PERIOD = 5
196 RATELIMIT_POST_PERIOD = 5
198 async def get(
199 self,
200 *,
201 head: bool = False,
202 ) -> None:
203 """Show the users the current messages."""
204 if not EVENT_REDIS.is_set():
205 raise HTTPError(503)
207 if head:
208 return
210 await self.render_chat(
211 await get_messages(self.redis, self.redis_prefix)
212 )
214 async def get_name(self) -> str:
215 """Get the name of the user."""
216 cookie = self.get_secure_cookie(
217 "emoji-chat-name",
218 max_age_days=90,
219 min_version=2,
220 )
222 name = cookie.decode("UTF-8") if cookie else get_random_name()
224 # save it in cookie or reset expiry date
225 if not self.get_secure_cookie(
226 "emoji-chat-name", max_age_days=30, min_version=2
227 ):
228 self.set_secure_cookie(
229 "emoji-chat-name",
230 name.encode("UTF-8"),
231 expires_days=90,
232 path="/",
233 samesite="Strict",
234 )
236 geoip = await self.geoip() or {}
237 if "country_flag" in geoip:
238 flag = geoip["country_flag"]
239 elif self.request.host_name.endswith(".onion"):
240 flag = "🏴☠"
241 else:
242 flag = "❔"
244 return normalize_emojis(name + flag)
246 async def get_name_as_list(self) -> list[str]:
247 """Return the name as list of emojis."""
248 return [emoji["emoji"] for emoji in emoji_list(await self.get_name())]
250 async def post(self) -> None:
251 """Let users send messages and show the users the current messages."""
252 if not EVENT_REDIS.is_set():
253 raise HTTPError(503)
255 message = emojize_user_input(
256 normalize_emojis(self.get_argument("message"))
257 )
259 if err := check_message_invalid(message):
260 raise HTTPError(400, reason=err)
262 await save_new_message(
263 await self.get_name(),
264 message,
265 redis=self.redis,
266 redis_prefix=self.redis_prefix,
267 )
269 await self.render_chat(
270 await get_messages(self.redis, self.redis_prefix)
271 )
273 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None:
274 """Render the chat."""
275 raise NotImplementedError
278class HTMLChatHandler(ChatHandler, HTMLRequestHandler):
279 """The HTML request handler for the emoji chat."""
281 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None:
282 """Render the chat."""
283 await self.render(
284 "pages/emoji_chat.html",
285 messages=messages,
286 user_name=await self.get_name_as_list(),
287 )
290class APIChatHandler(ChatHandler, APIRequestHandler):
291 """The API request handler for the emoji chat."""
293 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None:
294 """Render the chat."""
295 await self.finish(
296 {
297 "current_user": await self.get_name_as_list(),
298 "messages": messages,
299 }
300 )
303OPEN_CONNECTIONS: list[ChatWebSocketHandler] = []
306class ChatWebSocketHandler(WebSocketHandler, ChatHandler):
307 """The handler for the chat WebSocket."""
309 name: str
310 connection_time: int
312 def on_close(self) -> None: # noqa: D102
313 LOGGER.info("WebSocket closed")
314 OPEN_CONNECTIONS.remove(self)
315 for conn in OPEN_CONNECTIONS:
316 conn.send_users()
318 def on_message(self, message: str | bytes) -> Awaitable[None] | None:
319 """Respond to an incoming message."""
320 if not message:
321 return None
322 message2: dict[str, Any] = json.loads(message)
323 if message2["type"] == "message":
324 if "message" not in message2:
325 return self.write_message(
326 {
327 "type": "error",
328 "error": "Message needs message key with the message.",
329 }
330 )
331 return self.save_new_message(message2["message"])
333 return self.write_message(
334 {"type": "error", "error": f"Unknown type {message2['type']}."}
335 )
337 async def open(self, *args: str, **kwargs: str) -> None:
338 # pylint: disable=invalid-overridden-method
339 """Handle an opened connection."""
340 LOGGER.info("WebSocket opened")
341 await self.write_message(
342 {
343 "type": "init",
344 "current_user": [
345 emoji["emoji"] for emoji in emoji_list(self.name)
346 ],
347 }
348 )
350 self.connection_time = get_ms_timestamp()
351 OPEN_CONNECTIONS.append(self)
352 for conn in OPEN_CONNECTIONS:
353 conn.send_users()
355 await self.send_messages()
357 async def prepare(self) -> None: # noqa: D102
358 self.now = await self.get_time()
360 if not EVENT_REDIS.is_set():
361 raise HTTPError(503)
363 self.name = await self.get_name()
365 if not await self.ratelimit(True):
366 await self.ratelimit()
368 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None:
369 """Render the chat."""
370 raise NotImplementedError
372 async def save_new_message(self, msg_text: str) -> None:
373 """Save a new message."""
374 msg_text = emojize_user_input(normalize_emojis(msg_text).strip())
375 if err := check_message_invalid(msg_text):
376 return await self.write_message({"type": "error", "error": err})
378 if self.settings.get("RATELIMITS") and not self.is_authorized(
379 Permission.RATELIMITS
380 ):
381 if not EVENT_REDIS.is_set():
382 return await self.write_message({"type": "ratelimit"})
384 ratelimited, headers = await ratelimit(
385 self.redis,
386 self.redis_prefix,
387 str(self.request.remote_ip),
388 bucket=self.RATELIMIT_POST_BUCKET,
389 max_burst=self.RATELIMIT_POST_LIMIT - 1,
390 count_per_period=self.RATELIMIT_POST_COUNT_PER_PERIOD,
391 period=self.RATELIMIT_POST_PERIOD,
392 tokens=1,
393 )
395 if ratelimited:
396 return await self.write_message(
397 {"type": "ratelimit", "retry_after": headers["Retry-After"]}
398 )
400 return await save_new_message(
401 self.name, msg_text, self.redis, self.redis_prefix
402 )
404 async def send_messages(self) -> None:
405 """Send this WebSocket all current messages."""
406 return await self.write_message(
407 {
408 "type": "messages",
409 "messages": await get_messages(self.redis, self.redis_prefix),
410 },
411 )
413 def send_users(self) -> None:
414 """Send this WebSocket all current users."""
415 if sys.flags.dev_mode:
416 self.write_message( # type: ignore[unused-awaitable]
417 {
418 "type": "users",
419 "users": [
420 {
421 "name": conn.name,
422 "joined_at": conn.connection_time,
423 }
424 for conn in OPEN_CONNECTIONS
425 ],
426 }
427 )