aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--pysite/base_route.py8
-rw-r--r--pysite/constants.py10
-rw-r--r--pysite/decorators.py75
-rw-r--r--pysite/views/api/bot/tag.py8
-rw-r--r--pysite/views/api/bot/user.py36
-rw-r--r--requirements.txt1
6 files changed, 107 insertions, 31 deletions
diff --git a/pysite/base_route.py b/pysite/base_route.py
index 5c197a42..8e6648ee 100644
--- a/pysite/base_route.py
+++ b/pysite/base_route.py
@@ -112,8 +112,12 @@ class APIView(RouteView):
elif error_code is ErrorCodes.invalid_api_key:
data["error_message"] = "Invalid API-key"
http_code = 401
- elif error_code is ErrorCodes.missing_parameters:
- data["error_message"] = "Not all required parameters were provided"
+ elif error_code is ErrorCodes.bad_data_format:
+ data["error_message"] = "Input data in incorrect format"
+ http_code = 400
+ elif error_code is ErrorCodes.incorrect_parameters:
+ data["error_message"] = "Incorrect parameters provided"
+ http_code = 400
response = jsonify(data)
response.status_code = http_code
diff --git a/pysite/constants.py b/pysite/constants.py
index 0c9c8ecb..59febcc9 100644
--- a/pysite/constants.py
+++ b/pysite/constants.py
@@ -1,13 +1,19 @@
# coding=utf-8
-from enum import IntEnum
+from enum import Enum, IntEnum
class ErrorCodes(IntEnum):
unknown_route = 0
unauthorized = 1
invalid_api_key = 2
- missing_parameters = 3
+ incorrect_parameters = 3
+ bad_data_format = 4
+
+
+class ValidationTypes(Enum):
+ json = "json"
+ params = "params"
OWNER_ROLE = 267627879762755584
diff --git a/pysite/decorators.py b/pysite/decorators.py
index 8d0cf7f4..c404b375 100644
--- a/pysite/decorators.py
+++ b/pysite/decorators.py
@@ -1,13 +1,16 @@
# coding=utf-8
import os
from functools import wraps
+from json import JSONDecodeError
from flask import request
-from pysite.constants import ErrorCodes
+from schema import Schema, SchemaError
+from pysite.constants import ErrorCodes, ValidationTypes
-def valid_api_key(f):
+
+def api_key(f):
"""
Decorator to check if X-API-Key is valid.
@@ -15,9 +18,73 @@ def valid_api_key(f):
"""
@wraps(f)
- def has_valid_api_key(self, *args, **kwargs):
+ def inner(self, *args, **kwargs):
if not request.headers.get("X-API-Key") == os.environ.get("API_KEY"):
return self.error(ErrorCodes.invalid_api_key)
return f(self, *args, **kwargs)
- return has_valid_api_key
+ return inner
+
+
+def api_params(schema: Schema, validation_type: ValidationTypes = ValidationTypes.json):
+ """
+ Validate parameters of data passed to the decorated view.
+
+ Should only be applied to functions on APIView routes.
+
+ This will pass the validated data in as the first parameter to the decorated function.
+ This data will always be a list, and view functions are expected to be able to handle that
+ in the case of multiple sets of data being provided by the api.
+ """
+
+ def inner_decorator(f):
+
+ @wraps(f)
+ def inner(self, *args, **kwargs):
+ if validation_type == ValidationTypes.json:
+ try:
+ if not request.is_json:
+ return self.error(ErrorCodes.bad_data_format)
+
+ data = list(request.get_json())
+ except JSONDecodeError:
+ return self.error(ErrorCodes.bad_data_format)
+
+ elif validation_type == ValidationTypes.params:
+ # I really don't like this section here, but I can't think of a better way to do it
+ multi = request.args # This is a MultiDict, which should be flattened to a list of dicts
+
+ # We'll assume that there's always an equal number of values for each param
+ # Anything else doesn't really make sense anyway
+ data = []
+ longest = None
+
+ for key, items in multi.lists():
+ # Make sure every key has the same number of values
+ if longest is None:
+ # First iteration, store it
+ longest = len(items)
+
+ elif len(items) != longest:
+ # At least one key has a different number of values
+ return self.error(ErrorCodes.bad_data_format)
+
+ for i in range(longest): # Now we know all keys have the same number of values...
+ obj = {} # New dict to store this set of values
+
+ for key, items in multi.lists():
+ obj[key] = items[i] # Store the item at that specific index
+
+ data.append(obj)
+
+ else:
+ raise ValueError(f"Unknown validation type: {validation_type}")
+
+ try:
+ schema.validate(data)
+ except SchemaError:
+ return self.error(ErrorCodes.incorrect_parameters)
+
+ return f(self, data, *args, **kwargs)
+ return inner
+ return inner_decorator
diff --git a/pysite/views/api/bot/tag.py b/pysite/views/api/bot/tag.py
index b2ce145b..363f98fe 100644
--- a/pysite/views/api/bot/tag.py
+++ b/pysite/views/api/bot/tag.py
@@ -4,7 +4,7 @@ from flask import jsonify, request
from pysite.base_route import APIView, DBViewMixin
from pysite.constants import ErrorCodes
-from pysite.decorators import valid_api_key
+from pysite.decorators import api_key
class TagView(APIView, DBViewMixin):
@@ -13,7 +13,7 @@ class TagView(APIView, DBViewMixin):
table_name = "tag"
table_primary_key = "tag_name"
- @valid_api_key
+ @api_key
def get(self):
"""
Data must be provided as params,
@@ -29,7 +29,7 @@ class TagView(APIView, DBViewMixin):
return jsonify(data)
- @valid_api_key
+ @api_key
def post(self):
"""
Data must be provided as JSON.
@@ -51,6 +51,6 @@ class TagView(APIView, DBViewMixin):
}
)
else:
- return self.error(ErrorCodes.missing_parameters)
+ return self.error(ErrorCodes.incorrect_parameters)
return jsonify({"success": True})
diff --git a/pysite/views/api/bot/user.py b/pysite/views/api/bot/user.py
index aad58e05..5e9dc444 100644
--- a/pysite/views/api/bot/user.py
+++ b/pysite/views/api/bot/user.py
@@ -1,10 +1,20 @@
# coding=utf-8
-from flask import jsonify, request
+from flask import jsonify
+
+from schema import Schema
from pysite.base_route import APIView, DBViewMixin
-from pysite.constants import ErrorCodes
-from pysite.decorators import valid_api_key
+from pysite.constants import ValidationTypes
+from pysite.decorators import api_key, api_params
+
+
+SCHEMA = Schema([
+ {
+ "user_id": int,
+ "role": int
+ }
+])
REQUIRED_KEYS = [
"user_id",
@@ -18,24 +28,12 @@ class UserView(APIView, DBViewMixin):
table_name = "users"
table_primary_key = "user_id"
- @valid_api_key
- def post(self):
- data = request.get_json()
-
- if not isinstance(data, list):
- data = [data]
-
+ @api_key
+ @api_params(schema=SCHEMA, validation_type=ValidationTypes.json)
+ def post(self, data):
for user in data:
- if not all(k in user for k in REQUIRED_KEYS):
- print(user)
- return self.error(ErrorCodes.missing_parameters)
-
self.db.insert(
- self.table_name,
- {
- "user_id": user["user_id"],
- "role": user["role"],
- },
+ self.table_name, user,
conflict="update",
durability="soft"
)
diff --git a/requirements.txt b/requirements.txt
index 4ae185cd..205154bb 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -5,3 +5,4 @@ gevent
gevent-websocket
wsaccel
ujson
+schema