aboutsummaryrefslogtreecommitdiffstats
path: root/pysite/decorators.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--pysite/decorators.py75
1 files changed, 71 insertions, 4 deletions
diff --git a/pysite/decorators.py b/pysite/decorators.py
index 8d0cf7f4..c404b375 100644
--- a/pysite/decorators.py
+++ b/pysite/decorators.py
@@ -1,13 +1,16 @@
# coding=utf-8
import os
from functools import wraps
+from json import JSONDecodeError
from flask import request
-from pysite.constants import ErrorCodes
+from schema import Schema, SchemaError
+from pysite.constants import ErrorCodes, ValidationTypes
-def valid_api_key(f):
+
+def api_key(f):
"""
Decorator to check if X-API-Key is valid.
@@ -15,9 +18,73 @@ def valid_api_key(f):
"""
@wraps(f)
- def has_valid_api_key(self, *args, **kwargs):
+ def inner(self, *args, **kwargs):
if not request.headers.get("X-API-Key") == os.environ.get("API_KEY"):
return self.error(ErrorCodes.invalid_api_key)
return f(self, *args, **kwargs)
- return has_valid_api_key
+ return inner
+
+
+def api_params(schema: Schema, validation_type: ValidationTypes = ValidationTypes.json):
+ """
+ Validate parameters of data passed to the decorated view.
+
+ Should only be applied to functions on APIView routes.
+
+ This will pass the validated data in as the first parameter to the decorated function.
+ This data will always be a list, and view functions are expected to be able to handle that
+ in the case of multiple sets of data being provided by the api.
+ """
+
+ def inner_decorator(f):
+
+ @wraps(f)
+ def inner(self, *args, **kwargs):
+ if validation_type == ValidationTypes.json:
+ try:
+ if not request.is_json:
+ return self.error(ErrorCodes.bad_data_format)
+
+ data = list(request.get_json())
+ except JSONDecodeError:
+ return self.error(ErrorCodes.bad_data_format)
+
+ elif validation_type == ValidationTypes.params:
+ # I really don't like this section here, but I can't think of a better way to do it
+ multi = request.args # This is a MultiDict, which should be flattened to a list of dicts
+
+ # We'll assume that there's always an equal number of values for each param
+ # Anything else doesn't really make sense anyway
+ data = []
+ longest = None
+
+ for key, items in multi.lists():
+ # Make sure every key has the same number of values
+ if longest is None:
+ # First iteration, store it
+ longest = len(items)
+
+ elif len(items) != longest:
+ # At least one key has a different number of values
+ return self.error(ErrorCodes.bad_data_format)
+
+ for i in range(longest): # Now we know all keys have the same number of values...
+ obj = {} # New dict to store this set of values
+
+ for key, items in multi.lists():
+ obj[key] = items[i] # Store the item at that specific index
+
+ data.append(obj)
+
+ else:
+ raise ValueError(f"Unknown validation type: {validation_type}")
+
+ try:
+ schema.validate(data)
+ except SchemaError:
+ return self.error(ErrorCodes.incorrect_parameters)
+
+ return f(self, data, *args, **kwargs)
+ return inner
+ return inner_decorator