Coverage for an_website/utils/elasticsearch_setup.py: 47.368%
76 statements
« prev ^ index » next coverage.py v7.8.2, created at 2025-06-01 08:32 +0000
« prev ^ index » next coverage.py v7.8.2, created at 2025-06-01 08:32 +0000
1# This program is free software: you can redistribute it and/or modify
2# it under the terms of the GNU Affero General Public License as
3# published by the Free Software Foundation, either version 3 of the
4# License, or (at your option) any later version.
5#
6# This program is distributed in the hope that it will be useful,
7# but WITHOUT ANY WARRANTY; without even the implied warranty of
8# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
9# GNU Affero General Public License for more details.
10#
11# You should have received a copy of the GNU Affero General Public License
12# along with this program. If not, see <https://www.gnu.org/licenses/>.
13"""Functions for setting up Elasticsearch."""
14from __future__ import annotations
16import asyncio
17import logging
18from collections.abc import Awaitable, Callable
19from typing import Final, Literal, TypeAlias, TypedDict, cast
21import orjson
22from elastic_transport import ObjectApiResponse
23from elasticsearch import AsyncElasticsearch, NotFoundError
24from tornado.web import Application
26from .. import CA_BUNDLE_PATH, DIR
27from .better_config_parser import BetterConfigParser
28from .fix_static_path_impl import recurse_directory
29from .utils import none_to_default
31LOGGER: Final = logging.getLogger(__name__)
33ES_WHAT_LITERAL: TypeAlias = Literal[ # pylint: disable=invalid-name
34 "component_templates", "index_templates", "ingest_pipelines"
35]
36ES_WHAT_LITERALS: tuple[ES_WHAT_LITERAL, ...] = (
37 "ingest_pipelines",
38 "component_templates",
39 "index_templates",
40)
41type AnyArgsAsyncMethod = Callable[..., Awaitable[ObjectApiResponse[object]]]
44async def setup_elasticsearch_configs(
45 elasticsearch: AsyncElasticsearch,
46 prefix: str,
47) -> None:
48 """Setup Elasticsearch configs.""" # noqa: D401
49 spam: list[Awaitable[None | ObjectApiResponse[object]]]
51 for i in range(3):
52 spam = []
54 what: ES_WHAT_LITERAL = ES_WHAT_LITERALS[i]
56 base_path = DIR / "elasticsearch" / what
58 for rel_path in recurse_directory(
59 base_path, lambda path: path.name.endswith(".json")
60 ):
61 path = base_path / rel_path
62 if not path.is_file():
63 LOGGER.warning("%s is not a file", path)
64 continue
66 body = orjson.loads(
67 path.read_bytes().replace(b"{prefix}", prefix.encode("ASCII"))
68 )
70 name = f"{prefix}-{rel_path[:-5].replace('/', '-')}"
72 spam.append(
73 setup_elasticsearch_config(
74 elasticsearch, what, body, name, rel_path
75 )
76 )
78 await asyncio.gather(*spam)
81async def setup_elasticsearch_config(
82 es: AsyncElasticsearch,
83 what: ES_WHAT_LITERAL,
84 body: dict[str, object],
85 name: str,
86 path: str = "<unknown>",
87) -> None | ObjectApiResponse[object]:
88 """Setup Elasticsearch config.""" # noqa: D401
89 if what == "component_templates":
90 get: AnyArgsAsyncMethod = es.cluster.get_component_template
91 put: AnyArgsAsyncMethod = es.cluster.put_component_template
92 elif what == "index_templates":
93 get = es.indices.get_index_template
94 put = es.indices.put_index_template
95 elif what == "ingest_pipelines":
96 get = es.ingest.get_pipeline
97 put = es.ingest.put_pipeline
98 else:
99 raise AssertionError()
101 try:
102 if what == "ingest_pipelines":
103 current = await get(id=name)
104 current_version = current[name].get("version", 1)
105 else:
106 current = await get(
107 name=name, filter_path=f"{what}.name,{what}.version"
108 )
109 current_version = current[what][0].get("version", 1)
110 except NotFoundError:
111 current_version = 0
113 if current_version < body.get("version", 1):
114 if what == "ingest_pipelines":
115 return await put(id=name, body=body)
116 return await put(name=name, body=body)
118 if current_version > body.get("version", 1):
119 LOGGER.warning(
120 "%s has version %s. The version in Elasticsearch is %s!",
121 path,
122 body.get("version", 1),
123 current_version,
124 )
126 return None
129def setup_elasticsearch(app: Application) -> None | AsyncElasticsearch:
130 """Setup Elasticsearch.""" # noqa: D401
131 # pylint: disable-next=import-outside-toplevel
132 from elastic_transport.client_utils import DEFAULT, DefaultType
134 config: BetterConfigParser = app.settings["CONFIG"]
135 basic_auth: tuple[str | None, str | None] = (
136 config.get("ELASTICSEARCH", "USERNAME", fallback=None),
137 config.get("ELASTICSEARCH", "PASSWORD", fallback=None),
138 )
140 class Kwargs(TypedDict):
141 """Kwargs of AsyncElasticsearch constructor."""
143 hosts: tuple[str, ...] | None
144 cloud_id: None | str
145 verify_certs: bool
146 api_key: None | str
147 bearer_auth: None | str
148 client_cert: str | DefaultType
149 client_key: str | DefaultType
150 retry_on_timeout: bool | DefaultType
152 kwargs: Kwargs = {
153 "hosts": (
154 tuple(config.getset("ELASTICSEARCH", "HOSTS"))
155 if config.has_option("ELASTICSEARCH", "HOSTS")
156 else None
157 ),
158 "cloud_id": config.get("ELASTICSEARCH", "CLOUD_ID", fallback=None),
159 "verify_certs": config.getboolean(
160 "ELASTICSEARCH", "VERIFY_CERTS", fallback=True
161 ),
162 "api_key": config.get("ELASTICSEARCH", "API_KEY", fallback=None),
163 "bearer_auth": config.get(
164 "ELASTICSEARCH", "BEARER_AUTH", fallback=None
165 ),
166 "client_cert": none_to_default(
167 config.get("ELASTICSEARCH", "CLIENT_CERT", fallback=None), DEFAULT
168 ),
169 "client_key": none_to_default(
170 config.get("ELASTICSEARCH", "CLIENT_KEY", fallback=None), DEFAULT
171 ),
172 "retry_on_timeout": none_to_default(
173 config.getboolean(
174 "ELASTICSEARCH", "RETRY_ON_TIMEOUT", fallback=None
175 ),
176 DEFAULT,
177 ),
178 }
179 if not config.getboolean("ELASTICSEARCH", "ENABLED", fallback=False):
180 app.settings["ELASTICSEARCH"] = None
181 return None
182 elasticsearch = AsyncElasticsearch(
183 basic_auth=(
184 None if None in basic_auth else cast(tuple[str, str], basic_auth)
185 ),
186 ca_certs=CA_BUNDLE_PATH,
187 **kwargs,
188 )
189 app.settings["ELASTICSEARCH"] = elasticsearch
190 return elasticsearch