Coverage for an_website / main.py: 77.350%
234 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-19 18:33 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-19 18:33 +0000
1# This program is free software: you can redistribute it and/or modify
2# it under the terms of the GNU Affero General Public License as
3# published by the Free Software Foundation, either version 3 of the
4# License, or (at your option) any later version.
5#
6# This program is distributed in the hope that it will be useful,
7# but WITHOUT ANY WARRANTY; without even the implied warranty of
8# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9# GNU Affero General Public License for more details.
10#
11# You should have received a copy of the GNU Affero General Public License
12# along with this program. If not, see <https://www.gnu.org/licenses/>.
13# pylint: disable=import-private-name, too-many-lines
15"""
16The website of the AN.
18Loads config and modules and starts Tornado.
19"""
21import asyncio
22import atexit
23import logging
24import os
25import platform
26import signal
27import ssl
28import sys
29import threading
30import time
31import types
32import uuid
33from asyncio import AbstractEventLoop
34from asyncio.runners import _cancel_all_tasks # type: ignore[attr-defined]
35from base64 import b64encode
36from collections.abc import Callable, Iterable, Mapping, MutableSequence
37from configparser import ConfigParser
38from contextlib import suppress
39from functools import partial
40from hashlib import sha256
41from importlib import import_module
42from multiprocessing.process import _children # type: ignore[attr-defined]
43from pathlib import Path
44from socket import socket
45from typing import Any, Final, TypedDict, TypeGuard, cast
46from warnings import catch_warnings, simplefilter
47from zoneinfo import ZoneInfo
49import regex
50from ecs_logging import StdlibFormatter
51from elasticapm.contrib.tornado import ElasticAPM
52from redis.asyncio import (
53 BlockingConnectionPool,
54 Redis,
55 SSLConnection,
56 UnixDomainSocketConnection,
57)
58from setproctitle import setproctitle
59from tornado.httpserver import HTTPServer
60from tornado.log import LogFormatter
61from tornado.netutil import bind_sockets, bind_unix_socket
62from tornado.process import fork_processes, task_id
63from tornado.web import Application, RedirectHandler
64from typed_stream import Stream
66from . import (
67 CA_BUNDLE_PATH,
68 DIR,
69 EVENT_SHUTDOWN,
70 NAME,
71 TEMPLATES_DIR,
72 UPTIME,
73 VERSION,
74 pytest_is_running,
75)
76from .contact.contact import apply_contact_stuff_to_app
77from .utils import background_tasks, static_file_handling
78from .utils.base_request_handler import BaseRequestHandler, request_ctx_var
79from .utils.better_config_parser import BetterConfigParser
80from .utils.elasticsearch_setup import setup_elasticsearch
81from .utils.logging import WebhookFormatter, WebhookHandler
82from .utils.request_handler import NotFoundHandler
83from .utils.static_file_from_traversable import TraversableStaticFileHandler
84from .utils.template_loader import TemplateLoader
85from .utils.utils import (
86 ArgparseNamespace,
87 Handler,
88 ModuleInfo,
89 Permission,
90 Timer,
91 create_argument_parser,
92 geoip,
93 get_arguments_without_help,
94 time_function,
95)
97try:
98 from test.support import has_fork_support # type: ignore[import-not-found]
99except ModuleNotFoundError:
100 has_fork_support = hasattr(os, "fork")
102try:
103 import perf8 # type: ignore[import, unused-ignore]
104except ModuleNotFoundError:
105 perf8 = None # pylint: disable=invalid-name
107IGNORED_MODULES: Final[set[str]] = {
108 "patches",
109 "static",
110 "templates",
111} | (set() if sys.flags.dev_mode or pytest_is_running() else {"example"})
113LOGGER: Final = logging.getLogger(__name__)
116# add all the information from the packages to a list
117# this calls the get_module_info function in every file
118# files and dirs starting with '_' get ignored
119def get_module_infos() -> str | tuple[ModuleInfo, ...]:
120 """Import the modules and return the loaded module infos in a tuple."""
121 module_infos: list[ModuleInfo] = []
122 loaded_modules: list[str] = []
123 errors: list[str] = []
125 for potential_module in DIR.iterdir():
126 if (
127 potential_module.name.startswith("_")
128 or potential_module.name in IGNORED_MODULES
129 or not potential_module.is_dir()
130 ):
131 continue
133 _module_infos = get_module_infos_from_module(
134 potential_module.name, errors, ignore_not_found=True
135 )
136 if _module_infos:
137 module_infos.extend(_module_infos)
138 loaded_modules.append(potential_module.name)
139 LOGGER.debug(
140 (
141 "Found module_infos in %s.__init__.py, "
142 "not searching in other modules in the package."
143 ),
144 potential_module,
145 )
146 continue
148 if f"{potential_module.name}.*" in IGNORED_MODULES:
149 continue
151 for potential_file in potential_module.iterdir():
152 module_name = f"{potential_module.name}.{potential_file.name[:-3]}"
153 if (
154 not potential_file.name.endswith(".py")
155 or module_name in IGNORED_MODULES
156 or potential_file.name.startswith("_")
157 ):
158 continue
159 _module_infos = get_module_infos_from_module(module_name, errors)
160 if _module_infos:
161 module_infos.extend(_module_infos)
162 loaded_modules.append(module_name)
164 if len(errors) > 0:
165 if sys.flags.dev_mode:
166 # exit to make sure it gets fixed
167 return "\n".join(errors)
168 # don't exit in production to keep stuff running
169 LOGGER.error("\n".join(errors))
171 LOGGER.info(
172 "Loaded %d modules: '%s'",
173 len(loaded_modules),
174 "', '".join(loaded_modules),
175 )
177 LOGGER.info(
178 "Ignored %d modules: '%s'",
179 len(IGNORED_MODULES),
180 "', '".join(IGNORED_MODULES),
181 )
183 sort_module_infos(module_infos)
185 # make module_infos immutable so it never changes
186 return tuple(module_infos)
189def get_module_infos_from_module(
190 module_name: str,
191 errors: MutableSequence[str], # gets modified
192 ignore_not_found: bool = False,
193) -> None | list[ModuleInfo]:
194 """Get the module infos based on a module."""
195 import_timer = Timer()
196 module = import_module(
197 f".{module_name}",
198 package="an_website",
199 )
200 if import_timer.stop() > 0.1:
201 LOGGER.warning(
202 "Import of %s took %ss. That's affecting the startup time.",
203 module_name,
204 import_timer.get(),
205 )
207 module_infos: list[ModuleInfo] = []
209 has_get_module_info = "get_module_info" in dir(module)
210 has_get_module_infos = "get_module_infos" in dir(module)
212 if not (has_get_module_info or has_get_module_infos):
213 if ignore_not_found:
214 return None
215 errors.append(
216 f"{module_name} has no 'get_module_info' and no 'get_module_infos' "
217 "method. Please add at least one of the methods or add "
218 f"'{module_name.rsplit('.', 1)[0]}.*' or {module_name!r} to "
219 "IGNORED_MODULES."
220 )
221 return None
223 if has_get_module_info and isinstance(
224 module_info := module.get_module_info(),
225 ModuleInfo,
226 ):
227 module_infos.append(module_info)
228 elif has_get_module_info:
229 errors.append(
230 f"'get_module_info' in {module_name} does not return ModuleInfo. "
231 "Please fix the returned value."
232 )
234 if not has_get_module_infos:
235 return module_infos or None
237 _module_infos = module.get_module_infos()
239 if not isinstance(_module_infos, Iterable):
240 errors.append(
241 f"'get_module_infos' in {module_name} does not return an Iterable. "
242 "Please fix the returned value."
243 )
244 return module_infos or None
246 for _module_info in _module_infos:
247 if isinstance(_module_info, ModuleInfo):
248 module_infos.append(_module_info)
249 else:
250 errors.append(
251 f"'get_module_infos' in {module_name} did return an Iterable "
252 f"with an element of type {type(_module_info)}. "
253 "Please fix the returned value."
254 )
256 return module_infos or None
259def sort_module_infos(module_infos: list[ModuleInfo]) -> None:
260 """Sort a list of module info and move the main page to the top."""
261 # sort it so the order makes sense
262 module_infos.sort()
264 # move the main page to the top
265 for i, info in enumerate(module_infos):
266 if info.path == "/":
267 module_infos.insert(0, module_infos.pop(i))
268 break
271def get_all_handlers(module_infos: Iterable[ModuleInfo]) -> list[Handler]:
272 """
273 Parse the module information and return the handlers in a tuple.
275 If a handler has only 2 elements a dict with title and description
276 gets added. This information is gotten from the module info.
277 """
278 handler: Handler | list[Any]
279 handlers: list[Handler] = static_file_handling.get_handlers()
281 # add all the normal handlers
282 for module_info in module_infos:
283 for handler in module_info.handlers:
284 handler = list(handler) # pylint: disable=redefined-loop-name
285 # if the handler is a request handler from us
286 # and not a built-in like StaticFileHandler & RedirectHandler
287 if issubclass(handler[1], BaseRequestHandler):
288 if len(handler) == 2:
289 # set "default_title" or "default_description" to False so
290 # that module_info.name & module_info.description get used
291 handler.append(
292 {
293 "default_title": False,
294 "default_description": False,
295 "module_info": module_info,
296 }
297 )
298 else:
299 handler[2]["module_info"] = module_info
300 handlers.append(tuple(handler))
302 # redirect handler, to make finding APIs easier
303 handlers.append((r"/(.+)/api/*", RedirectHandler, {"url": "/api/{0}"}))
305 handlers.append(
306 (
307 r"(?i)/\.well-known/(.*)",
308 TraversableStaticFileHandler,
309 {
310 "root": Path(".well-known"),
311 "headers": (("Access-Control-Allow-Origin", "*"),),
312 },
313 )
314 )
316 LOGGER.debug("Loaded %d handlers", len(handlers))
318 return handlers
321def ignore_modules(config: BetterConfigParser) -> None:
322 """Read ignored modules from the config."""
323 IGNORED_MODULES.update(
324 config.getset("GENERAL", "IGNORED_MODULES", fallback=set())
325 )
328def get_normed_paths_from_module_infos(
329 module_infos: Iterable[ModuleInfo],
330) -> dict[str, str]:
331 """Get all paths from the module infos."""
333 def tuple_has_no_none(
334 value: tuple[str | None, str | None],
335 ) -> TypeGuard[tuple[str, str]]:
336 return None not in value
338 def info_to_paths(info: ModuleInfo) -> Stream[tuple[str, str]]:
339 return (
340 Stream(((info.path, info.path),))
341 .chain(
342 info.aliases.items()
343 if isinstance(info.aliases, Mapping)
344 else ((alias, info.path) for alias in info.aliases)
345 )
346 .chain(
347 Stream(info.sub_pages)
348 .map(lambda sub_info: sub_info.path)
349 .filter()
350 .map(lambda path: (path, path))
351 )
352 .filter(tuple_has_no_none)
353 )
355 return (
356 Stream(module_infos)
357 .flat_map(info_to_paths)
358 .filter(lambda p: p[0].startswith("/"))
359 .map(lambda p: (p[0].strip("/").lower(), p[1]))
360 .filter(lambda p: p[0])
361 .collect(dict)
362 )
365def make_app(config: ConfigParser) -> str | Application:
366 """Create the Tornado application and return it."""
367 module_infos, duration = time_function(get_module_infos)
368 if isinstance(module_infos, str):
369 return module_infos
370 if duration > 1:
371 LOGGER.warning(
372 "Getting the module infos took %ss. That's probably too long.",
373 duration,
374 )
375 handlers = get_all_handlers(module_infos)
376 return Application(
377 handlers,
378 MODULE_INFOS=module_infos,
379 SHOW_HAMBURGER_MENU=not Stream(module_infos)
380 .exclude(lambda info: info.hidden)
381 .filter(lambda info: info.path)
382 .empty(),
383 NORMED_PATHS=get_normed_paths_from_module_infos(module_infos),
384 HANDLERS=handlers,
385 # General settings
386 autoreload=False,
387 debug=sys.flags.dev_mode,
388 default_handler_class=NotFoundHandler,
389 compress_response=config.getboolean(
390 "GENERAL", "COMPRESS_RESPONSE", fallback=False
391 ),
392 websocket_ping_interval=10,
393 # Template settings
394 template_loader=TemplateLoader(
395 root=TEMPLATES_DIR, whitespace="oneline"
396 ),
397 )
400def apply_config_to_app(app: Application, config: BetterConfigParser) -> None:
401 """Apply the config (from the config.ini file) to the application."""
402 app.settings["CONFIG"] = config
404 app.settings["cookie_secret"] = config.get(
405 "GENERAL", "COOKIE_SECRET", fallback="xyzzy"
406 )
408 app.settings["CRAWLER_SECRET"] = config.get(
409 "APP_SEARCH", "CRAWLER_SECRET", fallback=None
410 )
412 app.settings["DOMAIN"] = config.get("GENERAL", "DOMAIN", fallback=None)
414 app.settings["ELASTICSEARCH_PREFIX"] = config.get(
415 "ELASTICSEARCH", "PREFIX", fallback=NAME
416 )
418 app.settings["HSTS"] = config.getboolean("TLS", "HSTS", fallback=False)
420 app.settings["NETCUP"] = config.getboolean(
421 "GENERAL", "NETCUP", fallback=False
422 )
424 onion_address = config.get("GENERAL", "ONION_ADDRESS", fallback=None)
425 app.settings["ONION_ADDRESS"] = onion_address
426 if onion_address is None:
427 app.settings["ONION_PROTOCOL"] = None
428 else:
429 app.settings["ONION_PROTOCOL"] = onion_address.split("://")[0]
431 app.settings["RATELIMITS"] = config.getboolean(
432 "GENERAL",
433 "RATELIMITS",
434 fallback=config.getboolean("REDIS", "ENABLED", fallback=False),
435 )
437 app.settings["REDIS_PREFIX"] = config.get("REDIS", "PREFIX", fallback=NAME)
439 app.settings["REPORTING"] = config.getboolean(
440 "REPORTING", "ENABLED", fallback=True
441 )
443 app.settings["REPORTING_BUILTIN"] = config.getboolean(
444 "REPORTING", "BUILTIN", fallback=sys.flags.dev_mode
445 )
447 app.settings["REPORTING_ENDPOINT"] = config.get(
448 "REPORTING",
449 "ENDPOINT",
450 fallback=(
451 "/api/reports"
452 if app.settings["REPORTING_BUILTIN"]
453 else "https://asozial.org/api/reports"
454 ),
455 )
457 app.settings["TRUSTED_API_SECRETS"] = {
458 key_perms[0]: Permission(
459 int(key_perms[1])
460 if len(key_perms) > 1
461 else (1 << len(Permission)) - 1 # should be all permissions
462 )
463 for secret in config.getset(
464 "GENERAL", "TRUSTED_API_SECRETS", fallback={"xyzzy"}
465 )
466 if (key_perms := [part.strip() for part in secret.split("=")])
467 if key_perms[0]
468 }
470 app.settings["AUTH_TOKEN_SECRET"] = config.get(
471 "GENERAL", "AUTH_TOKEN_SECRET", fallback=None
472 )
473 if not app.settings["AUTH_TOKEN_SECRET"]:
474 # pylint: disable-next=import-outside-toplevel
475 from .version.version import hash_bytes
477 node = uuid.getnode().to_bytes(6, "big")
478 secret = hash_bytes(node)
479 LOGGER.warning(
480 "AUTH_TOKEN_SECRET is unset, implicitly setting it to %r",
481 secret,
482 )
483 app.settings["AUTH_TOKEN_SECRET"] = secret
485 app.settings["UNDER_ATTACK"] = config.getboolean(
486 "GENERAL", "UNDER_ATTACK", fallback=False
487 )
489 apply_contact_stuff_to_app(app, config)
492def get_ssl_context( # pragma: no cover
493 config: ConfigParser,
494) -> None | ssl.SSLContext:
495 """Create SSL context and configure using the config."""
496 if config.getboolean("TLS", "ENABLED", fallback=False):
497 ssl_ctx = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
498 ssl_ctx.load_cert_chain(
499 config.get("TLS", "CERTFILE"),
500 config.get("TLS", "KEYFILE", fallback=None),
501 config.get("TLS", "PASSWORD", fallback=None),
502 )
503 return ssl_ctx
504 return None
507def setup_logging( # pragma: no cover
508 config: ConfigParser,
509 force: bool = False,
510) -> None:
511 """Setup logging.""" # noqa: D401
512 root_logger = logging.getLogger()
514 if root_logger.handlers:
515 if not force:
516 return
517 for handler in root_logger.handlers[:]:
518 root_logger.removeHandler(handler)
519 handler.close()
521 debug = config.getboolean("LOGGING", "DEBUG", fallback=sys.flags.dev_mode)
523 logging.captureWarnings(True)
525 root_logger.setLevel(logging.DEBUG if debug else logging.INFO)
526 logging.getLogger("tornado.curl_httpclient").setLevel(logging.INFO)
527 logging.getLogger("elasticsearch").setLevel(logging.INFO)
529 stream_handler = logging.StreamHandler()
530 if sys.flags.dev_mode:
531 spam = regex.sub(r"%\((end_)?color\)s", "", LogFormatter.DEFAULT_FORMAT)
532 formatter = logging.Formatter(spam, LogFormatter.DEFAULT_DATE_FORMAT)
533 else:
534 formatter = LogFormatter()
535 stream_handler.setFormatter(formatter)
536 root_logger.addHandler(stream_handler)
538 if path := config.get("LOGGING", "PATH", fallback=None):
539 os.makedirs(path, 0o755, True)
540 file_handler = logging.handlers.TimedRotatingFileHandler(
541 os.path.join(path, f"{NAME}.log"),
542 encoding="UTF-8",
543 when="midnight",
544 backupCount=30,
545 utc=True,
546 )
547 file_handler.setFormatter(StdlibFormatter())
548 root_logger.addHandler(file_handler)
551class WebhookLoggingOptions: # pylint: disable=too-few-public-methods
552 """Webhook logging options."""
554 __slots__ = (
555 "url",
556 "content_type",
557 "body_format",
558 "timestamp_format",
559 "timestamp_timezone",
560 "escape_message",
561 "max_message_length",
562 )
564 url: str | None
565 content_type: str
566 body_format: str
567 timestamp_format: str | None
568 timestamp_timezone: str | None
569 escape_message: bool
570 max_message_length: int | None
572 def __init__(self, config: ConfigParser) -> None:
573 """Initialize Webhook logging options."""
574 self.url = config.get("LOGGING", "WEBHOOK_URL", fallback=None)
575 self.content_type = config.get(
576 "LOGGING",
577 "WEBHOOK_CONTENT_TYPE",
578 fallback="application/json",
579 )
580 spam = regex.sub(r"%\((end_)?color\)s", "", LogFormatter.DEFAULT_FORMAT)
581 self.body_format = config.get(
582 "LOGGING",
583 "WEBHOOK_BODY_FORMAT",
584 fallback='{"text":"' + spam + '"}',
585 )
586 self.timestamp_format = config.get(
587 "LOGGING",
588 "WEBHOOK_TIMESTAMP_FORMAT",
589 fallback=None,
590 )
591 self.timestamp_timezone = config.get(
592 "LOGGING", "WEBHOOK_TIMESTAMP_TIMEZONE", fallback=None
593 )
594 self.escape_message = config.getboolean(
595 "LOGGING",
596 "WEBHOOK_ESCAPE_MESSAGE",
597 fallback=True,
598 )
599 self.max_message_length = config.getint(
600 "LOGGING", "WEBHOOK_MAX_MESSAGE_LENGTH", fallback=None
601 )
604def setup_webhook_logging( # pragma: no cover
605 options: WebhookLoggingOptions,
606 loop: asyncio.AbstractEventLoop,
607) -> None:
608 """Setup Webhook logging.""" # noqa: D401
609 if not options.url:
610 return
612 LOGGER.info("Setting up Webhook logging")
614 root_logger = logging.getLogger()
616 webhook_content_type = options.content_type
617 webhook_handler = WebhookHandler(
618 logging.ERROR,
619 loop=loop,
620 url=options.url,
621 content_type=webhook_content_type,
622 )
623 formatter = WebhookFormatter(
624 options.body_format,
625 options.timestamp_format,
626 )
627 formatter.timezone = (
628 None
629 if options.timestamp_format is None
630 else ZoneInfo(options.timestamp_format)
631 )
632 formatter.escape_message = options.escape_message
633 formatter.max_message_length = options.max_message_length
634 formatter.get_context_line = lambda _: (
635 f"Request: {request}"
636 if (request := request_ctx_var.get(None))
637 else None
638 )
639 webhook_handler.setFormatter(formatter)
640 root_logger.addHandler(webhook_handler)
642 info_handler = WebhookHandler(
643 logging.INFO,
644 loop=loop,
645 url=options.url,
646 content_type=webhook_content_type,
647 )
648 info_handler.setFormatter(formatter)
649 logging.getLogger("an_website.quotes.create").addHandler(info_handler)
652def setup_apm(app: Application) -> None: # pragma: no cover
653 """Setup APM.""" # noqa: D401
654 config: BetterConfigParser = app.settings["CONFIG"]
655 app.settings["ELASTIC_APM"] = {
656 "ENABLED": config.getboolean("ELASTIC_APM", "ENABLED", fallback=False),
657 "SERVER_URL": config.get(
658 "ELASTIC_APM", "SERVER_URL", fallback="http://localhost:8200"
659 ),
660 "SECRET_TOKEN": config.get(
661 "ELASTIC_APM", "SECRET_TOKEN", fallback=None
662 ),
663 "API_KEY": config.get("ELASTIC_APM", "API_KEY", fallback=None),
664 "VERIFY_SERVER_CERT": config.getboolean(
665 "ELASTIC_APM", "VERIFY_SERVER_CERT", fallback=True
666 ),
667 "USE_CERTIFI": True, # doesn't actually use certifi
668 "SERVICE_NAME": NAME.removesuffix("-dev"),
669 "SERVICE_VERSION": VERSION,
670 "ENVIRONMENT": (
671 "production" if not sys.flags.dev_mode else "development"
672 ),
673 "DEBUG": True,
674 "CAPTURE_BODY": "errors",
675 "TRANSACTION_IGNORE_URLS": [
676 "/api/ping",
677 "/static/*",
678 "/favicon.png",
679 ],
680 "TRANSACTIONS_IGNORE_PATTERNS": ["^OPTIONS "],
681 "PROCESSORS": [
682 "an_website.utils.utils.apm_anonymization_processor",
683 "elasticapm.processors.sanitize_stacktrace_locals",
684 "elasticapm.processors.sanitize_http_request_cookies",
685 "elasticapm.processors.sanitize_http_headers",
686 "elasticapm.processors.sanitize_http_wsgi_env",
687 "elasticapm.processors.sanitize_http_request_body",
688 ],
689 "RUM_SERVER_URL": config.get(
690 "ELASTIC_APM", "RUM_SERVER_URL", fallback=None
691 ),
692 "RUM_SERVER_URL_PREFIX": config.get(
693 "ELASTIC_APM", "RUM_SERVER_URL_PREFIX", fallback=None
694 ),
695 }
697 script_options = [
698 f"serviceName:{app.settings['ELASTIC_APM']['SERVICE_NAME']!r}",
699 f"serviceVersion:{app.settings['ELASTIC_APM']['SERVICE_VERSION']!r}",
700 f"environment:{app.settings['ELASTIC_APM']['ENVIRONMENT']!r}",
701 ]
703 rum_server_url = app.settings["ELASTIC_APM"]["RUM_SERVER_URL"]
705 if rum_server_url is None:
706 script_options.append(
707 f"serverUrl:{app.settings['ELASTIC_APM']['SERVER_URL']!r}"
708 )
709 elif rum_server_url:
710 script_options.append(f"serverUrl:{rum_server_url!r}")
711 else:
712 script_options.append("serverUrl:window.location.origin")
714 if app.settings["ELASTIC_APM"]["RUM_SERVER_URL_PREFIX"]:
715 script_options.append(
716 f"serverUrlPrefix:{app.settings['ELASTIC_APM']['RUM_SERVER_URL_PREFIX']!r}"
717 )
719 app.settings["ELASTIC_APM"]["INLINE_SCRIPT"] = (
720 "elasticApm.init({" + ",".join(script_options) + "})"
721 )
723 app.settings["ELASTIC_APM"]["INLINE_SCRIPT_HASH"] = b64encode(
724 sha256(
725 app.settings["ELASTIC_APM"]["INLINE_SCRIPT"].encode("ASCII")
726 ).digest()
727 ).decode("ASCII")
729 if app.settings["ELASTIC_APM"]["ENABLED"]:
730 app.settings["ELASTIC_APM"]["CLIENT"] = ElasticAPM(app).client
733def setup_app_search(app: Application) -> None: # pragma: no cover
734 """Setup Elastic App Search.""" # noqa: D401
735 config: BetterConfigParser = app.settings["CONFIG"]
736 host = config.get("APP_SEARCH", "HOST", fallback=None)
737 key = config.get("APP_SEARCH", "SEARCH_KEY", fallback=None)
738 verify_certs = config.getboolean(
739 "APP_SEARCH", "VERIFY_CERTS", fallback=True
740 )
741 app.settings["APP_SEARCH_HOST"] = host
742 app.settings["APP_SEARCH_KEY"] = key
743 app.settings["APP_SEARCH_ENGINE"] = config.get(
744 "APP_SEARCH", "ENGINE_NAME", fallback=NAME.removesuffix("-dev")
745 )
747 app_search: object = None
748 try:
749 with catch_warnings():
750 simplefilter("ignore", DeprecationWarning)
751 # pylint: disable-next=import-outside-toplevel
752 from elastic_enterprise_search import ( # type: ignore[import-untyped]
753 AppSearch,
754 )
755 except ModuleNotFoundError:
756 LOGGER.log(
757 logging.ERROR if host else logging.INFO,
758 "elastic-enterprise-search is not installed",
759 exc_info=True,
760 )
761 else:
762 if host:
763 app_search = AppSearch(
764 host,
765 bearer_auth=key,
766 verify_certs=verify_certs,
767 ca_certs=CA_BUNDLE_PATH,
768 )
770 app.settings["APP_SEARCH"] = app_search
773def setup_redis(app: Application) -> None | Redis[str]:
774 """Setup Redis.""" # noqa: D401
775 config: BetterConfigParser = app.settings["CONFIG"]
777 class Kwargs(TypedDict, total=False):
778 """Kwargs of BlockingConnectionPool constructor."""
780 db: int
781 username: None | str
782 password: None | str
783 retry_on_timeout: bool
784 connection_class: type[UnixDomainSocketConnection] | type[SSLConnection]
785 path: str
786 host: str
787 port: int
788 ssl_ca_certs: str
789 ssl_keyfile: None | str
790 ssl_certfile: None | str
791 ssl_check_hostname: bool
792 ssl_cert_reqs: str
794 kwargs: Kwargs = {
795 "db": config.getint("REDIS", "DB", fallback=0),
796 "username": config.get("REDIS", "USERNAME", fallback=None),
797 "password": config.get("REDIS", "PASSWORD", fallback=None),
798 "retry_on_timeout": config.getboolean(
799 "REDIS", "RETRY_ON_TIMEOUT", fallback=False
800 ),
801 }
802 redis_ssl_kwargs: Kwargs = {
803 "connection_class": SSLConnection,
804 "ssl_ca_certs": CA_BUNDLE_PATH,
805 "ssl_keyfile": config.get("REDIS", "SSL_KEYFILE", fallback=None),
806 "ssl_certfile": config.get("REDIS", "SSL_CERTFILE", fallback=None),
807 "ssl_cert_reqs": config.get(
808 "REDIS", "SSL_CERT_REQS", fallback="required"
809 ),
810 "ssl_check_hostname": config.getboolean(
811 "REDIS", "SSL_CHECK_HOSTNAME", fallback=False
812 ),
813 }
814 redis_host_port_kwargs: Kwargs = {
815 "host": config.get("REDIS", "HOST", fallback="localhost"),
816 "port": config.getint("REDIS", "PORT", fallback=6379),
817 }
818 redis_use_ssl = config.getboolean("REDIS", "SSL", fallback=False)
819 redis_unix_socket_path = config.get(
820 "REDIS", "UNIX_SOCKET_PATH", fallback=None
821 )
823 if redis_unix_socket_path is not None:
824 if redis_use_ssl:
825 LOGGER.warning(
826 "SSL is enabled for Redis, but a UNIX socket is used"
827 )
828 if config.has_option("REDIS", "HOST"):
829 LOGGER.warning(
830 "A host is configured for Redis, but a UNIX socket is used"
831 )
832 if config.has_option("REDIS", "PORT"):
833 LOGGER.warning(
834 "A port is configured for Redis, but a UNIX socket is used"
835 )
836 kwargs.update(
837 {
838 "connection_class": UnixDomainSocketConnection,
839 "path": redis_unix_socket_path,
840 }
841 )
842 else:
843 kwargs.update(redis_host_port_kwargs)
844 if redis_use_ssl:
845 kwargs.update(redis_ssl_kwargs)
847 if not config.getboolean("REDIS", "ENABLED", fallback=False):
848 app.settings["REDIS"] = None
849 return None
850 connection_pool = BlockingConnectionPool(
851 client_name=NAME,
852 decode_responses=True,
853 **kwargs,
854 )
855 redis = cast("Redis[str]", Redis(connection_pool=connection_pool))
856 app.settings["REDIS"] = redis
857 return redis
860def signal_handler( # noqa: D103 # pragma: no cover
861 signalnum: int, frame: None | types.FrameType
862) -> None:
863 # pylint: disable=unused-argument, missing-function-docstring
864 if signalnum in {signal.SIGINT, signal.SIGTERM}:
865 EVENT_SHUTDOWN.set()
866 if signalnum == getattr(signal, "SIGHUP", None):
867 EVENT_SHUTDOWN.set()
870def install_signal_handler() -> None: # pragma: no cover
871 """Install the signal handler."""
872 signal.signal(signal.SIGINT, signal_handler)
873 signal.signal(signal.SIGTERM, signal_handler)
874 if hasattr(signal, "SIGHUP"):
875 signal.signal(signal.SIGHUP, signal_handler)
878def supervise(loop: AbstractEventLoop) -> None:
879 """Supervise."""
880 while foobarbaz := background_tasks.HEARTBEAT: # pylint: disable=while-used
881 if time.monotonic() - foobarbaz >= 10:
882 worker = task_id()
883 pid = os.getpid()
885 task = asyncio.current_task(loop)
886 request = task.get_context().get(request_ctx_var) if task else None
888 LOGGER.fatal(
889 "Heartbeat timed out for worker %s (pid %d), "
890 "current request: %s, current task: %s",
891 worker,
892 pid,
893 request,
894 task,
895 )
896 atexit._run_exitfuncs() # pylint: disable=protected-access
897 os.abort()
898 time.sleep(1)
901type EventLoopFactory = Callable[[], asyncio.AbstractEventLoop]
904def get_default_event_loop_factory() -> EventLoopFactory:
905 """Get the preferred event loop factory."""
906 loop_factory = asyncio.new_event_loop
908 if os.environ.get("DISABLE_UVLOOP") not in {
909 "y",
910 "yes",
911 "t",
912 "true",
913 "on",
914 "1",
915 }:
916 with suppress(ModuleNotFoundError):
917 loop_factory = import_module("uvloop").new_event_loop
919 return loop_factory
922def main( # noqa: C901 # pragma: no cover
923 config: BetterConfigParser | None = None,
924 loop_factory: None | EventLoopFactory = None,
925) -> int | str:
926 """
927 Start everything.
929 This is the main function that is called when running this program.
930 """
931 # pylint: disable=too-complex, too-many-branches
932 # pylint: disable=too-many-locals, too-many-statements
933 if loop_factory is None:
934 loop_factory = get_default_event_loop_factory()
936 setproctitle(NAME)
938 install_signal_handler()
940 parser = create_argument_parser()
941 args, _ = parser.parse_known_args(
942 get_arguments_without_help(), ArgparseNamespace()
943 )
945 if args.version:
946 print("Version: ", end="", flush=True, file=sys.stderr)
947 print(VERSION, flush=True)
948 if args.verbose:
949 # pylint: disable-next=import-outside-toplevel
950 from .version.version import (
951 get_file_hashes,
952 get_hash_of_file_hashes,
953 )
955 print()
956 print("Hash der Datei-Hashes:")
957 print(get_hash_of_file_hashes())
959 if args.verbose > 1:
960 print()
961 print("Datei-Hashes:")
962 print(get_file_hashes())
964 return 0
966 config = config or BetterConfigParser.from_path(*args.config)
967 assert config is not None
968 config.add_override_argument_parser(parser)
970 setup_logging(config)
972 LOGGER.info("Starting %s %s", NAME, VERSION)
974 if platform.system() == "Windows":
975 LOGGER.warning(
976 "Running %s on Windows is not officially supported",
977 NAME.removesuffix("-dev"),
978 )
980 ignore_modules(config)
981 app = make_app(config)
982 if isinstance(app, str):
983 return app
985 apply_config_to_app(app, config)
986 setup_elasticsearch(app)
987 setup_app_search(app)
988 setup_redis(app)
989 setup_apm(app)
991 behind_proxy = config.getboolean("GENERAL", "BEHIND_PROXY", fallback=False)
993 server = HTTPServer(
994 app,
995 body_timeout=3600,
996 decompress_request=True,
997 max_body_size=1_000_000_000,
998 ssl_options=get_ssl_context(config),
999 xheaders=behind_proxy,
1000 )
1002 socket_factories: list[Callable[[], Iterable[socket]]] = []
1004 port = config.getint("GENERAL", "PORT", fallback=None)
1006 if port:
1007 socket_factories.append(
1008 partial(
1009 bind_sockets,
1010 port,
1011 "localhost" if behind_proxy else "",
1012 )
1013 )
1015 unix_socket_path = config.get(
1016 "GENERAL",
1017 "UNIX_SOCKET_PATH",
1018 fallback=None,
1019 )
1021 if unix_socket_path:
1022 os.makedirs(unix_socket_path, 0o755, True)
1023 socket_factories.append(
1024 lambda: (
1025 bind_unix_socket(
1026 os.path.join(unix_socket_path, f"{NAME}.sock"),
1027 mode=0o666,
1028 ),
1029 )
1030 )
1032 processes = config.getint(
1033 "GENERAL",
1034 "PROCESSES",
1035 fallback=has_fork_support * (2 if sys.flags.dev_mode else -1),
1036 )
1038 if processes < 0:
1039 processes = os.process_cpu_count() or 0
1041 worker: None | int = None
1043 run_supervisor_thread = config.getboolean(
1044 "GENERAL", "SUPERVISE", fallback=False
1045 )
1046 elasticsearch_is_enabled = config.getboolean(
1047 "ELASTICSEARCH", "ENABLED", fallback=False
1048 )
1049 redis_is_enabled = config.getboolean("REDIS", "ENABLED", fallback=False)
1050 webhook_logging_options = WebhookLoggingOptions(config)
1051 # all config options should be read before forking
1052 if args.save_config_to:
1053 with open(args.save_config_to, "w", encoding="UTF-8") as file:
1054 config.write(file)
1055 config.set_all_options_should_be_parsed()
1056 del config
1057 # show help message if --help is given (after reading config, before forking)
1058 parser.parse_args()
1060 if not socket_factories:
1061 LOGGER.warning("No sockets configured")
1062 return 0
1064 # create sockets after checking for --help
1065 sockets: list[socket] = (
1066 Stream(socket_factories).flat_map(lambda fun: fun()).collect(list)
1067 )
1069 UPTIME.reset()
1070 main_pid = os.getpid()
1072 if processes:
1073 setproctitle(f"{NAME} - Master")
1075 worker = fork_processes(processes)
1077 setproctitle(f"{NAME} - Worker {worker}")
1079 # yeet all children (there should be none, but do it regardless, just in case)
1080 _children.clear()
1082 if "an_website.quotes" in sys.modules:
1083 from .quotes.utils import ( # pylint: disable=import-outside-toplevel
1084 AUTHORS_CACHE,
1085 QUOTES_CACHE,
1086 WRONG_QUOTES_CACHE,
1087 )
1089 del AUTHORS_CACHE.control.created_by_ultra # type: ignore[attr-defined]
1090 del QUOTES_CACHE.control.created_by_ultra # type: ignore[attr-defined]
1091 del WRONG_QUOTES_CACHE.control.created_by_ultra # type: ignore[attr-defined]
1092 del (geoip.__kwdefaults__ or {})["caches"].control.created_by_ultra
1094 if unix_socket_path:
1095 sockets.append(
1096 bind_unix_socket(
1097 os.path.join(unix_socket_path, f"{NAME}.{worker}.sock"),
1098 mode=0o666,
1099 )
1100 )
1102 # get loop after forking
1103 # if not forking allow loop to be set in advance by external code
1104 loop: None | asyncio.AbstractEventLoop
1105 try:
1106 loop = asyncio.get_event_loop()
1107 if loop.is_closed():
1108 loop = None
1109 except RuntimeError:
1110 loop = None
1112 if loop is None:
1113 loop = loop_factory()
1114 asyncio.set_event_loop(loop)
1116 if not loop.get_task_factory():
1117 loop.set_task_factory(asyncio.eager_task_factory)
1119 if perf8 and "PERF8" in os.environ:
1120 loop.run_until_complete(perf8.enable())
1122 setup_webhook_logging(webhook_logging_options, loop)
1124 server.add_sockets(sockets)
1126 tasks = background_tasks.start_background_tasks( # noqa: F841
1127 module_infos=app.settings["MODULE_INFOS"],
1128 loop=loop,
1129 main_pid=main_pid,
1130 app=app,
1131 processes=processes,
1132 elasticsearch_is_enabled=elasticsearch_is_enabled,
1133 redis_is_enabled=redis_is_enabled,
1134 worker=worker,
1135 )
1137 if run_supervisor_thread:
1138 background_tasks.HEARTBEAT = time.monotonic()
1139 threading.Thread(
1140 target=supervise, args=(loop,), name="supervisor", daemon=True
1141 ).start()
1143 try:
1144 loop.run_forever()
1145 EVENT_SHUTDOWN.set()
1146 finally:
1147 try: # pylint: disable=too-many-try-statements
1148 server.stop()
1149 loop.run_until_complete(asyncio.sleep(1))
1150 loop.run_until_complete(server.close_all_connections())
1151 if perf8 and "PERF8" in os.environ:
1152 loop.run_until_complete(perf8.disable())
1153 if redis := app.settings.get("REDIS"):
1154 loop.run_until_complete(
1155 redis.aclose(close_connection_pool=True)
1156 )
1157 if elasticsearch := app.settings.get("ELASTICSEARCH"):
1158 loop.run_until_complete(elasticsearch.close())
1159 finally:
1160 try:
1161 _cancel_all_tasks(loop)
1162 loop.run_until_complete(loop.shutdown_asyncgens())
1163 loop.run_until_complete(loop.shutdown_default_executor())
1164 finally:
1165 loop.close()
1166 background_tasks.HEARTBEAT = 0
1168 return len(tasks)