diff options
Diffstat (limited to 'pysite/decorators.py')
-rw-r--r-- | pysite/decorators.py | 22 |
1 files changed, 18 insertions, 4 deletions
diff --git a/pysite/decorators.py b/pysite/decorators.py index 0dc1b092..1d840ac7 100644 --- a/pysite/decorators.py +++ b/pysite/decorators.py @@ -3,7 +3,7 @@ from json import JSONDecodeError from flask import request from schema import Schema, SchemaError -from werkzeug.exceptions import Forbidden +from werkzeug.exceptions import BadRequest, Forbidden from pysite.base_route import APIView, RouteView from pysite.constants import BOT_API_KEY, CSRF, DEBUG_MODE, ErrorCodes, ValidationTypes @@ -64,7 +64,10 @@ def api_key(f): return inner_decorator -def api_params(schema: Schema, validation_type: ValidationTypes = ValidationTypes.json): +def api_params( + schema: Schema, + validation_type: ValidationTypes = ValidationTypes.json, + allow_duplicate_params: bool = False): """ Validate parameters of data passed to the decorated view. @@ -73,6 +76,10 @@ def api_params(schema: Schema, validation_type: ValidationTypes = ValidationType This will pass the validated data in as the first parameter to the decorated function. This data will always be a list, and view functions are expected to be able to handle that in the case of multiple sets of data being provided by the api. + + If `allow_duplicate_params` is set to False (only effects dictionary schemata + and parameter validation), then the view will return a 400 Bad Request + response if the client submits multiple parameters with the same name. """ def inner_decorator(f): @@ -86,13 +93,13 @@ def api_params(schema: Schema, validation_type: ValidationTypes = ValidationType data = request.get_json() - if not isinstance(data, list): + if not isinstance(data, list) and isinstance(schema._schema, list): data = [data] except JSONDecodeError: return self.error(ErrorCodes.bad_data_format) # pragma: no cover - elif validation_type == ValidationTypes.params: + elif validation_type == ValidationTypes.params and isinstance(schema._schema, list): # I really don't like this section here, but I can't think of a better way to do it multi = request.args # This is a MultiDict, which should be flattened to a list of dicts @@ -120,6 +127,13 @@ def api_params(schema: Schema, validation_type: ValidationTypes = ValidationType data.append(obj) + elif validation_type == ValidationTypes.params and isinstance(schema._schema, dict): + if not allow_duplicate_params: + for _arg, value in request.args.to_dict(flat=False).items(): + if len(value) > 1: + raise BadRequest("This view does not allow duplicate query arguments") + data = request.args.to_dict() + else: raise ValueError(f"Unknown validation type: {validation_type}") # pragma: no cover |