Coverage for an_website/main.py: 81.224%
245 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-16 19:56 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-11-16 19:56 +0000
1# This program is free software: you can redistribute it and/or modify
2# it under the terms of the GNU Affero General Public License as
3# published by the Free Software Foundation, either version 3 of the
4# License, or (at your option) any later version.
5#
6# This program is distributed in the hope that it will be useful,
7# but WITHOUT ANY WARRANTY; without even the implied warranty of
8# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9# GNU Affero General Public License for more details.
10#
11# You should have received a copy of the GNU Affero General Public License
12# along with this program. If not, see <https://www.gnu.org/licenses/>.
13# pylint: disable=import-private-name, too-many-lines
15"""
16The website of the AN.
18Loads config and modules and starts Tornado.
19"""
21from __future__ import annotations
23import asyncio
24import atexit
25import importlib
26import logging
27import os
28import platform
29import signal
30import ssl
31import sys
32import threading
33import time
34import types
35import uuid
36from asyncio import AbstractEventLoop
37from asyncio.runners import _cancel_all_tasks # type: ignore[attr-defined]
38from base64 import b64encode
39from collections.abc import Callable, Iterable, Mapping, MutableSequence
40from configparser import ConfigParser
41from functools import partial
42from hashlib import sha256
43from multiprocessing.process import _children # type: ignore[attr-defined]
44from pathlib import Path
45from socket import socket
46from typing import Any, Final, TypedDict, TypeGuard, cast
47from warnings import catch_warnings, simplefilter
48from zoneinfo import ZoneInfo
50import regex
51from Crypto.Hash import RIPEMD160
52from ecs_logging import StdlibFormatter
53from elastic_enterprise_search import AppSearch # type: ignore[import-untyped]
54from elasticapm.contrib.tornado import ElasticAPM
55from redis.asyncio import (
56 BlockingConnectionPool,
57 Redis,
58 SSLConnection,
59 UnixDomainSocketConnection,
60)
61from setproctitle import setproctitle
62from tornado.httpserver import HTTPServer
63from tornado.log import LogFormatter
64from tornado.netutil import bind_sockets, bind_unix_socket
65from tornado.process import fork_processes, task_id
66from tornado.web import Application, RedirectHandler
67from typed_stream import Stream
69from . import (
70 CA_BUNDLE_PATH,
71 DIR,
72 EVENT_SHUTDOWN,
73 NAME,
74 TEMPLATES_DIR,
75 UPTIME,
76 VERSION,
77)
78from .contact.contact import apply_contact_stuff_to_app
79from .utils import background_tasks, static_file_handling
80from .utils.base_request_handler import BaseRequestHandler, request_ctx_var
81from .utils.better_config_parser import BetterConfigParser
82from .utils.elasticsearch_setup import setup_elasticsearch
83from .utils.logging import WebhookFormatter, WebhookHandler
84from .utils.request_handler import NotFoundHandler
85from .utils.static_file_from_traversable import TraversableStaticFileHandler
86from .utils.template_loader import TemplateLoader
87from .utils.utils import (
88 ArgparseNamespace,
89 Handler,
90 ModuleInfo,
91 Permission,
92 Timer,
93 create_argument_parser,
94 geoip,
95 get_arguments_without_help,
96 time_function,
97)
99try:
100 import perf8 # type: ignore[import, unused-ignore]
101except ModuleNotFoundError:
102 perf8 = None # pylint: disable=invalid-name
104IGNORED_MODULES: Final[set[str]] = {
105 "patches",
106 "static",
107 "templates",
108} | (set() if sys.flags.dev_mode else {"example"})
110LOGGER: Final = logging.getLogger(__name__)
113# add all the information from the packages to a list
114# this calls the get_module_info function in every file
115# files and dirs starting with '_' get ignored
116def get_module_infos() -> str | tuple[ModuleInfo, ...]:
117 """Import the modules and return the loaded module infos in a tuple."""
118 module_infos: list[ModuleInfo] = []
119 loaded_modules: list[str] = []
120 errors: list[str] = []
122 for potential_module in DIR.iterdir():
123 if (
124 potential_module.name.startswith("_")
125 or potential_module.name in IGNORED_MODULES
126 or not potential_module.is_dir()
127 ):
128 continue
130 _module_infos = get_module_infos_from_module(
131 potential_module.name, errors, ignore_not_found=True
132 )
133 if _module_infos:
134 module_infos.extend(_module_infos)
135 loaded_modules.append(potential_module.name)
136 LOGGER.debug(
137 (
138 "Found module_infos in %s.__init__.py, "
139 "not searching in other modules in the package."
140 ),
141 potential_module,
142 )
143 continue
145 if f"{potential_module}.*" in IGNORED_MODULES:
146 continue
148 for potential_file in potential_module.iterdir():
149 module_name = f"{potential_module.name}.{potential_file.name[:-3]}"
150 if (
151 not potential_file.name.endswith(".py")
152 or module_name in IGNORED_MODULES
153 or potential_file.name.startswith("_")
154 ):
155 continue
156 _module_infos = get_module_infos_from_module(module_name, errors)
157 if _module_infos:
158 module_infos.extend(_module_infos)
159 loaded_modules.append(module_name)
161 if len(errors) > 0:
162 if sys.flags.dev_mode:
163 # exit to make sure it gets fixed
164 return "\n".join(errors)
165 # don't exit in production to keep stuff running
166 LOGGER.error("\n".join(errors))
168 LOGGER.info(
169 "Loaded %d modules: '%s'",
170 len(loaded_modules),
171 "', '".join(loaded_modules),
172 )
174 LOGGER.info(
175 "Ignored %d modules: '%s'",
176 len(IGNORED_MODULES),
177 "', '".join(IGNORED_MODULES),
178 )
180 sort_module_infos(module_infos)
182 # make module_infos immutable so it never changes
183 return tuple(module_infos)
186def get_module_infos_from_module(
187 module_name: str,
188 errors: MutableSequence[str], # gets modified
189 ignore_not_found: bool = False,
190) -> None | list[ModuleInfo]:
191 """Get the module infos based on a module."""
192 import_timer = Timer()
193 module = importlib.import_module(
194 f".{module_name}",
195 package="an_website",
196 )
197 if import_timer.stop() > 0.1:
198 LOGGER.warning(
199 "Import of %s took %ss. That's affecting the startup time.",
200 module_name,
201 import_timer.get(),
202 )
204 module_infos: list[ModuleInfo] = []
206 has_get_module_info = "get_module_info" in dir(module)
207 has_get_module_infos = "get_module_infos" in dir(module)
209 if not (has_get_module_info or has_get_module_infos):
210 if ignore_not_found:
211 return None
212 errors.append(
213 f"{module_name} has no 'get_module_info' and no 'get_module_infos' "
214 "method. Please add at least one of the methods or add "
215 f"'{module_name.rsplit('.', 1)[0]}.*' or {module_name!r} to "
216 "IGNORED_MODULES."
217 )
218 return None
220 if has_get_module_info and isinstance(
221 module_info := module.get_module_info(),
222 ModuleInfo,
223 ):
224 module_infos.append(module_info)
225 elif has_get_module_info:
226 errors.append(
227 f"'get_module_info' in {module_name} does not return ModuleInfo. "
228 "Please fix the returned value."
229 )
231 if not has_get_module_infos:
232 return module_infos or None
234 _module_infos = module.get_module_infos()
236 if not isinstance(_module_infos, Iterable):
237 errors.append(
238 f"'get_module_infos' in {module_name} does not return an Iterable. "
239 "Please fix the returned value."
240 )
241 return module_infos or None
243 for _module_info in _module_infos:
244 if isinstance(_module_info, ModuleInfo):
245 module_infos.append(_module_info)
246 else:
247 errors.append(
248 f"'get_module_infos' in {module_name} did return an Iterable "
249 f"with an element of type {type(_module_info)}. "
250 "Please fix the returned value."
251 )
253 return module_infos or None
256def sort_module_infos(module_infos: list[ModuleInfo]) -> None:
257 """Sort a list of module info and move the main page to the top."""
258 # sort it so the order makes sense
259 module_infos.sort()
261 # move the main page to the top
262 for i, info in enumerate(module_infos):
263 if info.path == "/":
264 module_infos.insert(0, module_infos.pop(i))
265 break
268def get_all_handlers(module_infos: Iterable[ModuleInfo]) -> list[Handler]:
269 """
270 Parse the module information and return the handlers in a tuple.
272 If a handler has only 2 elements a dict with title and description
273 gets added. This information is gotten from the module info.
274 """
275 handler: Handler | list[Any]
276 handlers: list[Handler] = static_file_handling.get_handlers()
278 # add all the normal handlers
279 for module_info in module_infos:
280 for handler in module_info.handlers:
281 handler = list(handler) # pylint: disable=redefined-loop-name
282 # if the handler is a request handler from us
283 # and not a built-in like StaticFileHandler & RedirectHandler
284 if issubclass(handler[1], BaseRequestHandler):
285 if len(handler) == 2:
286 # set "default_title" or "default_description" to False so
287 # that module_info.name & module_info.description get used
288 handler.append(
289 {
290 "default_title": False,
291 "default_description": False,
292 "module_info": module_info,
293 }
294 )
295 else:
296 handler[2]["module_info"] = module_info
297 handlers.append(tuple(handler))
299 # redirect handler, to make finding APIs easier
300 handlers.append((r"/(.+)/api/*", RedirectHandler, {"url": "/api/{0}"}))
302 handlers.append(
303 (
304 r"(?i)/\.well-known/(.*)",
305 TraversableStaticFileHandler,
306 {"root": Path(".well-known"), "hashes": {}},
307 )
308 )
310 LOGGER.debug("Loaded %d handlers", len(handlers))
312 return handlers
315def ignore_modules(config: BetterConfigParser) -> None:
316 """Read ignored modules from the config."""
317 IGNORED_MODULES.update(
318 config.getset(
319 "GENERAL", "IGNORED_MODULES", fallback={"element_web_link"}
320 )
321 )
324def get_normed_paths_from_module_infos(
325 module_infos: Iterable[ModuleInfo],
326) -> dict[str, str]:
327 """Get all paths from the module infos."""
329 def tuple_has_no_none(
330 value: tuple[str | None, str | None]
331 ) -> TypeGuard[tuple[str, str]]:
332 return None not in value
334 def info_to_paths(info: ModuleInfo) -> Stream[tuple[str, str]]:
335 return (
336 Stream(((info.path, info.path),))
337 .chain(
338 info.aliases.items()
339 if isinstance(info.aliases, Mapping)
340 else ((alias, info.path) for alias in info.aliases)
341 )
342 .chain(
343 Stream(info.sub_pages)
344 .map(lambda sub_info: sub_info.path)
345 .filter()
346 .map(lambda path: (path, path))
347 )
348 .filter(tuple_has_no_none)
349 )
351 return (
352 Stream(module_infos)
353 .flat_map(info_to_paths)
354 .filter(lambda p: p[0].startswith("/"))
355 .map(lambda p: (p[0].strip("/").lower(), p[1]))
356 .filter(lambda p: p[0])
357 .collect(dict)
358 )
361def make_app(config: ConfigParser) -> str | Application:
362 """Create the Tornado application and return it."""
363 module_infos, duration = time_function(get_module_infos)
364 if isinstance(module_infos, str):
365 return module_infos
366 if duration > 1:
367 LOGGER.warning(
368 "Getting the module infos took %ss. That's probably too long.",
369 duration,
370 )
371 handlers = get_all_handlers(module_infos)
372 return Application(
373 handlers, # type: ignore[arg-type]
374 MODULE_INFOS=module_infos,
375 SHOW_HAMBURGER_MENU=not Stream(module_infos)
376 .exclude(lambda info: info.hidden)
377 .filter(lambda info: info.path)
378 .empty(),
379 NORMED_PATHS=get_normed_paths_from_module_infos(module_infos),
380 HANDLERS=handlers,
381 # General settings
382 autoreload=False,
383 debug=sys.flags.dev_mode,
384 default_handler_class=NotFoundHandler,
385 compress_response=config.getboolean(
386 "GENERAL", "COMPRESS_RESPONSE", fallback=False
387 ),
388 websocket_ping_interval=10,
389 # Template settings
390 template_loader=TemplateLoader(
391 root=TEMPLATES_DIR, whitespace="oneline"
392 ),
393 )
396def apply_config_to_app(app: Application, config: BetterConfigParser) -> None:
397 """Apply the config (from the config.ini file) to the application."""
398 app.settings["CONFIG"] = config
400 app.settings["cookie_secret"] = config.get(
401 "GENERAL", "COOKIE_SECRET", fallback="xyzzy"
402 )
404 app.settings["CRAWLER_SECRET"] = config.get(
405 "APP_SEARCH", "CRAWLER_SECRET", fallback=None
406 )
408 app.settings["DOMAIN"] = config.get("GENERAL", "DOMAIN", fallback=None)
410 app.settings["ELASTICSEARCH_PREFIX"] = config.get(
411 "ELASTICSEARCH", "PREFIX", fallback=NAME
412 )
414 app.settings["HSTS"] = config.getboolean("TLS", "HSTS", fallback=False)
416 app.settings["NETCUP"] = config.getboolean(
417 "GENERAL", "NETCUP", fallback=False
418 )
420 app.settings["COMMITMENT_URI"] = config.get(
421 "GENERAL",
422 "COMMITMENT_URI",
423 fallback="https://github.asozial.org/an-website/commitment.txt",
424 )
426 onion_address = config.get("GENERAL", "ONION_ADDRESS", fallback=None)
427 app.settings["ONION_ADDRESS"] = onion_address
428 if onion_address is None:
429 app.settings["ONION_PROTOCOL"] = None
430 else:
431 app.settings["ONION_PROTOCOL"] = onion_address.split("://")[0]
433 app.settings["RATELIMITS"] = config.getboolean(
434 "GENERAL",
435 "RATELIMITS",
436 fallback=config.getboolean("REDIS", "ENABLED", fallback=False),
437 )
439 app.settings["REDIS_PREFIX"] = config.get("REDIS", "PREFIX", fallback=NAME)
441 app.settings["REPORTING"] = config.getboolean(
442 "REPORTING", "ENABLED", fallback=True
443 )
445 app.settings["REPORTING_BUILTIN"] = config.getboolean(
446 "REPORTING", "BUILTIN", fallback=sys.flags.dev_mode
447 )
449 app.settings["REPORTING_ENDPOINT"] = config.get(
450 "REPORTING",
451 "ENDPOINT",
452 fallback=(
453 "/api/reports"
454 if app.settings["REPORTING_BUILTIN"]
455 else "https://asozial.org/api/reports"
456 ),
457 )
459 app.settings["TRUSTED_API_SECRETS"] = {
460 key_perms[0]: Permission(
461 int(key_perms[1])
462 if len(key_perms) > 1
463 else (1 << len(Permission)) - 1 # should be all permissions
464 )
465 for secret in config.getset(
466 "GENERAL", "TRUSTED_API_SECRETS", fallback={"xyzzy"}
467 )
468 if (key_perms := [part.strip() for part in secret.split("=")])
469 if key_perms[0]
470 }
472 app.settings["AUTH_TOKEN_SECRET"] = config.get(
473 "GENERAL", "AUTH_TOKEN_SECRET", fallback=None
474 )
475 if not app.settings["AUTH_TOKEN_SECRET"]:
476 node = uuid.getnode().to_bytes(6, "big")
477 secret = RIPEMD160.new(node).digest().decode("BRAILLE")
478 LOGGER.warning(
479 "AUTH_TOKEN_SECRET is unset, implicitly setting it to %r",
480 secret,
481 )
482 app.settings["AUTH_TOKEN_SECRET"] = secret
484 app.settings["UNDER_ATTACK"] = config.getboolean(
485 "GENERAL", "UNDER_ATTACK", fallback=False
486 )
488 apply_contact_stuff_to_app(app, config)
491def get_ssl_context( # pragma: no cover
492 config: ConfigParser,
493) -> None | ssl.SSLContext:
494 """Create SSL context and configure using the config."""
495 if config.getboolean("TLS", "ENABLED", fallback=False):
496 ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
497 ssl_ctx.load_cert_chain(
498 config.get("TLS", "CERTFILE"),
499 config.get("TLS", "KEYFILE", fallback=None),
500 config.get("TLS", "PASSWORD", fallback=None),
501 )
502 return ssl_ctx
503 return None
506def setup_logging( # pragma: no cover
507 config: ConfigParser,
508 force: bool = False,
509) -> None:
510 """Setup logging.""" # noqa: D401
511 root_logger = logging.getLogger()
513 if root_logger.handlers:
514 if not force:
515 return
516 for handler in root_logger.handlers[:]:
517 root_logger.removeHandler(handler)
518 handler.close()
520 debug = config.getboolean("LOGGING", "DEBUG", fallback=sys.flags.dev_mode)
522 logging.captureWarnings(True)
524 root_logger.setLevel(logging.DEBUG if debug else logging.INFO)
525 logging.getLogger("tornado.curl_httpclient").setLevel(logging.INFO)
526 logging.getLogger("elasticsearch").setLevel(logging.INFO)
528 stream_handler = logging.StreamHandler()
529 if sys.flags.dev_mode:
530 spam = regex.sub(r"%\((end_)?color\)s", "", LogFormatter.DEFAULT_FORMAT)
531 formatter = logging.Formatter(spam, LogFormatter.DEFAULT_DATE_FORMAT)
532 else:
533 formatter = LogFormatter()
534 stream_handler.setFormatter(formatter)
535 root_logger.addHandler(stream_handler)
537 if path := config.get("LOGGING", "PATH", fallback=None):
538 os.makedirs(path, 0o755, True)
539 file_handler = logging.handlers.TimedRotatingFileHandler(
540 os.path.join(path, f"{NAME}.log"),
541 encoding="UTF-8",
542 when="midnight",
543 backupCount=30,
544 utc=True,
545 )
546 file_handler.setFormatter(StdlibFormatter())
547 root_logger.addHandler(file_handler)
550class WebhookLoggingOptions: # pylint: disable=too-few-public-methods
551 """Webhook logging options."""
553 __slots__ = (
554 "url",
555 "content_type",
556 "body_format",
557 "timestamp_format",
558 "timestamp_timezone",
559 "escape_message",
560 "max_message_length",
561 )
563 url: str | None
564 content_type: str
565 body_format: str
566 timestamp_format: str | None
567 timestamp_timezone: str | None
568 escape_message: bool
569 max_message_length: int | None
571 def __init__(self, config: ConfigParser) -> None:
572 """Initialize Webhook logging options."""
573 self.url = config.get("LOGGING", "WEBHOOK_URL", fallback=None)
574 self.content_type = config.get(
575 "LOGGING",
576 "WEBHOOK_CONTENT_TYPE",
577 fallback="application/json",
578 )
579 spam = regex.sub(r"%\((end_)?color\)s", "", LogFormatter.DEFAULT_FORMAT)
580 self.body_format = config.get(
581 "LOGGING",
582 "WEBHOOK_BODY_FORMAT",
583 fallback='{"text":"' + spam + '"}',
584 )
585 self.timestamp_format = config.get(
586 "LOGGING",
587 "WEBHOOK_TIMESTAMP_FORMAT",
588 fallback=None,
589 )
590 self.timestamp_timezone = config.get(
591 "LOGGING", "WEBHOOK_TIMESTAMP_TIMEZONE", fallback=None
592 )
593 self.escape_message = config.getboolean(
594 "LOGGING",
595 "WEBHOOK_ESCAPE_MESSAGE",
596 fallback=True,
597 )
598 self.max_message_length = config.getint(
599 "LOGGING", "WEBHOOK_MAX_MESSAGE_LENGTH", fallback=None
600 )
603def setup_webhook_logging( # pragma: no cover
604 options: WebhookLoggingOptions,
605 loop: asyncio.AbstractEventLoop,
606) -> None:
607 """Setup Webhook logging.""" # noqa: D401
608 if not options.url:
609 return
611 LOGGER.info("Setting up Webhook logging")
613 root_logger = logging.getLogger()
615 webhook_content_type = options.content_type
616 webhook_handler = WebhookHandler(
617 logging.ERROR,
618 loop=loop,
619 url=options.url,
620 content_type=webhook_content_type,
621 )
622 formatter = WebhookFormatter(
623 options.body_format,
624 options.timestamp_format,
625 )
626 formatter.timezone = (
627 None
628 if options.timestamp_format is None
629 else ZoneInfo(options.timestamp_format)
630 )
631 formatter.escape_message = options.escape_message
632 formatter.max_message_length = options.max_message_length
633 webhook_handler.setFormatter(formatter)
634 root_logger.addHandler(webhook_handler)
637def setup_apm(app: Application) -> None: # pragma: no cover
638 """Setup APM.""" # noqa: D401
639 config: BetterConfigParser = app.settings["CONFIG"]
640 app.settings["ELASTIC_APM"] = {
641 "ENABLED": config.getboolean("ELASTIC_APM", "ENABLED", fallback=False),
642 "SERVER_URL": config.get(
643 "ELASTIC_APM", "SERVER_URL", fallback="http://localhost:8200"
644 ),
645 "SECRET_TOKEN": config.get(
646 "ELASTIC_APM", "SECRET_TOKEN", fallback=None
647 ),
648 "API_KEY": config.get("ELASTIC_APM", "API_KEY", fallback=None),
649 "VERIFY_SERVER_CERT": config.getboolean(
650 "ELASTIC_APM", "VERIFY_SERVER_CERT", fallback=True
651 ),
652 "USE_CERTIFI": True, # doesn't actually use certifi
653 "SERVICE_NAME": NAME.removesuffix("-dev"),
654 "SERVICE_VERSION": VERSION,
655 "ENVIRONMENT": (
656 "production" if not sys.flags.dev_mode else "development"
657 ),
658 "DEBUG": True,
659 "CAPTURE_BODY": "errors",
660 "TRANSACTION_IGNORE_URLS": [
661 "/api/ping",
662 "/static/*",
663 "/favicon.png",
664 ],
665 "TRANSACTIONS_IGNORE_PATTERNS": ["^OPTIONS "],
666 "PROCESSORS": [
667 "an_website.utils.utils.apm_anonymization_processor",
668 "elasticapm.processors.sanitize_stacktrace_locals",
669 "elasticapm.processors.sanitize_http_request_cookies",
670 "elasticapm.processors.sanitize_http_headers",
671 "elasticapm.processors.sanitize_http_wsgi_env",
672 "elasticapm.processors.sanitize_http_request_body",
673 ],
674 "RUM_SERVER_URL": config.get(
675 "ELASTIC_APM", "RUM_SERVER_URL", fallback=None
676 ),
677 "RUM_SERVER_URL_PREFIX": config.get(
678 "ELASTIC_APM", "RUM_SERVER_URL_PREFIX", fallback=None
679 ),
680 }
682 script_options = [
683 f"serviceName:{app.settings['ELASTIC_APM']['SERVICE_NAME']!r}",
684 f"serviceVersion:{app.settings['ELASTIC_APM']['SERVICE_VERSION']!r}",
685 f"environment:{app.settings['ELASTIC_APM']['ENVIRONMENT']!r}",
686 ]
688 rum_server_url = app.settings["ELASTIC_APM"]["RUM_SERVER_URL"]
690 if rum_server_url is None:
691 script_options.append(
692 f"serverUrl:{app.settings['ELASTIC_APM']['SERVER_URL']!r}"
693 )
694 elif rum_server_url:
695 script_options.append(f"serverUrl:{rum_server_url!r}")
696 else:
697 script_options.append("serverUrl:window.location.origin")
699 if app.settings["ELASTIC_APM"]["RUM_SERVER_URL_PREFIX"]:
700 script_options.append(
701 f"serverUrlPrefix:{app.settings['ELASTIC_APM']['RUM_SERVER_URL_PREFIX']!r}"
702 )
704 app.settings["ELASTIC_APM"]["INLINE_SCRIPT"] = (
705 "elasticApm.init({" + ",".join(script_options) + "})"
706 )
708 app.settings["ELASTIC_APM"]["INLINE_SCRIPT_HASH"] = b64encode(
709 sha256(
710 app.settings["ELASTIC_APM"]["INLINE_SCRIPT"].encode("ASCII")
711 ).digest()
712 ).decode("ASCII")
714 if app.settings["ELASTIC_APM"]["ENABLED"]:
715 app.settings["ELASTIC_APM"]["CLIENT"] = ElasticAPM(app).client
718def setup_app_search(app: Application) -> None: # pragma: no cover
719 """Setup Elastic App Search.""" # noqa: D401
720 config: BetterConfigParser = app.settings["CONFIG"]
721 host = config.get("APP_SEARCH", "HOST", fallback=None)
722 key = config.get("APP_SEARCH", "SEARCH_KEY", fallback=None)
723 verify_certs = config.getboolean(
724 "APP_SEARCH", "VERIFY_CERTS", fallback=True
725 )
726 app.settings["APP_SEARCH"] = (
727 AppSearch(
728 host,
729 bearer_auth=key,
730 verify_certs=verify_certs,
731 ca_certs=CA_BUNDLE_PATH,
732 )
733 if host
734 else None
735 )
736 app.settings["APP_SEARCH_HOST"] = host
737 app.settings["APP_SEARCH_KEY"] = key
738 app.settings["APP_SEARCH_ENGINE"] = config.get(
739 "APP_SEARCH", "ENGINE_NAME", fallback=NAME.removesuffix("-dev")
740 )
743def setup_redis(app: Application) -> None | Redis[str]:
744 """Setup Redis.""" # noqa: D401
745 config: BetterConfigParser = app.settings["CONFIG"]
747 class Kwargs(TypedDict, total=False):
748 """Kwargs of BlockingConnectionPool constructor."""
750 db: int
751 username: None | str
752 password: None | str
753 retry_on_timeout: bool
754 connection_class: type[UnixDomainSocketConnection] | type[SSLConnection]
755 path: str
756 host: str
757 port: int
758 ssl_ca_certs: str
759 ssl_keyfile: None | str
760 ssl_certfile: None | str
761 ssl_check_hostname: bool
762 ssl_cert_reqs: str
764 kwargs: Kwargs = {
765 "db": config.getint("REDIS", "DB", fallback=0),
766 "username": config.get("REDIS", "USERNAME", fallback=None),
767 "password": config.get("REDIS", "PASSWORD", fallback=None),
768 "retry_on_timeout": config.getboolean(
769 "REDIS", "RETRY_ON_TIMEOUT", fallback=False
770 ),
771 }
772 redis_ssl_kwargs: Kwargs = {
773 "connection_class": SSLConnection,
774 "ssl_ca_certs": CA_BUNDLE_PATH,
775 "ssl_keyfile": config.get("REDIS", "SSL_KEYFILE", fallback=None),
776 "ssl_certfile": config.get("REDIS", "SSL_CERTFILE", fallback=None),
777 "ssl_cert_reqs": config.get(
778 "REDIS", "SSL_CERT_REQS", fallback="required"
779 ),
780 "ssl_check_hostname": config.getboolean(
781 "REDIS", "SSL_CHECK_HOSTNAME", fallback=False
782 ),
783 }
784 redis_host_port_kwargs: Kwargs = {
785 "host": config.get("REDIS", "HOST", fallback="localhost"),
786 "port": config.getint("REDIS", "PORT", fallback=6379),
787 }
788 redis_use_ssl = config.getboolean("REDIS", "SSL", fallback=False)
789 redis_unix_socket_path = config.get(
790 "REDIS", "UNIX_SOCKET_PATH", fallback=None
791 )
793 if redis_unix_socket_path is not None:
794 if redis_use_ssl:
795 LOGGER.warning(
796 "SSL is enabled for Redis, but a UNIX socket is used"
797 )
798 if config.has_option("REDIS", "HOST"):
799 LOGGER.warning(
800 "A host is configured for Redis, but a UNIX socket is used"
801 )
802 if config.has_option("REDIS", "PORT"):
803 LOGGER.warning(
804 "A port is configured for Redis, but a UNIX socket is used"
805 )
806 kwargs.update(
807 {
808 "connection_class": UnixDomainSocketConnection,
809 "path": redis_unix_socket_path,
810 }
811 )
812 else:
813 kwargs.update(redis_host_port_kwargs)
814 if redis_use_ssl:
815 kwargs.update(redis_ssl_kwargs)
817 if not config.getboolean("REDIS", "ENABLED", fallback=False):
818 app.settings["REDIS"] = None
819 return None
820 connection_pool = BlockingConnectionPool(
821 client_name=NAME,
822 decode_responses=True,
823 **kwargs,
824 )
825 redis = cast("Redis[str]", Redis(connection_pool=connection_pool))
826 app.settings["REDIS"] = redis
827 return redis
830def signal_handler( # noqa: D103 # pragma: no cover
831 signalnum: int, frame: None | types.FrameType
832) -> None:
833 # pylint: disable=unused-argument, missing-function-docstring
834 if signalnum in {signal.SIGINT, signal.SIGTERM}:
835 EVENT_SHUTDOWN.set()
836 if signalnum == getattr(signal, "SIGHUP", None):
837 EVENT_SHUTDOWN.set()
840def install_signal_handler() -> None: # pragma: no cover
841 """Install the signal handler."""
842 signal.signal(signal.SIGINT, signal_handler)
843 signal.signal(signal.SIGTERM, signal_handler)
844 if hasattr(signal, "SIGHUP"):
845 signal.signal(signal.SIGHUP, signal_handler)
848def supervise(loop: AbstractEventLoop) -> None:
849 """Supervise."""
850 while foobarbaz := background_tasks.HEARTBEAT: # pylint: disable=while-used
851 if time.monotonic() - foobarbaz >= 10:
852 worker = task_id()
853 pid = os.getpid()
855 task = asyncio.current_task(loop)
856 request = task.get_context().get(request_ctx_var) if task else None
858 LOGGER.fatal(
859 "Heartbeat timed out for worker %s (pid %d), "
860 "current request: %s, current task: %s",
861 worker,
862 pid,
863 request,
864 task,
865 )
866 atexit._run_exitfuncs() # pylint: disable=protected-access
867 os.abort()
868 time.sleep(1)
871def main( # noqa: C901 # pragma: no cover
872 config: BetterConfigParser | None = None,
873) -> int | str:
874 """
875 Start everything.
877 This is the main function that is called when running this programm.
878 """
879 # pylint: disable=too-complex, too-many-branches
880 # pylint: disable=too-many-locals, too-many-statements
881 setproctitle(NAME)
883 install_signal_handler()
885 parser = create_argument_parser()
886 args, _ = parser.parse_known_args(
887 get_arguments_without_help(), ArgparseNamespace()
888 )
890 config = config or BetterConfigParser.from_path(*args.config)
891 assert config is not None
892 config.add_override_argument_parser(parser)
894 setup_logging(config)
896 LOGGER.info("Starting %s %s", NAME, VERSION)
898 if platform.system() == "Windows":
899 LOGGER.warning(
900 "Running %s on Windows is not officially supported",
901 NAME.removesuffix("-dev"),
902 )
904 ignore_modules(config)
905 app = make_app(config)
906 if isinstance(app, str):
907 return app
909 apply_config_to_app(app, config)
910 setup_elasticsearch(app)
911 setup_app_search(app)
912 setup_redis(app)
913 setup_apm(app)
915 behind_proxy = config.getboolean("GENERAL", "BEHIND_PROXY", fallback=False)
917 server = HTTPServer(
918 app,
919 body_timeout=3600,
920 decompress_request=True,
921 max_body_size=1_000_000_000,
922 ssl_options=get_ssl_context(config),
923 xheaders=behind_proxy,
924 )
926 socket_factories: list[Callable[[], Iterable[socket]]] = []
928 port = config.getint("GENERAL", "PORT", fallback=None)
930 if port:
931 socket_factories.append(
932 partial(
933 bind_sockets,
934 port,
935 "localhost" if behind_proxy else "",
936 )
937 )
939 unix_socket_path = config.get(
940 "GENERAL",
941 "UNIX_SOCKET_PATH",
942 fallback=None,
943 )
945 if unix_socket_path:
946 os.makedirs(unix_socket_path, 0o755, True)
947 socket_factories.append(
948 lambda: (
949 bind_unix_socket(
950 os.path.join(unix_socket_path, f"{NAME}.sock"),
951 mode=0o666,
952 ),
953 )
954 )
956 processes = config.getint(
957 "GENERAL",
958 "PROCESSES",
959 fallback=hasattr(os, "fork") * (2 if sys.flags.dev_mode else -1),
960 )
962 if processes < 0:
963 processes = (
964 os.process_cpu_count() # type: ignore[attr-defined]
965 if sys.version_info >= (3, 13)
966 else os.cpu_count()
967 ) or 0
969 worker: None | int = None
971 run_supervisor_thread = config.getboolean(
972 "GENERAL", "SUPERVISE", fallback=False
973 )
974 elasticsearch_is_enabled = config.getboolean(
975 "ELASTICSEARCH", "ENABLED", fallback=False
976 )
977 redis_is_enabled = config.getboolean("REDIS", "ENABLED", fallback=False)
978 webhook_logging_options = WebhookLoggingOptions(config)
979 # all config options should be read before forking
980 if args.save_config_to:
981 with open(args.save_config_to, "w", encoding="UTF-8") as file:
982 config.write(file)
983 config.set_all_options_should_be_parsed()
984 del config
985 # show help message if --help is given (after reading config, before forking)
986 parser.parse_args()
988 if not socket_factories:
989 LOGGER.warning("No sockets configured")
990 return 0
992 # create sockets after checking for --help
993 sockets: list[socket] = (
994 Stream(socket_factories).flat_map(lambda fun: fun()).collect(list)
995 )
997 UPTIME.reset()
998 main_pid = os.getpid()
1000 if processes:
1001 setproctitle(f"{NAME} - Master")
1003 worker = fork_processes(processes)
1005 setproctitle(f"{NAME} - Worker {worker}")
1007 # yeet all children (there should be none, but do it regardless, just in case)
1008 _children.clear()
1010 if "an_website.quotes" in sys.modules:
1011 from .quotes.utils import ( # pylint: disable=import-outside-toplevel
1012 AUTHORS_CACHE,
1013 QUOTES_CACHE,
1014 WRONG_QUOTES_CACHE,
1015 )
1017 del AUTHORS_CACHE.control.created_by_ultra # type: ignore[attr-defined]
1018 del QUOTES_CACHE.control.created_by_ultra # type: ignore[attr-defined]
1019 del WRONG_QUOTES_CACHE.control.created_by_ultra # type: ignore[attr-defined]
1020 del geoip.__kwdefaults__["caches"].control.created_by_ultra
1022 if unix_socket_path:
1023 sockets.append(
1024 bind_unix_socket(
1025 os.path.join(unix_socket_path, f"{NAME}.{worker}.sock"),
1026 mode=0o666,
1027 )
1028 )
1030 # get loop after forking
1031 # if not forking allow loop to be set in advance by external code
1032 loop: None | asyncio.AbstractEventLoop
1033 try:
1034 with catch_warnings(): # TODO: remove after dropping support for 3.13
1035 simplefilter("ignore", DeprecationWarning)
1036 loop = asyncio.get_event_loop()
1037 if loop.is_closed():
1038 loop = None
1039 except RuntimeError:
1040 loop = None
1042 if loop is None:
1043 loop = asyncio.new_event_loop()
1044 asyncio.set_event_loop(loop)
1046 if sys.version_info >= (3, 13) and not loop.get_task_factory():
1047 loop.set_task_factory(asyncio.eager_task_factory)
1049 if perf8 and "PERF8" in os.environ:
1050 loop.run_until_complete(perf8.enable())
1052 setup_webhook_logging(webhook_logging_options, loop)
1054 server.add_sockets(sockets)
1056 tasks = background_tasks.start_background_tasks( # noqa: F841
1057 module_infos=app.settings["MODULE_INFOS"],
1058 loop=loop,
1059 main_pid=main_pid,
1060 app=app,
1061 processes=processes,
1062 elasticsearch_is_enabled=elasticsearch_is_enabled,
1063 redis_is_enabled=redis_is_enabled,
1064 worker=worker,
1065 )
1067 if run_supervisor_thread:
1068 background_tasks.HEARTBEAT = time.monotonic()
1069 threading.Thread(
1070 target=supervise, args=(loop,), name="supervisor", daemon=True
1071 ).start()
1073 try:
1074 loop.run_forever()
1075 EVENT_SHUTDOWN.set()
1076 finally:
1077 try: # pylint: disable=too-many-try-statements
1078 server.stop()
1079 loop.run_until_complete(asyncio.sleep(1))
1080 loop.run_until_complete(server.close_all_connections())
1081 if perf8 and "PERF8" in os.environ:
1082 loop.run_until_complete(perf8.disable())
1083 if redis := app.settings.get("REDIS"):
1084 loop.run_until_complete(
1085 redis.aclose(close_connection_pool=True)
1086 )
1087 if elasticsearch := app.settings.get("ELASTICSEARCH"):
1088 loop.run_until_complete(elasticsearch.close())
1089 finally:
1090 try:
1091 _cancel_all_tasks(loop)
1092 loop.run_until_complete(loop.shutdown_asyncgens())
1093 loop.run_until_complete(loop.shutdown_default_executor())
1094 finally:
1095 loop.close()
1096 background_tasks.HEARTBEAT = 0
1098 return len(tasks)