aboutsummaryrefslogtreecommitdiffstats
path: root/arthur/apis/directory/ldap.py
blob: c919f551e81fc407f828e52f211301ce64047b66 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
"""API utilities for fetching data from the LDAP directory."""

from dataclasses import dataclass
from functools import cache

try:
    from bonsai import LDAPClient, LDAPDN, LDAPSearchScope

    BONSAI_AVAILABLE = True
except ImportError:
    BONSAI_AVAILABLE = False

from arthur.config import CONFIG
from arthur.constants import LDAP_ROLE_MAPPING


@dataclass
class LDAPUser:
    """A user in the LDAP directory."""

    uid: str
    employee_number: str | None = None
    display_name: str | None = None
    groups: list[str] | None = None


def prepare_dn(dn: str) -> str:
    """Prepare a DN into a fully qualified DN."""
    return f"{dn},{CONFIG.ldap_base_dn}"


def get_cn(dn: str) -> str:
    """Parse a DN and return the cn (or first attribute) in it."""
    parsed = LDAPDN(dn)

    return parsed.rdns[0][0][1]


@cache
def create_client() -> "LDAPClient":
    """Create an LDAP client with the configured settings."""
    client = LDAPClient(str(CONFIG.ldap_host), tls=True)

    client.set_credentials(
        "SIMPLE", CONFIG.ldap_bind_user, CONFIG.ldap_bind_password.get_secret_value()
    )
    client.set_ca_cert(CONFIG.ldap_certificate_location.as_posix())

    return client


async def find_users() -> list[LDAPUser]:
    """Find all users in the LDAP directory."""
    client = create_client()

    found_users = []

    async with client.connect(is_async=True) as conn:
        users = await conn.search(
            prepare_dn("cn=users,cn=accounts"),
            LDAPSearchScope.SUBTREE,
            "(mail=*@pydis.wtf)",
            ["uid", "employeeNumber", "displayName", "memberOf"],
        )

        for user in users:
            groups = user.get("memberOf", ())
            parsed_groups = []

            for group in groups:
                g_name = get_cn(group)
                if g_name in LDAP_ROLE_MAPPING:
                    parsed_groups.append(g_name)

            new_user = LDAPUser(
                uid=user["uid"][0],
                employee_number=user.get("employeeNumber", [None])[0],
                display_name=user["displayName"][0],
                groups=parsed_groups,
            )

            found_users.append(new_user)

    return found_users


async def find_by_discord_id(discord_id: int) -> LDAPUser | None:
    """Find a user in the LDAP directory by their Discord ID."""
    client = create_client()

    async with client.connect(is_async=True) as conn:
        users = await conn.search(
            prepare_dn("cn=users,cn=accounts"),
            LDAPSearchScope.SUBTREE,
            f"(employeeNumber={discord_id})",
            ["uid", "employeeNumber", "displayName", "memberOf"],
        )

        user = users[0] if users else None

        if not user:
            return None

        groups = user.get("memberOf", ())
        parsed_groups = []

        for group in groups:
            g_name = get_cn(group)
            if g_name in LDAP_ROLE_MAPPING:
                parsed_groups.append(g_name)

        user = LDAPUser(
            uid=user["uid"][0],
            employee_number=user.get("employeeNumber", [None])[0],
            display_name=user["displayName"][0],
            groups=parsed_groups,
        )

        return user


async def get_group_members(group_name: str) -> list[LDAPUser]:
    """
    Get all members of a group in the LDAP directory.

    Users returned by this do not have any fields filled except for their UID.
    """
    client = create_client()

    async with client.connect(is_async=True) as conn:
        group = await conn.search(
            prepare_dn(f"cn={group_name},cn=groups,cn=accounts"),
            LDAPSearchScope.SUBTREE,
            "(objectClass=groupOfNames)",
            ["member"],
        )

        if not group:
            return []

        members = group[0].get("member", [])

        found_users = [LDAPUser(uid=get_cn(member)) for member in members]

        return found_users