Coverage for an_website/emoji_chat/chat.py: 49.405%
168 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-16 19:56 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-16 19:56 +0000
1# This program is free software: you can redistribute it and/or modify
2# it under the terms of the GNU Affero General Public License as
3# published by the Free Software Foundation, either version 3 of the
4# License, or (at your option) any later version.
5#
6# This program is distributed in the hope that it will be useful,
7# but WITHOUT ANY WARRANTY; without even the implied warranty of
8# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9# GNU Affero General Public License for more details.
10#
11# You should have received a copy of the GNU Affero General Public License
12# along with this program. If not, see <https://www.gnu.org/licenses/>.
14"""A 🆒 chat."""
16from __future__ import annotations
18import asyncio
19import logging
20import random
21import sys
22import time
23from collections.abc import Awaitable, Iterable, Mapping
24from typing import Any, Final, Literal
26import orjson as json
27from emoji import purely_emoji # type: ignore[attr-defined]
28from emoji import EMOJI_DATA, demojize, emoji_list, emojize
29from redis.asyncio import Redis
30from tornado.web import Application, HTTPError
31from tornado.websocket import WebSocketHandler
33from .. import EPOCH_MS, EVENT_REDIS, EVENT_SHUTDOWN, NAME, ORJSON_OPTIONS
34from ..utils.base_request_handler import BaseRequestHandler
35from ..utils.request_handler import APIRequestHandler, HTMLRequestHandler
36from ..utils.utils import Permission, ratelimit
37from .pub_sub_provider import PubSubProvider
39LOGGER: Final = logging.getLogger(__name__)
41EMOJIS: Final[tuple[str, ...]] = tuple(EMOJI_DATA)
42EMOJIS_NO_FLAGS: Final[tuple[str, ...]] = tuple(
43 emoji for emoji in EMOJIS if ord(emoji[0]) not in range(0x1F1E6, 0x1F200)
44)
46MAX_MESSAGE_SAVE_COUNT: Final = 100
47MAX_MESSAGE_LENGTH: Final = 20
48REDIS_CHANNEL: Final = f"{NAME}:emoji_chat_channel"
51def get_ms_timestamp() -> int:
52 """Get the current time in ms."""
53 return time.time_ns() // 1_000_000 - EPOCH_MS
56async def subscribe_to_redis_channel(
57 app: Application, worker: int | None
58) -> None:
59 """Subscribe to the Redis channel and handle incoming messages."""
60 get_pubsub = PubSubProvider((REDIS_CHANNEL,), app.settings, worker)
61 del app
63 while not EVENT_SHUTDOWN.is_set(): # pylint: disable=while-used
64 ps = await get_pubsub()
65 try:
66 message = await ps.get_message(timeout=5.0)
67 except Exception as exc: # pylint: disable=broad-exception-caught
68 if str(exc) == "Connection closed by server.":
69 continue
70 LOGGER.exception("Failed to get message on worker %s", worker)
71 await asyncio.sleep(0)
72 continue
74 match message:
75 case None:
76 pass
77 case {
78 "type": "message",
79 "data": str(data),
80 "channel": channel,
81 } if channel == REDIS_CHANNEL:
82 await asyncio.gather(
83 *[conn.write_message(data) for conn in OPEN_CONNECTIONS]
84 )
85 case {
86 "type": "subscribe",
87 "data": 1,
88 "channel": channel,
89 } if channel == REDIS_CHANNEL:
90 logging.info(
91 "Subscribed to Redis channel %r on worker %s",
92 channel,
93 worker,
94 )
95 case _:
96 logging.error(
97 "Got unexpected message %s on worker %s",
98 message,
99 worker,
100 )
101 await asyncio.sleep(0)
104async def save_new_message(
105 author: str,
106 message: str,
107 redis: Redis[str],
108 redis_prefix: str,
109) -> None:
110 """Save a new message."""
111 message_dict = {
112 "author": [data["emoji"] for data in emoji_list(author)],
113 "content": [data["emoji"] for data in emoji_list(message)],
114 "timestamp": get_ms_timestamp(),
115 }
116 await redis.rpush(
117 f"{redis_prefix}:emoji-chat:message-list",
118 json.dumps(message_dict, option=ORJSON_OPTIONS),
119 )
120 await redis.ltrim(
121 f"{redis_prefix}:emoji-chat:message-list", -MAX_MESSAGE_SAVE_COUNT, -1
122 )
123 LOGGER.info("GOT new message %s", message_dict)
124 await redis.publish(
125 REDIS_CHANNEL,
126 json.dumps(
127 {
128 "type": "message",
129 "message": message_dict,
130 },
131 option=ORJSON_OPTIONS,
132 ),
133 )
136async def get_messages(
137 redis: Redis[str],
138 redis_prefix: str,
139 start: None | int = None,
140 stop: int = -1,
141) -> list[dict[str, Any]]:
142 """Get the messages."""
143 start = start if start is not None else -MAX_MESSAGE_SAVE_COUNT
144 messages = await redis.lrange(
145 f"{redis_prefix}:emoji-chat:message-list", start, stop
146 )
147 return [json.loads(message) for message in messages]
150def check_message_invalid(message: str) -> Literal[False] | str:
151 """Check if a message is an invalid message."""
152 if not message:
153 return "Empty message not allowed."
155 if not purely_emoji(message):
156 return "Message can only contain emojis."
158 if len(emoji_list(message)) > MAX_MESSAGE_LENGTH:
159 return f"Message longer than {MAX_MESSAGE_LENGTH} emojis."
161 return False
164def emojize_user_input(string: str) -> str:
165 """Emojize user input."""
166 string = emojize(string, language="de")
167 string = emojize(string, language="en")
168 string = emojize(string, language="alias")
169 return string
172def normalize_emojis(string: str) -> str:
173 """Normalize emojis in a string."""
174 return emojize(demojize(string))
177def get_random_name() -> str:
178 """Generate a random name."""
179 return normalize_emojis(
180 "".join(random.sample(EMOJIS_NO_FLAGS, 5)) # nosec: B311
181 )
184class ChatHandler(BaseRequestHandler):
185 """The request handler for the emoji chat."""
187 RATELIMIT_GET_BUCKET = "emoji-chat-get-messages"
188 RATELIMIT_GET_LIMIT = 10
189 RATELIMIT_GET_COUNT_PER_PERIOD = 10
190 RATELIMIT_GET_PERIOD = 1
192 RATELIMIT_POST_BUCKET = "emoji-chat-send-message"
193 RATELIMIT_POST_LIMIT = 5
194 RATELIMIT_POST_COUNT_PER_PERIOD = 5
195 RATELIMIT_POST_PERIOD = 5
197 async def get(
198 self,
199 *,
200 head: bool = False,
201 ) -> None:
202 """Show the users the current messages."""
203 if not EVENT_REDIS.is_set():
204 raise HTTPError(503)
206 if head:
207 return
209 await self.render_chat(
210 await get_messages(self.redis, self.redis_prefix)
211 )
213 async def get_name(self) -> str:
214 """Get the name of the user."""
215 cookie = self.get_secure_cookie(
216 "emoji-chat-name",
217 max_age_days=90,
218 min_version=2,
219 )
221 name = cookie.decode("UTF-8") if cookie else get_random_name()
223 # save it in cookie or reset expiry date
224 if not self.get_secure_cookie(
225 "emoji-chat-name", max_age_days=30, min_version=2
226 ):
227 self.set_secure_cookie(
228 "emoji-chat-name",
229 name.encode("UTF-8"),
230 expires_days=90,
231 path="/",
232 samesite="Strict",
233 )
235 geoip = await self.geoip() or {}
236 if "country_flag" in geoip:
237 flag = geoip["country_flag"]
238 elif self.request.host_name.endswith(".onion"):
239 flag = "🏴☠"
240 else:
241 flag = "❔"
243 return normalize_emojis(name + flag)
245 async def get_name_as_list(self) -> list[str]:
246 """Return the name as list of emojis."""
247 return [emoji["emoji"] for emoji in emoji_list(await self.get_name())]
249 async def post(self) -> None:
250 """Let users send messages and show the users the current messages."""
251 if not EVENT_REDIS.is_set():
252 raise HTTPError(503)
254 message = emojize_user_input(
255 normalize_emojis(self.get_argument("message"))
256 )
258 if err := check_message_invalid(message):
259 raise HTTPError(400, reason=err)
261 await save_new_message(
262 await self.get_name(),
263 message,
264 redis=self.redis,
265 redis_prefix=self.redis_prefix,
266 )
268 await self.render_chat(
269 await get_messages(self.redis, self.redis_prefix)
270 )
272 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None:
273 """Render the chat."""
274 raise NotImplementedError
277class HTMLChatHandler(ChatHandler, HTMLRequestHandler):
278 """The HTML request handler for the emoji chat."""
280 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None:
281 """Render the chat."""
282 await self.render(
283 "pages/emoji_chat.html",
284 messages=messages,
285 user_name=await self.get_name_as_list(),
286 )
289class APIChatHandler(ChatHandler, APIRequestHandler):
290 """The API request handler for the emoji chat."""
292 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None:
293 """Render the chat."""
294 await self.finish(
295 {
296 "current_user": await self.get_name_as_list(),
297 "messages": messages,
298 }
299 )
302OPEN_CONNECTIONS: list[ChatWebSocketHandler] = []
305class ChatWebSocketHandler(WebSocketHandler, ChatHandler):
306 """The handler for the chat WebSocket."""
308 name: str
309 connection_time: int
311 def on_close(self) -> None: # noqa: D102
312 LOGGER.info("WebSocket closed")
313 OPEN_CONNECTIONS.remove(self)
314 for conn in OPEN_CONNECTIONS:
315 conn.send_users()
317 def on_message(self, message: str | bytes) -> Awaitable[None] | None:
318 """Respond to an incoming message."""
319 if not message:
320 return None
321 message2: dict[str, Any] = json.loads(message)
322 if message2["type"] == "message":
323 if "message" not in message2:
324 return self.write_message(
325 {
326 "type": "error",
327 "error": "Message needs message key with the message.",
328 }
329 )
330 return self.save_new_message(message2["message"])
332 return self.write_message(
333 {"type": "error", "error": f"Unknown type {message2['type']}."}
334 )
336 async def open(self, *args: str, **kwargs: str) -> None:
337 # pylint: disable=invalid-overridden-method
338 """Handle an opened connection."""
339 LOGGER.info("WebSocket opened")
340 await self.write_message(
341 {
342 "type": "init",
343 "current_user": [
344 emoji["emoji"] for emoji in emoji_list(self.name)
345 ],
346 }
347 )
349 self.connection_time = get_ms_timestamp()
350 OPEN_CONNECTIONS.append(self)
351 for conn in OPEN_CONNECTIONS:
352 conn.send_users()
354 await self.send_messages()
356 async def prepare(self) -> None: # noqa: D102
357 self.now = await self.get_time()
359 if not EVENT_REDIS.is_set():
360 raise HTTPError(503)
362 self.name = await self.get_name()
364 if not await self.ratelimit(True):
365 await self.ratelimit()
367 async def render_chat(self, messages: Iterable[Mapping[str, Any]]) -> None:
368 """Render the chat."""
369 raise NotImplementedError
371 async def save_new_message(self, msg_text: str) -> None:
372 """Save a new message."""
373 msg_text = emojize_user_input(normalize_emojis(msg_text).strip())
374 if err := check_message_invalid(msg_text):
375 return await self.write_message({"type": "error", "error": err})
377 if self.settings.get("RATELIMITS") and not self.is_authorized(
378 Permission.RATELIMITS
379 ):
380 if not EVENT_REDIS.is_set():
381 return await self.write_message({"type": "ratelimit"})
383 ratelimited, headers = await ratelimit(
384 self.redis,
385 self.redis_prefix,
386 str(self.request.remote_ip),
387 bucket=self.RATELIMIT_POST_BUCKET,
388 max_burst=self.RATELIMIT_POST_LIMIT - 1,
389 count_per_period=self.RATELIMIT_POST_COUNT_PER_PERIOD,
390 period=self.RATELIMIT_POST_PERIOD,
391 tokens=1,
392 )
394 if ratelimited:
395 return await self.write_message(
396 {"type": "ratelimit", "retry_after": headers["Retry-After"]}
397 )
399 return await save_new_message(
400 self.name, msg_text, self.redis, self.redis_prefix
401 )
403 async def send_messages(self) -> None:
404 """Send this WebSocket all current messages."""
405 return await self.write_message(
406 {
407 "type": "messages",
408 "messages": await get_messages(self.redis, self.redis_prefix),
409 },
410 )
412 def send_users(self) -> None:
413 """Send this WebSocket all current users."""
414 if sys.flags.dev_mode:
415 self.write_message( # type: ignore[unused-awaitable]
416 {
417 "type": "users",
418 "users": [
419 {
420 "name": conn.name,
421 "joined_at": conn.connection_time,
422 }
423 for conn in OPEN_CONNECTIONS
424 ],
425 }
426 )