Coverage for an_website / emoji_chat / chat.py: 48.795%
166 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 19:37 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-11 19:37 +0000
1# This program is free software: you can redistribute it and/or modify
2# it under the terms of the GNU Affero General Public License as
3# published by the Free Software Foundation, either version 3 of the
4# License, or (at your option) any later version.
5#
6# This program is distributed in the hope that it will be useful,
7# but WITHOUT ANY WARRANTY; without even the implied warranty of
8# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9# GNU Affero General Public License for more details.
10#
11# You should have received a copy of the GNU Affero General Public License
12# along with this program. If not, see <https://www.gnu.org/licenses/>.
14"""A 🆒 chat."""
16from __future__ import annotations
18import asyncio
19import logging
20import random
21import sys
22import time
23from collections.abc import Awaitable, Iterable, Mapping
24from typing import Any, Final, Literal
26import orjson as json
27from emoji import EMOJI_DATA, demojize, emoji_list, emojize, purely_emoji
28from redis.asyncio import Redis
29from tornado.web import Application, HTTPError
30from tornado.websocket import WebSocketHandler
32from .. import EPOCH_MS, EVENT_REDIS, EVENT_SHUTDOWN, NAME, ORJSON_OPTIONS
33from ..utils.base_request_handler import BaseRequestHandler
34from ..utils.request_handler import APIRequestHandler, HTMLRequestHandler
35from ..utils.utils import Permission, ratelimit
36from .pub_sub_provider import PubSubProvider
38LOGGER: Final = logging.getLogger(__name__)
40EMOJIS_NO_FLAGS: Final[tuple[str, ...]] = tuple(
41 emoji
42 for emoji in EMOJI_DATA
43 if ord(emoji[0]) not in range(0x1F1E6, 0x1F200)
44)
46MAX_MESSAGE_SAVE_COUNT: Final = 200
47MAX_MESSAGE_LENGTH: Final = 20
48REDIS_CHANNEL: Final = f"{NAME}:emoji_chat_channel"
51def get_ms_timestamp() -> int:
52 """Get the current time in ms."""
53 return time.time_ns() // 1_000_000 - EPOCH_MS
56async def subscribe_to_redis_channel(
57 app: Application, worker: int | None
58) -> None:
59 """Subscribe to the Redis channel and handle incoming messages."""
60 get_pubsub = PubSubProvider((REDIS_CHANNEL,), app.settings, worker)
61 del app
63 while not EVENT_SHUTDOWN.is_set(): # pylint: disable=while-used
64 ps = await get_pubsub()
65 try:
66 message = await ps.get_message(timeout=5.0)
67 except Exception as exc: # pylint: disable=broad-exception-caught
68 if str(exc) == "Connection closed by server.":
69 continue
70 LOGGER.exception("Failed to get message on worker %s", worker)
71 await asyncio.sleep(0)
72 continue
74 match message:
75 case None:
76 pass
77 case {
78 "type": "message",
79 "data": str(data),
80 "channel": channel,
81 } if (
82 channel == REDIS_CHANNEL
83 ):
84 await asyncio.gather(
85 *[conn.write_message(data) for conn in OPEN_CONNECTIONS]
86 )
87 case {
88 "type": "subscribe",
89 "data": 1,
90 "channel": channel,
91 } if (
92 channel == REDIS_CHANNEL
93 ):
94 logging.info(
95 "Subscribed to Redis channel %r on worker %s",
96 channel,
97 worker,
98 )
99 case _:
100 logging.error(
101 "Got unexpected message %s on worker %s",
102 message,
103 worker,
104 )
105 await asyncio.sleep(0)
108async def save_new_message(
109 author: str,
110 message: str,
111 redis: Redis[str],
112 redis_prefix: str,
113) -> None:
114 """Save a new message."""
115 message_dict = {
116 "author": [data["emoji"] for data in emoji_list(author)],
117 "content": [data["emoji"] for data in emoji_list(message)],
118 "timestamp": get_ms_timestamp(),
119 }
120 await redis.rpush(
121 f"{redis_prefix}:emoji-chat:message-list",
122 json.dumps(message_dict, option=ORJSON_OPTIONS),
123 )
124 await redis.ltrim(
125 f"{redis_prefix}:emoji-chat:message-list", -MAX_MESSAGE_SAVE_COUNT, -1
126 )
127 LOGGER.info("GOT new message %s", message_dict)
128 await redis.publish(
129 REDIS_CHANNEL,
130 json.dumps(
131 {
132 "type": "message",
133 "message": message_dict,
134 },
135 option=ORJSON_OPTIONS,
136 ),
137 )
140async def get_messages(
141 redis: Redis[str],
142 redis_prefix: str,
143 start: None | int = None,
144 stop: int = -1,
145) -> list[dict[str, Any]]:
146 """Get the messages."""
147 start = start if start is not None else -MAX_MESSAGE_SAVE_COUNT
148 messages = await redis.lrange(
149 f"{redis_prefix}:emoji-chat:message-list", start, stop
150 )
151 return [json.loads(message) for message in messages]
154def check_message_invalid(message: str) -> Literal[False] | str:
155 """Check if a message is an invalid message."""
156 if not message:
157 return "Empty message not allowed."
159 if not purely_emoji(message):
160 return "Message can only contain emojis."
162 if len(emoji_list(message)) > MAX_MESSAGE_LENGTH:
163 return f"Message longer than {MAX_MESSAGE_LENGTH} emojis."
165 return False
168def emojize_user_input(string: str) -> str:
169 """Emojize user input."""
170 string = emojize(string, language="de")
171 string = emojize(string, language="en")
172 string = emojize(string, language="alias")
173 return string
176def normalize_emojis(string: str) -> str:
177 """Normalize emojis in a string."""
178 return emojize(demojize(string))
181def get_random_name() -> str:
182 """Generate a random name."""
183 return normalize_emojis(
184 "".join(random.sample(EMOJIS_NO_FLAGS, 5)) # nosec: B311
185 )
188class ChatHandler(BaseRequestHandler):
189 """The request handler for the emoji chat."""
191 RATELIMIT_GET_BUCKET = "emoji-chat-get-messages"
192 RATELIMIT_GET_LIMIT = 10
193 RATELIMIT_GET_COUNT_PER_PERIOD = 10
194 RATELIMIT_GET_PERIOD = 1
196 RATELIMIT_POST_BUCKET = "emoji-chat-send-message"
197 RATELIMIT_POST_LIMIT = 5
198 RATELIMIT_POST_COUNT_PER_PERIOD = 5
199 RATELIMIT_POST_PERIOD = 5
201 async def get(
202 self,
203 *,
204 head: bool = False,
205 ) -> None:
206 """Show the users the current messages."""
207 if not EVENT_REDIS.is_set():
208 raise HTTPError(503)
210 if head:
211 return
213 await self.render_chat(
214 await get_messages(self.redis, self.redis_prefix)
215 )
217 async def get_name(self) -> str:
218 """Get the name of the user."""
219 cookie = self.get_secure_cookie(
220 "emoji-chat-name",
221 max_age_days=90,
222 min_version=2,
223 )
225 name = cookie.decode("UTF-8") if cookie else get_random_name()
227 # save it in cookie or reset expiry date
228 if not self.get_secure_cookie(
229 "emoji-chat-name", max_age_days=30, min_version=2
230 ):
231 self.set_secure_cookie(
232 "emoji-chat-name",
233 name.encode("UTF-8"),
234 expires_days=90,
235 path="/",
236 samesite="Strict",
237 )
239 geoip = await self.geoip() or {}
240 if "country_flag" in geoip:
241 flag = geoip["country_flag"]
242 elif self.request.host_name.endswith(".onion"):
243 flag = "🏴☠"
244 else:
245 flag = "❔"
247 return normalize_emojis(name + flag)
249 async def get_name_as_list(self) -> list[str]:
250 """Return the name as list of emojis."""
251 return [emoji["emoji"] for emoji in emoji_list(await self.get_name())]
253 async def post(self) -> None:
254 """Let users send messages and show the users the current messages."""
255 if not EVENT_REDIS.is_set():
256 raise HTTPError(503)
258 message = emojize_user_input(
259 normalize_emojis(self.get_argument("message"))
260 )
262 if err := check_message_invalid(message):
263 raise HTTPError(400, reason=err)
265 await save_new_message(
266 await self.get_name(),
267 message,
268 redis=self.redis,
269 redis_prefix=self.redis_prefix,
270 )
272 await self.render_chat(
273 await get_messages(self.redis, self.redis_prefix)
274 )
276 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None:
277 """Render the chat."""
278 raise NotImplementedError
281class HTMLChatHandler(ChatHandler, HTMLRequestHandler):
282 """The HTML request handler for the emoji chat."""
284 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None:
285 """Render the chat."""
286 await self.render(
287 "pages/emoji_chat.html",
288 messages=messages,
289 user_name=await self.get_name_as_list(),
290 )
293class APIChatHandler(ChatHandler, APIRequestHandler):
294 """The API request handler for the emoji chat."""
296 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None:
297 """Render the chat."""
298 await self.finish(
299 {
300 "current_user": await self.get_name_as_list(),
301 "messages": messages,
302 }
303 )
306OPEN_CONNECTIONS: list[ChatWebSocketHandler] = []
309class ChatWebSocketHandler(WebSocketHandler, ChatHandler):
310 """The handler for the chat WebSocket."""
312 name: str
313 connection_time: int
315 def on_close(self) -> None: # noqa: D102
316 LOGGER.info("WebSocket closed")
317 OPEN_CONNECTIONS.remove(self)
318 for conn in OPEN_CONNECTIONS:
319 conn.send_users()
321 def on_message(self, message: str | bytes) -> Awaitable[None] | None:
322 """Respond to an incoming message."""
323 if not message:
324 return None
325 message2: dict[str, Any] = json.loads(message)
326 if message2["type"] == "message":
327 if "message" not in message2:
328 return self.write_message(
329 {
330 "type": "error",
331 "error": "Message needs message key with the message.",
332 }
333 )
334 return self.save_new_message(message2["message"])
336 return self.write_message(
337 {"type": "error", "error": f"Unknown type {message2['type']}."}
338 )
340 async def open(self, *args: str, **kwargs: str) -> None:
341 # pylint: disable=invalid-overridden-method
342 """Handle an opened connection."""
343 LOGGER.info("WebSocket opened")
344 await self.write_message(
345 {
346 "type": "init",
347 "current_user": [
348 emoji["emoji"] for emoji in emoji_list(self.name)
349 ],
350 }
351 )
353 self.connection_time = get_ms_timestamp()
354 OPEN_CONNECTIONS.append(self)
355 for conn in OPEN_CONNECTIONS:
356 conn.send_users()
358 await self.send_messages()
360 async def prepare(self) -> None: # noqa: D102
361 self.now = await self.get_time()
363 if not EVENT_REDIS.is_set():
364 raise HTTPError(503)
366 self.name = await self.get_name()
368 if not await self.ratelimit(True):
369 await self.ratelimit()
371 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None:
372 """Render the chat."""
373 raise NotImplementedError
375 async def save_new_message(self, msg_text: str) -> None:
376 """Save a new message."""
377 msg_text = emojize_user_input(normalize_emojis(msg_text).strip())
378 if err := check_message_invalid(msg_text):
379 return await self.write_message({"type": "error", "error": err})
381 if self.settings.get("RATELIMITS") and not self.is_authorized(
382 Permission.RATELIMITS
383 ):
384 if not EVENT_REDIS.is_set():
385 return await self.write_message({"type": "ratelimit"})
387 ratelimited, headers = await ratelimit(
388 self.redis,
389 self.redis_prefix,
390 str(self.request.remote_ip),
391 bucket=self.RATELIMIT_POST_BUCKET,
392 max_burst=self.RATELIMIT_POST_LIMIT - 1,
393 count_per_period=self.RATELIMIT_POST_COUNT_PER_PERIOD,
394 period=self.RATELIMIT_POST_PERIOD,
395 tokens=1,
396 )
398 if ratelimited:
399 return await self.write_message(
400 {"type": "ratelimit", "retry_after": headers["Retry-After"]}
401 )
403 return await save_new_message(
404 self.name, msg_text, self.redis, self.redis_prefix
405 )
407 async def send_messages(self) -> None:
408 """Send this WebSocket all current messages."""
409 return await self.write_message(
410 {
411 "type": "messages",
412 "messages": await get_messages(self.redis, self.redis_prefix),
413 },
414 )
416 def send_users(self) -> None:
417 """Send this WebSocket all current users."""
418 if sys.flags.dev_mode:
419 self.write_message( # type: ignore[unused-awaitable]
420 {
421 "type": "users",
422 "users": [
423 {
424 "name": conn.name,
425 "joined_at": conn.connection_time,
426 }
427 for conn in OPEN_CONNECTIONS
428 ],
429 }
430 )