Coverage for an_website/emoji_chat/chat.py: 47.853%
163 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-10 18:56 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-10 18:56 +0000
1# This program is free software: you can redistribute it and/or modify
2# it under the terms of the GNU Affero General Public License as
3# published by the Free Software Foundation, either version 3 of the
4# License, or (at your option) any later version.
5#
6# This program is distributed in the hope that it will be useful,
7# but WITHOUT ANY WARRANTY; without even the implied warranty of
8# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9# GNU Affero General Public License for more details.
10#
11# You should have received a copy of the GNU Affero General Public License
12# along with this program. If not, see <https://www.gnu.org/licenses/>.
14"""A 🆒 chat."""
16import asyncio
17import logging
18import random
19import sys
20import time
21from collections.abc import Awaitable, Iterable, Mapping
22from typing import Any, Final, Literal
24import orjson as json
25from emoji import EMOJI_DATA, demojize, emoji_list, emojize, purely_emoji
26from redis.asyncio import Redis
27from tornado.web import Application, HTTPError
28from tornado.websocket import WebSocketHandler
30from .. import EPOCH_MS, EVENT_REDIS, EVENT_SHUTDOWN, NAME, ORJSON_OPTIONS
31from ..utils.base_request_handler import BaseRequestHandler
32from ..utils.request_handler import APIRequestHandler, HTMLRequestHandler
33from ..utils.utils import Permission, ratelimit
34from .pub_sub_provider import PubSubProvider
36LOGGER: Final = logging.getLogger(__name__)
38EMOJIS_NO_FLAGS: Final[tuple[str, ...]] = tuple(
39 emoji
40 for emoji in EMOJI_DATA
41 if ord(emoji[0]) not in range(0x1F1E6, 0x1F200)
42)
44MAX_MESSAGE_SAVE_COUNT: Final = 200
45MAX_MESSAGE_LENGTH: Final = 20
46REDIS_CHANNEL: Final = f"{NAME}:emoji_chat_channel"
49def get_ms_timestamp() -> int:
50 """Get the current time in ms."""
51 return time.time_ns() // 1_000_000 - EPOCH_MS
54async def subscribe_to_redis_channel(
55 app: Application, worker: int | None
56) -> None:
57 """Subscribe to the Redis channel and handle incoming messages."""
58 get_pubsub = PubSubProvider((REDIS_CHANNEL,), app.settings, worker)
59 del app
61 while not EVENT_SHUTDOWN.is_set(): # pylint: disable=while-used
62 ps = await get_pubsub()
63 try:
64 message = await ps.get_message(timeout=5.0)
65 except Exception as exc: # pylint: disable=broad-exception-caught
66 if str(exc) == "Connection closed by server.":
67 continue
68 LOGGER.exception("Failed to get message on worker %s", worker)
69 await asyncio.sleep(0)
70 continue
72 match message:
73 case None:
74 pass
75 case {
76 "type": "message",
77 "data": str() as data,
78 "channel": channel,
79 } if (
80 channel == REDIS_CHANNEL
81 ):
82 await asyncio.gather(
83 *[conn.write_message(data) for conn in OPEN_CONNECTIONS]
84 )
85 case {
86 "type": "subscribe",
87 "data": 1,
88 "channel": channel,
89 } if (
90 channel == REDIS_CHANNEL
91 ):
92 logging.info(
93 "Subscribed to Redis channel %r on worker %s",
94 channel,
95 worker,
96 )
97 case _:
98 logging.error(
99 "Got unexpected message %s on worker %s",
100 message,
101 worker,
102 )
103 await asyncio.sleep(0)
106async def save_new_message(
107 author: str,
108 message: str,
109 redis: Redis[str],
110 redis_prefix: str,
111) -> None:
112 """Save a new message."""
113 message_dict = {
114 "author": [data["emoji"] for data in emoji_list(author)],
115 "content": [data["emoji"] for data in emoji_list(message)],
116 "timestamp": get_ms_timestamp(),
117 }
118 await redis.rpush(
119 f"{redis_prefix}:emoji-chat:message-list",
120 json.dumps(message_dict, option=ORJSON_OPTIONS),
121 )
122 await redis.ltrim(
123 f"{redis_prefix}:emoji-chat:message-list", -MAX_MESSAGE_SAVE_COUNT, -1
124 )
125 LOGGER.info("GOT new message %s", message_dict)
126 await redis.publish(
127 REDIS_CHANNEL,
128 json.dumps(
129 {
130 "type": "message",
131 "message": message_dict,
132 },
133 option=ORJSON_OPTIONS,
134 ),
135 )
138async def get_messages(
139 redis: Redis[str],
140 redis_prefix: str,
141 start: None | int = None,
142 stop: int = -1,
143) -> list[dict[str, Any]]:
144 """Get the messages."""
145 start = start if start is not None else -MAX_MESSAGE_SAVE_COUNT
146 messages = await redis.lrange(
147 f"{redis_prefix}:emoji-chat:message-list", start, stop
148 )
149 return [json.loads(message) for message in messages]
152def check_message_invalid(message: str) -> Literal[False] | str:
153 """Check if a message is an invalid message."""
154 if not message:
155 return "Empty message not allowed."
157 if not purely_emoji(message):
158 return "Message can only contain emojis."
160 if len(emoji_list(message)) > MAX_MESSAGE_LENGTH:
161 return f"Message longer than {MAX_MESSAGE_LENGTH} emojis."
163 return False
166def emojize_user_input(string: str) -> str:
167 """Emojize user input."""
168 string = emojize(string, language="de")
169 string = emojize(string, language="en")
170 string = emojize(string, language="alias")
171 return string
174def normalize_emojis(string: str) -> str:
175 """Normalize emojis in a string."""
176 return emojize(demojize(string))
179def get_random_name() -> str:
180 """Generate a random name."""
181 return normalize_emojis(
182 "".join(random.sample(EMOJIS_NO_FLAGS, 5)) # nosec: B311
183 )
186class ChatHandler(BaseRequestHandler):
187 """The request handler for the emoji chat."""
189 RATELIMIT_GET_BUCKET = "emoji-chat-get-messages"
190 RATELIMIT_GET_LIMIT = 10
191 RATELIMIT_GET_COUNT_PER_PERIOD = 10
192 RATELIMIT_GET_PERIOD = 1
194 RATELIMIT_POST_BUCKET = "emoji-chat-send-message"
195 RATELIMIT_POST_LIMIT = 5
196 RATELIMIT_POST_COUNT_PER_PERIOD = 5
197 RATELIMIT_POST_PERIOD = 5
199 async def get(
200 self,
201 *,
202 head: bool = False,
203 ) -> None:
204 """Show the users the current messages."""
205 if not EVENT_REDIS.is_set():
206 raise HTTPError(503)
208 if head:
209 return
211 await self.render_chat(
212 await get_messages(self.redis, self.redis_prefix)
213 )
215 async def get_name(self) -> str:
216 """Get the name of the user."""
217 cookie = self.get_secure_cookie(
218 "emoji-chat-name",
219 max_age_days=90,
220 min_version=2,
221 )
223 name = cookie.decode("UTF-8") if cookie else get_random_name()
225 # save it in cookie or reset expiry date
226 if not self.get_secure_cookie(
227 "emoji-chat-name", max_age_days=30, min_version=2
228 ):
229 self.set_secure_cookie(
230 "emoji-chat-name",
231 name.encode("UTF-8"),
232 expires_days=90,
233 path="/",
234 samesite="Strict",
235 )
237 geoip = await self.geoip() or {}
238 if "country_flag" in geoip:
239 flag = geoip["country_flag"]
240 elif self.request.host_name.endswith(".onion"):
241 flag = "🏴☠"
242 else:
243 flag = "❔"
245 return normalize_emojis(name + flag)
247 async def get_name_as_list(self) -> list[str]:
248 """Return the name as list of emojis."""
249 return [emoji["emoji"] for emoji in emoji_list(await self.get_name())]
251 async def post(self) -> None:
252 """Let users send messages and show the users the current messages."""
253 if not EVENT_REDIS.is_set():
254 raise HTTPError(503)
256 message = emojize_user_input(
257 normalize_emojis(self.get_argument("message"))
258 )
260 if err := check_message_invalid(message):
261 raise HTTPError(400, reason=err)
263 await save_new_message(
264 await self.get_name(),
265 message,
266 redis=self.redis,
267 redis_prefix=self.redis_prefix,
268 )
270 await self.render_chat(
271 await get_messages(self.redis, self.redis_prefix)
272 )
274 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None:
275 """Render the chat."""
276 raise NotImplementedError
279class HTMLChatHandler(ChatHandler, HTMLRequestHandler):
280 """The HTML request handler for the emoji chat."""
282 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None:
283 """Render the chat."""
284 await self.render(
285 "pages/emoji_chat.html",
286 messages=messages,
287 user_name=await self.get_name_as_list(),
288 )
291class APIChatHandler(ChatHandler, APIRequestHandler):
292 """The API request handler for the emoji chat."""
294 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None:
295 """Render the chat."""
296 await self.finish(
297 {
298 "current_user": await self.get_name_as_list(),
299 "messages": messages,
300 }
301 )
304OPEN_CONNECTIONS: list[ChatWebSocketHandler] = []
307class ChatWebSocketHandler(WebSocketHandler, ChatHandler):
308 """The handler for the chat WebSocket."""
310 name: str
311 connection_time: int
313 def on_close(self) -> None: # noqa: D102
314 LOGGER.info("WebSocket closed")
315 OPEN_CONNECTIONS.remove(self)
316 for conn in OPEN_CONNECTIONS:
317 conn.send_users()
319 def on_message(self, message: str | bytes) -> Awaitable[None] | None:
320 """Respond to an incoming message."""
321 if not message:
322 return None
323 message2: dict[str, Any] = json.loads(message)
324 if message2["type"] == "message":
325 if "message" not in message2:
326 return self.write_message(
327 {
328 "type": "error",
329 "error": "Message needs message key with the message.",
330 }
331 )
332 return self.save_new_message(message2["message"])
334 return self.write_message(
335 {"type": "error", "error": f"Unknown type {message2['type']}."}
336 )
338 async def open(self, *args: str, **kwargs: str) -> None:
339 # pylint: disable=invalid-overridden-method
340 """Handle an opened connection."""
341 LOGGER.info("WebSocket opened")
342 await self.write_message(
343 {
344 "type": "init",
345 "current_user": [
346 emoji["emoji"] for emoji in emoji_list(self.name)
347 ],
348 }
349 )
351 self.connection_time = get_ms_timestamp()
352 OPEN_CONNECTIONS.append(self)
353 for conn in OPEN_CONNECTIONS:
354 conn.send_users()
356 await self.send_messages()
358 async def prepare(self) -> None: # noqa: D102
359 self.now = await self.get_time()
361 if not EVENT_REDIS.is_set():
362 raise HTTPError(503)
364 self.name = await self.get_name()
366 if not await self.ratelimit(True):
367 await self.ratelimit()
369 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None:
370 """Render the chat."""
371 raise NotImplementedError
373 async def save_new_message(self, msg_text: str) -> None:
374 """Save a new message."""
375 msg_text = emojize_user_input(normalize_emojis(msg_text).strip())
376 if err := check_message_invalid(msg_text):
377 return await self.write_message({"type": "error", "error": err})
379 if self.settings.get("RATELIMITS") and not self.is_authorized(
380 Permission.RATELIMITS
381 ):
382 if not EVENT_REDIS.is_set():
383 return await self.write_message({"type": "ratelimit"})
385 ratelimited, headers = await ratelimit(
386 self.redis,
387 self.redis_prefix,
388 str(self.request.remote_ip),
389 bucket=self.RATELIMIT_POST_BUCKET,
390 max_burst=self.RATELIMIT_POST_LIMIT - 1,
391 count_per_period=self.RATELIMIT_POST_COUNT_PER_PERIOD,
392 period=self.RATELIMIT_POST_PERIOD,
393 tokens=1,
394 )
396 if ratelimited:
397 return await self.write_message(
398 {"type": "ratelimit", "retry_after": headers["Retry-After"]}
399 )
401 return await save_new_message(
402 self.name, msg_text, self.redis, self.redis_prefix
403 )
405 async def send_messages(self) -> None:
406 """Send this WebSocket all current messages."""
407 return await self.write_message(
408 {
409 "type": "messages",
410 "messages": await get_messages(self.redis, self.redis_prefix),
411 },
412 )
414 def send_users(self) -> None:
415 """Send this WebSocket all current users."""
416 if sys.flags.dev_mode:
417 self.write_message( # type: ignore[unused-awaitable]
418 {
419 "type": "users",
420 "users": [
421 {
422 "name": conn.name,
423 "joined_at": conn.connection_time,
424 }
425 for conn in OPEN_CONNECTIONS
426 ],
427 }
428 )