Coverage for an_website / emoji_chat / chat.py: 47.853%
163 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-15 14:36 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-15 14:36 +0000
1# This program is free software: you can redistribute it and/or modify
2# it under the terms of the GNU Affero General Public License as
3# published by the Free Software Foundation, either version 3 of the
4# License, or (at your option) any later version.
5#
6# This program is distributed in the hope that it will be useful,
7# but WITHOUT ANY WARRANTY; without even the implied warranty of
8# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9# GNU Affero General Public License for more details.
10#
11# You should have received a copy of the GNU Affero General Public License
12# along with this program. If not, see <https://www.gnu.org/licenses/>.
14"""A 🆒 chat."""
17import asyncio
18import logging
19import random
20import sys
21import time
22from collections.abc import Awaitable, Iterable, Mapping
23from typing import Any, Final, Literal
25import orjson as json
26from emoji import EMOJI_DATA, demojize, emoji_list, emojize, purely_emoji
27from redis.asyncio import Redis
28from tornado.web import Application, HTTPError
29from tornado.websocket import WebSocketHandler
31from .. import EPOCH_MS, EVENT_REDIS, EVENT_SHUTDOWN, NAME, ORJSON_OPTIONS
32from ..utils.base_request_handler import BaseRequestHandler
33from ..utils.request_handler import APIRequestHandler, HTMLRequestHandler
34from ..utils.utils import Permission, ratelimit
35from .pub_sub_provider import PubSubProvider
37LOGGER: Final = logging.getLogger(__name__)
39EMOJIS_NO_FLAGS: Final[tuple[str, ...]] = tuple(
40 emoji
41 for emoji in EMOJI_DATA
42 if ord(emoji[0]) not in range(0x1F1E6, 0x1F200)
43)
45MAX_MESSAGE_SAVE_COUNT: Final = 200
46MAX_MESSAGE_LENGTH: Final = 20
47REDIS_CHANNEL: Final = f"{NAME}:emoji_chat_channel"
50def get_ms_timestamp() -> int:
51 """Get the current time in ms."""
52 return time.time_ns() // 1_000_000 - EPOCH_MS
55async def subscribe_to_redis_channel(
56 app: Application, worker: int | None
57) -> None:
58 """Subscribe to the Redis channel and handle incoming messages."""
59 get_pubsub = PubSubProvider((REDIS_CHANNEL,), app.settings, worker)
60 del app
62 while not EVENT_SHUTDOWN.is_set(): # pylint: disable=while-used
63 ps = await get_pubsub()
64 try:
65 message = await ps.get_message(timeout=5.0)
66 except Exception as exc: # pylint: disable=broad-exception-caught
67 if str(exc) == "Connection closed by server.":
68 continue
69 LOGGER.exception("Failed to get message on worker %s", worker)
70 await asyncio.sleep(0)
71 continue
73 match message:
74 case None:
75 pass
76 case {
77 "type": "message",
78 "data": str() as data,
79 "channel": channel,
80 } if (
81 channel == REDIS_CHANNEL
82 ):
83 await asyncio.gather(
84 *[conn.write_message(data) for conn in OPEN_CONNECTIONS]
85 )
86 case {
87 "type": "subscribe",
88 "data": 1,
89 "channel": channel,
90 } if (
91 channel == REDIS_CHANNEL
92 ):
93 logging.info(
94 "Subscribed to Redis channel %r on worker %s",
95 channel,
96 worker,
97 )
98 case _:
99 logging.error(
100 "Got unexpected message %s on worker %s",
101 message,
102 worker,
103 )
104 await asyncio.sleep(0)
107async def save_new_message(
108 author: str,
109 message: str,
110 redis: Redis[str],
111 redis_prefix: str,
112) -> None:
113 """Save a new message."""
114 message_dict = {
115 "author": [data["emoji"] for data in emoji_list(author)],
116 "content": [data["emoji"] for data in emoji_list(message)],
117 "timestamp": get_ms_timestamp(),
118 }
119 await redis.rpush(
120 f"{redis_prefix}:emoji-chat:message-list",
121 json.dumps(message_dict, option=ORJSON_OPTIONS),
122 )
123 await redis.ltrim(
124 f"{redis_prefix}:emoji-chat:message-list", -MAX_MESSAGE_SAVE_COUNT, -1
125 )
126 LOGGER.info("GOT new message %s", message_dict)
127 await redis.publish(
128 REDIS_CHANNEL,
129 json.dumps(
130 {
131 "type": "message",
132 "message": message_dict,
133 },
134 option=ORJSON_OPTIONS,
135 ),
136 )
139async def get_messages(
140 redis: Redis[str],
141 redis_prefix: str,
142 start: None | int = None,
143 stop: int = -1,
144) -> list[dict[str, Any]]:
145 """Get the messages."""
146 start = start if start is not None else -MAX_MESSAGE_SAVE_COUNT
147 messages = await redis.lrange(
148 f"{redis_prefix}:emoji-chat:message-list", start, stop
149 )
150 return [json.loads(message) for message in messages]
153def check_message_invalid(message: str) -> Literal[False] | str:
154 """Check if a message is an invalid message."""
155 if not message:
156 return "Empty message not allowed."
158 if not purely_emoji(message):
159 return "Message can only contain emojis."
161 if len(emoji_list(message)) > MAX_MESSAGE_LENGTH:
162 return f"Message longer than {MAX_MESSAGE_LENGTH} emojis."
164 return False
167def emojize_user_input(string: str) -> str:
168 """Emojize user input."""
169 string = emojize(string, language="de")
170 string = emojize(string, language="en")
171 string = emojize(string, language="alias")
172 return string
175def normalize_emojis(string: str) -> str:
176 """Normalize emojis in a string."""
177 return emojize(demojize(string))
180def get_random_name() -> str:
181 """Generate a random name."""
182 return normalize_emojis(
183 "".join(random.sample(EMOJIS_NO_FLAGS, 5)) # nosec: B311
184 )
187class ChatHandler(BaseRequestHandler):
188 """The request handler for the emoji chat."""
190 RATELIMIT_GET_BUCKET = "emoji-chat-get-messages"
191 RATELIMIT_GET_LIMIT = 10
192 RATELIMIT_GET_COUNT_PER_PERIOD = 10
193 RATELIMIT_GET_PERIOD = 1
195 RATELIMIT_POST_BUCKET = "emoji-chat-send-message"
196 RATELIMIT_POST_LIMIT = 5
197 RATELIMIT_POST_COUNT_PER_PERIOD = 5
198 RATELIMIT_POST_PERIOD = 5
200 async def get(
201 self,
202 *,
203 head: bool = False,
204 ) -> None:
205 """Show the users the current messages."""
206 if not EVENT_REDIS.is_set():
207 raise HTTPError(503)
209 if head:
210 return
212 await self.render_chat(
213 await get_messages(self.redis, self.redis_prefix)
214 )
216 async def get_name(self) -> str:
217 """Get the name of the user."""
218 cookie = self.get_secure_cookie(
219 "emoji-chat-name",
220 max_age_days=90,
221 min_version=2,
222 )
224 name = cookie.decode("UTF-8") if cookie else get_random_name()
226 # save it in cookie or reset expiry date
227 if not self.get_secure_cookie(
228 "emoji-chat-name", max_age_days=30, min_version=2
229 ):
230 self.set_secure_cookie(
231 "emoji-chat-name",
232 name.encode("UTF-8"),
233 expires_days=90,
234 path="/",
235 samesite="Strict",
236 )
238 geoip = await self.geoip() or {}
239 if "country_flag" in geoip:
240 flag = geoip["country_flag"]
241 elif self.request.host_name.endswith(".onion"):
242 flag = "🏴☠"
243 else:
244 flag = "❔"
246 return normalize_emojis(name + flag)
248 async def get_name_as_list(self) -> list[str]:
249 """Return the name as list of emojis."""
250 return [emoji["emoji"] for emoji in emoji_list(await self.get_name())]
252 async def post(self) -> None:
253 """Let users send messages and show the users the current messages."""
254 if not EVENT_REDIS.is_set():
255 raise HTTPError(503)
257 message = emojize_user_input(
258 normalize_emojis(self.get_argument("message"))
259 )
261 if err := check_message_invalid(message):
262 raise HTTPError(400, reason=err)
264 await save_new_message(
265 await self.get_name(),
266 message,
267 redis=self.redis,
268 redis_prefix=self.redis_prefix,
269 )
271 await self.render_chat(
272 await get_messages(self.redis, self.redis_prefix)
273 )
275 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None:
276 """Render the chat."""
277 raise NotImplementedError
280class HTMLChatHandler(ChatHandler, HTMLRequestHandler):
281 """The HTML request handler for the emoji chat."""
283 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None:
284 """Render the chat."""
285 await self.render(
286 "pages/emoji_chat.html",
287 messages=messages,
288 user_name=await self.get_name_as_list(),
289 )
292class APIChatHandler(ChatHandler, APIRequestHandler):
293 """The API request handler for the emoji chat."""
295 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None:
296 """Render the chat."""
297 await self.finish(
298 {
299 "current_user": await self.get_name_as_list(),
300 "messages": messages,
301 }
302 )
305OPEN_CONNECTIONS: list[ChatWebSocketHandler] = []
308class ChatWebSocketHandler(WebSocketHandler, ChatHandler):
309 """The handler for the chat WebSocket."""
311 name: str
312 connection_time: int
314 def on_close(self) -> None: # noqa: D102
315 LOGGER.info("WebSocket closed")
316 OPEN_CONNECTIONS.remove(self)
317 for conn in OPEN_CONNECTIONS:
318 conn.send_users()
320 def on_message(self, message: str | bytes) -> Awaitable[None] | None:
321 """Respond to an incoming message."""
322 if not message:
323 return None
324 message2: dict[str, Any] = json.loads(message)
325 if message2["type"] == "message":
326 if "message" not in message2:
327 return self.write_message(
328 {
329 "type": "error",
330 "error": "Message needs message key with the message.",
331 }
332 )
333 return self.save_new_message(message2["message"])
335 return self.write_message(
336 {"type": "error", "error": f"Unknown type {message2['type']}."}
337 )
339 async def open(self, *args: str, **kwargs: str) -> None:
340 # pylint: disable=invalid-overridden-method
341 """Handle an opened connection."""
342 LOGGER.info("WebSocket opened")
343 await self.write_message(
344 {
345 "type": "init",
346 "current_user": [
347 emoji["emoji"] for emoji in emoji_list(self.name)
348 ],
349 }
350 )
352 self.connection_time = get_ms_timestamp()
353 OPEN_CONNECTIONS.append(self)
354 for conn in OPEN_CONNECTIONS:
355 conn.send_users()
357 await self.send_messages()
359 async def prepare(self) -> None: # noqa: D102
360 self.now = await self.get_time()
362 if not EVENT_REDIS.is_set():
363 raise HTTPError(503)
365 self.name = await self.get_name()
367 if not await self.ratelimit(True):
368 await self.ratelimit()
370 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None:
371 """Render the chat."""
372 raise NotImplementedError
374 async def save_new_message(self, msg_text: str) -> None:
375 """Save a new message."""
376 msg_text = emojize_user_input(normalize_emojis(msg_text).strip())
377 if err := check_message_invalid(msg_text):
378 return await self.write_message({"type": "error", "error": err})
380 if self.settings.get("RATELIMITS") and not self.is_authorized(
381 Permission.RATELIMITS
382 ):
383 if not EVENT_REDIS.is_set():
384 return await self.write_message({"type": "ratelimit"})
386 ratelimited, headers = await ratelimit(
387 self.redis,
388 self.redis_prefix,
389 str(self.request.remote_ip),
390 bucket=self.RATELIMIT_POST_BUCKET,
391 max_burst=self.RATELIMIT_POST_LIMIT - 1,
392 count_per_period=self.RATELIMIT_POST_COUNT_PER_PERIOD,
393 period=self.RATELIMIT_POST_PERIOD,
394 tokens=1,
395 )
397 if ratelimited:
398 return await self.write_message(
399 {"type": "ratelimit", "retry_after": headers["Retry-After"]}
400 )
402 return await save_new_message(
403 self.name, msg_text, self.redis, self.redis_prefix
404 )
406 async def send_messages(self) -> None:
407 """Send this WebSocket all current messages."""
408 return await self.write_message(
409 {
410 "type": "messages",
411 "messages": await get_messages(self.redis, self.redis_prefix),
412 },
413 )
415 def send_users(self) -> None:
416 """Send this WebSocket all current users."""
417 if sys.flags.dev_mode:
418 self.write_message( # type: ignore[unused-awaitable]
419 {
420 "type": "users",
421 "users": [
422 {
423 "name": conn.name,
424 "joined_at": conn.connection_time,
425 }
426 for conn in OPEN_CONNECTIONS
427 ],
428 }
429 )