aboutsummaryrefslogtreecommitdiffstats
path: root/botcore/redis_message_relay.py
blob: f90e4dab1f3c865331507125316bdf1fbb1f999a (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import asyncio
import inspect
import json
from typing import Callable, Union

import aioredis


class RedisMessageRelay:
    """A class for relaying messages across services using redis lists and pubsub."""

    def __init__(
            self,
            name: str,
            redis_channel: Union[aioredis.Channel, str],
            redis_list: str,
            redis_pool: aioredis.Redis,
    ) -> None:
        self.name = name
        self.channel = redis_channel
        self.redis_list = redis_list

        self.redis = redis_pool

    def __repr__(self) -> str:
        return f"<{self.__class__.__name__} name={self.name} channel={self.channel} list={self.redis_list}>"


class RedisMessageProducer(RedisMessageRelay):

    def __init__(
            self,
            name: str,
            redis_channel: Union[aioredis.Channel, str],
            redis_list: str,
            redis_pool: aioredis.Redis,
    ) -> None:
        super().__init__(name, redis_channel, redis_list, redis_pool)

    async def relay(self, data: dict) -> int:
        """
        Push message and notify consumer.

        returns:
            - Number of subscribers the notification was delivered to.
        """
        serialised = json.dumps(data)

        await self.redis.rpush(self.redis_list, serialised)

        # Notify consumer about new message.
        subs: int = await self.redis.publish_json(
            self.channel,
            {"pushed": True}
        )
        return subs


class RedisMessageConsumer(RedisMessageRelay):

    def __init__(
            self,
            name: str,
            redis_channel: Union[aioredis.Channel, str],
            redis_list: str,
            redis_pool: aioredis.Redis,
            callback: Callable,
            loop: asyncio.AbstractEventLoop
    ) -> None:
        super().__init__(name, redis_channel, redis_list, redis_pool)

        self.callback = callback
        self.loop = loop

    async def listen(self) -> None:
        res = await self.redis.subscribe(self.channel)
        self.channel = res[0]
        await self.wait_until_message()

    async def pre_callback(self, data: dict):
        if inspect.iscoroutinefunction(self.callback):
            await self.callback(data)
            return

        self.callback(data)

    async def read_redis_list(self) -> None:
        serialised = await self.redis.lpop(self.redis_list)
        try:
            data = json.loads(serialised)
        except (json.JSONDecodeError, TypeError):
            pass
        else:
            await self.pre_callback(data)

    async def wait_until_message(self) -> None:
        """Receive message from a channel."""
        # Empty queue before receiving messages.
        queue: list = await self.redis.lrange(self.redis_list, 0, -1)
        await self.redis.delete(self.redis_list)

        for item in queue:
            await self.pre_callback(item)

        while await self.channel.wait_message():
            data = await self.channel.get_json()

            if data["pushed"]:
                await self.read_redis_list()