diff options
| -rw-r--r-- | bot/__main__.py | 3 | ||||
| -rw-r--r-- | bot/api.py | 30 | ||||
| -rw-r--r-- | bot/cogs/events.py | 17 | ||||
| -rw-r--r-- | bot/cogs/tags.py | 87 |
4 files changed, 74 insertions, 63 deletions
diff --git a/bot/__main__.py b/bot/__main__.py index 30d1b4c9a..0055e19ba 100644 --- a/bot/__main__.py +++ b/bot/__main__.py @@ -1,3 +1,4 @@ +import asyncio import logging import socket @@ -5,6 +6,7 @@ from aiohttp import AsyncResolver, ClientSession, TCPConnector from discord import Game from discord.ext.commands import Bot, when_mentioned_or +from bot.api import APIClient from bot.constants import Bot as BotConfig, DEBUG_MODE from bot.utils.service_discovery import wait_for_rmq @@ -27,6 +29,7 @@ bot.http_session = ClientSession( family=socket.AF_INET, ) ) +bot.api_client = APIClient(loop=asyncio.get_event_loop()) log.info("Waiting for RabbitMQ...") has_rmq = wait_for_rmq() diff --git a/bot/api.py b/bot/api.py new file mode 100644 index 000000000..6b9598da2 --- /dev/null +++ b/bot/api.py @@ -0,0 +1,30 @@ +from urllib.parse import quote as quote_url + +import aiohttp + +from .constants import Keys, URLs + + +class APIClient: + def __init__(self, **kwargs): + auth_headers = { + 'Authorization': f"Token {Keys.site_api}" + } + + if 'headers' in kwargs: + kwargs['headers'].update(auth_headers) + else: + kwargs['headers'] = auth_headers + + self.session = aiohttp.ClientSession( + **kwargs, + raise_for_status=True + ) + + @staticmethod + def _url_for(endpoint: str): + return f"{URLs.site_schema}{URLs.site_api}/{quote_url(endpoint)}" + + async def get(self, endpoint: str, *args, **kwargs): + async with self.session.get(self._url_for(endpoint), *args, **kwargs) as resp: + return await resp.json() diff --git a/bot/cogs/events.py b/bot/cogs/events.py index 0b9b75a00..281e212ff 100644 --- a/bot/cogs/events.py +++ b/bot/cogs/events.py @@ -1,5 +1,6 @@ import logging +from aiohttp import ClientResponseError from discord import Colour, Embed, Member, Object from discord.ext.commands import ( BadArgument, Bot, BotMissingPermissions, @@ -134,10 +135,18 @@ class Events: f"Here's what I'm missing: **{e.missing_perms}**" ) elif isinstance(e, CommandInvokeError): - await ctx.send( - f"Sorry, an unexpected error occurred. Please let us know!\n\n```{e}```" - ) - raise e.original + if isinstance(e.original, ClientResponseError): + return await ctx.send("There was some response error but I can't put my finger on what exactly.") + if e.original.status == 404: + await ctx.send("There does not seem to be anything matching your query.") + else: + await ctx.send("BEEP BEEP UNKNOWN API ERROR!=?!??!?!?!?") + + else: + await ctx.send( + f"Sorry, an unexpected error occurred. Please let us know!\n\n```{e}```" + ) + raise e.original raise e async def on_ready(self): diff --git a/bot/cogs/tags.py b/bot/cogs/tags.py index 7499b2b1c..baed44a8c 100644 --- a/bot/cogs/tags.py +++ b/bot/cogs/tags.py @@ -85,7 +85,7 @@ class Tags: def __init__(self, bot: Bot): self.bot = bot self.tag_cooldowns = {} - self.headers = {"X-API-KEY": Keys.site_api} + self.headers = {"Authorization": f"Token {Keys.site_api}"} async def get_tag_data(self, tag_name=None) -> dict: """ @@ -97,12 +97,13 @@ class Tags: if not, returns a list of dicts with all tag data. """ - params = {} if tag_name: - params["tag_name"] = tag_name + url = f'{URLs.site_tags_api}/{tag_name}' + else: + url = URLs.site_tags_api - response = await self.bot.http_session.get(URLs.site_tags_api, headers=self.headers, params=params) + response = await self.bot.http_session.get(url, headers=self.headers) tag_data = await response.json() return tag_data @@ -196,64 +197,32 @@ class Tags: f"Cooldown ends in {time_left:.1f} seconds.") return - tags = [] - - embed = Embed() - embed.colour = Colour.red() - tag_data = await self.get_tag_data(tag_name) - - # If we found something, prepare that data - if tag_data: - embed.colour = Colour.blurple() - - if tag_name: - log.debug(f"{ctx.author} requested the tag '{tag_name}'") - embed.title = tag_name + if tag_name is not None: + tag = await self.bot.api_client.get(f'/bot/tags/{tag_name}') + if ctx.channel.id not in TEST_CHANNELS: + self.tag_cooldowns[tag_name] = { + "time": time.time(), + "channel": ctx.channel.id + } + await ctx.send(embed=Embed.from_data(tag['embed'])) - if ctx.channel.id not in TEST_CHANNELS: - self.tag_cooldowns[tag_name] = { - "time": time.time(), - "channel": ctx.channel.id - } - - else: - embed.title = "**Current tags**" - - if isinstance(tag_data, list): - log.debug(f"{ctx.author} requested a list of all tags") - tags = [f"**»** {tag['tag_name']}" for tag in tag_data] - tags = sorted(tags) - - else: - embed.description = tag_data['tag_content'] - - # If not, prepare an error message. else: - embed.colour = Colour.red() - - if isinstance(tag_data, dict): - log.warning(f"{ctx.author} requested the tag '{tag_name}', but it could not be found.") - embed.description = f"**{tag_name}** is an unknown tag name. Please check the spelling and try again." + tags = await self.bot.api_client.get('/bot/tags') + if not tags: + await ctx.send(embed=Embed( + description="**There are no tags in the database!**", + colour=Colour.red() + )) else: - log.warning(f"{ctx.author} requested a list of all tags, but the tags database was empty!") - embed.description = "**There are no tags in the database!**" - - if tag_name: - embed.set_footer(text="To show a list of all tags, use !tags.") - embed.title = "Tag not found." - - # Paginate if this is a list of all tags - if tags: - log.debug(f"Returning a paginated list of all tags.") - return await LinePaginator.paginate( - (lines for lines in tags), - ctx, embed, - footer_text="To show a tag, type !tags <tagname>.", - empty=False, - max_lines=15 - ) - - return await ctx.send(embed=embed) + embed = Embed(title="**Current tags**") + await LinePaginator.paginate( + sorted(f"**»** {tag['title']}" for tag in tags), + ctx, + embed, + footer_text="To show a tag, type !tags <tagname>.", + empty=False, + max_lines=15 + ) @tags_group.command(name='set', aliases=('add', 'edit', 's')) @with_role(Roles.admin, Roles.owner, Roles.moderator) |