aboutsummaryrefslogtreecommitdiffstats
path: root/pydis_site/apps/api/models/bot/metricity.py
blob: 7d5d5f0be673314a773678c6cb02d3a3c8198800 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from django.db import connections

BLOCK_INTERVAL = 10 * 60  # 10 minute blocks

# This needs to be a list due to psycopg3 type adaptions.
EXCLUDE_CHANNELS = [
    "267659945086812160",  # Bot commands
    "607247579608121354"  # SeasonalBot commands
]


class NotFoundError(Exception):
    """Raised when an entity cannot be found."""



class Metricity:
    """Abstraction for a connection to the metricity database."""

    def __init__(self):
        self.cursor = connections['metricity'].cursor()

    def __enter__(self):
        return self

    def __exit__(self, *_):
        self.cursor.close()

    def user(self, user_id: str) -> dict:
        """Query a user's data."""
        # TODO: Swap this back to some sort of verified at date
        query = "SELECT joined_at FROM users WHERE id = '%s'"
        self.cursor.execute(query, [user_id])
        values = self.cursor.fetchone()

        if not values:
            raise NotFoundError

        return {'joined_at': values[0]}

    def total_messages(self, user_id: str) -> int:
        """Query total number of messages for a user."""
        self.cursor.execute(
            """
            SELECT
                COUNT(*)
            FROM messages
            WHERE
                author_id = '%s'
                AND NOT is_deleted
                AND channel_id != ALL(%s)
            """,
            [user_id, EXCLUDE_CHANNELS]
        )
        values = self.cursor.fetchone()

        if not values:
            raise NotFoundError

        return values[0]

    def total_message_blocks(self, user_id: str) -> int:
        """
        Query number of 10 minute blocks during which the user has been active.

        This metric prevents users from spamming to achieve the message total threshold.
        """
        self.cursor.execute(
            """
            SELECT
                COUNT(*)
            FROM (
                SELECT
                    (floor((extract('epoch' from created_at) / %s )) * %s) AS interval
                FROM messages
                WHERE
                    author_id='%s'
                    AND NOT is_deleted
                    AND channel_id != ALL(%s)
                GROUP BY interval
            ) block_query;
            """,
            [BLOCK_INTERVAL, BLOCK_INTERVAL, user_id, EXCLUDE_CHANNELS]
        )
        values = self.cursor.fetchone()

        if not values:
            raise NotFoundError

        return values[0]

    def top_channel_activity(self, user_id: str) -> list[tuple[str, int]]:
        """
        Query the top three channels in which the user is most active.

        Help channels are grouped under "the help channels",
        and off-topic channels are grouped under "off-topic".
        """
        self.cursor.execute(
            """
            SELECT
                CASE
                    WHEN channels.name ILIKE 'help-%%' THEN 'the help channels'
                    WHEN channels.name ILIKE 'ot%%' THEN 'off-topic'
                    WHEN channels.name ILIKE '%%voice%%' THEN 'voice chats'
                    ELSE channels.name
                END,
                COUNT(1)
            FROM
                messages
                LEFT JOIN channels ON channels.id = messages.channel_id
            WHERE
                author_id = '%s' AND NOT messages.is_deleted
            GROUP BY
                1
            ORDER BY
                2 DESC
            LIMIT
                3;
            """,
            [user_id]
        )

        values = self.cursor.fetchall()

        if not values:
            raise NotFoundError

        return values

    def total_messages_in_past_n_days(
        self,
        user_ids: list[str],
        days: int
    ) -> list[tuple[str, int]]:
        """
        Query activity by a list of users in the past `days` days.

        Returns a list of (user_id, message_count) tuples.
        """
        self.cursor.execute(
            """
            SELECT
                author_id, COUNT(*)
            FROM messages
            WHERE
                author_id = ANY(%s)
                AND NOT is_deleted
                AND channel_id != ALL(%s)
                AND created_at > now() - interval '%s days'
            GROUP BY author_id
            """,
            [user_ids, EXCLUDE_CHANNELS, days]
        )
        values = self.cursor.fetchall()

        return values