Coverage for an_website/emoji_chat/pub_sub_provider.py: 55.172%

29 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-11-16 19:56 +0000

1# This program is free software: you can redistribute it and/or modify 

2# it under the terms of the GNU Affero General Public License as 

3# published by the Free Software Foundation, either version 3 of the 

4# License, or (at your option) any later version. 

5# 

6# This program is distributed in the hope that it will be useful, 

7# but WITHOUT ANY WARRANTY; without even the implied warranty of 

8# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 

9# GNU Affero General Public License for more details. 

10# 

11# You should have received a copy of the GNU Affero General Public License 

12# along with this program. If not, see <https://www.gnu.org/licenses/>. 

13 

14"""A provider of the Redis PubSub class.""" 

15 

16from __future__ import annotations 

17 

18import logging 

19from collections.abc import Collection 

20from dataclasses import dataclass 

21from typing import Any, Final 

22 

23from redis.asyncio.client import PubSub, Redis 

24 

25from .. import EVENT_REDIS 

26 

27LOGGER: Final = logging.getLogger(__name__) 

28 

29 

30@dataclass(slots=True) 

31class PubSubProvider: 

32 """Provide a PubSub object.""" 

33 

34 channels: Collection[str] 

35 settings: dict[str, Any] 

36 worker: int | None 

37 _ps: PubSub | None = None 

38 _redis: Redis[str] | None = None 

39 

40 async def __call__(self) -> PubSub: 

41 """Get PubSub object.""" 

42 if not self.settings.get("REDIS"): 

43 LOGGER.error("Redis not available on worker %s", self.worker) 

44 

45 await EVENT_REDIS.wait() 

46 

47 redis: Redis[str] = self.settings["REDIS"] 

48 

49 if self._ps: 

50 if self._redis == redis: 

51 return self._ps 

52 LOGGER.info( 

53 "Closing old PubSub connection on worker %s", self.worker 

54 ) 

55 await self._ps.close() 

56 

57 self._ps = redis.pubsub() 

58 self._redis = redis 

59 await self._ps.subscribe(*self.channels) 

60 return self._ps