aboutsummaryrefslogtreecommitdiffstats
path: root/pysite/decorators.py
diff options
context:
space:
mode:
Diffstat (limited to 'pysite/decorators.py')
-rw-r--r--pysite/decorators.py22
1 files changed, 18 insertions, 4 deletions
diff --git a/pysite/decorators.py b/pysite/decorators.py
index 0dc1b092..1d840ac7 100644
--- a/pysite/decorators.py
+++ b/pysite/decorators.py
@@ -3,7 +3,7 @@ from json import JSONDecodeError
from flask import request
from schema import Schema, SchemaError
-from werkzeug.exceptions import Forbidden
+from werkzeug.exceptions import BadRequest, Forbidden
from pysite.base_route import APIView, RouteView
from pysite.constants import BOT_API_KEY, CSRF, DEBUG_MODE, ErrorCodes, ValidationTypes
@@ -64,7 +64,10 @@ def api_key(f):
return inner_decorator
-def api_params(schema: Schema, validation_type: ValidationTypes = ValidationTypes.json):
+def api_params(
+ schema: Schema,
+ validation_type: ValidationTypes = ValidationTypes.json,
+ allow_duplicate_params: bool = False):
"""
Validate parameters of data passed to the decorated view.
@@ -73,6 +76,10 @@ def api_params(schema: Schema, validation_type: ValidationTypes = ValidationType
This will pass the validated data in as the first parameter to the decorated function.
This data will always be a list, and view functions are expected to be able to handle that
in the case of multiple sets of data being provided by the api.
+
+ If `allow_duplicate_params` is set to False (only effects dictionary schemata
+ and parameter validation), then the view will return a 400 Bad Request
+ response if the client submits multiple parameters with the same name.
"""
def inner_decorator(f):
@@ -86,13 +93,13 @@ def api_params(schema: Schema, validation_type: ValidationTypes = ValidationType
data = request.get_json()
- if not isinstance(data, list):
+ if not isinstance(data, list) and isinstance(schema._schema, list):
data = [data]
except JSONDecodeError:
return self.error(ErrorCodes.bad_data_format) # pragma: no cover
- elif validation_type == ValidationTypes.params:
+ elif validation_type == ValidationTypes.params and isinstance(schema._schema, list):
# I really don't like this section here, but I can't think of a better way to do it
multi = request.args # This is a MultiDict, which should be flattened to a list of dicts
@@ -120,6 +127,13 @@ def api_params(schema: Schema, validation_type: ValidationTypes = ValidationType
data.append(obj)
+ elif validation_type == ValidationTypes.params and isinstance(schema._schema, dict):
+ if not allow_duplicate_params:
+ for _arg, value in request.args.to_dict(flat=False).items():
+ if len(value) > 1:
+ raise BadRequest("This view does not allow duplicate query arguments")
+ data = request.args.to_dict()
+
else:
raise ValueError(f"Unknown validation type: {validation_type}") # pragma: no cover