diff options
-rw-r--r-- | pysite/decorators.py | 22 | ||||
-rw-r--r-- | pysite/views/api/bot/hiphopify.py | 42 | ||||
-rw-r--r-- | pysite/views/api/bot/tags.py | 37 | ||||
-rw-r--r-- | pysite/views/tests/index.py | 15 | ||||
-rw-r--r-- | tests/test_decorators.py | 31 |
5 files changed, 86 insertions, 61 deletions
diff --git a/pysite/decorators.py b/pysite/decorators.py index 0dc1b092..1d840ac7 100644 --- a/pysite/decorators.py +++ b/pysite/decorators.py @@ -3,7 +3,7 @@ from json import JSONDecodeError from flask import request from schema import Schema, SchemaError -from werkzeug.exceptions import Forbidden +from werkzeug.exceptions import BadRequest, Forbidden from pysite.base_route import APIView, RouteView from pysite.constants import BOT_API_KEY, CSRF, DEBUG_MODE, ErrorCodes, ValidationTypes @@ -64,7 +64,10 @@ def api_key(f): return inner_decorator -def api_params(schema: Schema, validation_type: ValidationTypes = ValidationTypes.json): +def api_params( + schema: Schema, + validation_type: ValidationTypes = ValidationTypes.json, + allow_duplicate_params: bool = False): """ Validate parameters of data passed to the decorated view. @@ -73,6 +76,10 @@ def api_params(schema: Schema, validation_type: ValidationTypes = ValidationType This will pass the validated data in as the first parameter to the decorated function. This data will always be a list, and view functions are expected to be able to handle that in the case of multiple sets of data being provided by the api. + + If `allow_duplicate_params` is set to False (only effects dictionary schemata + and parameter validation), then the view will return a 400 Bad Request + response if the client submits multiple parameters with the same name. """ def inner_decorator(f): @@ -86,13 +93,13 @@ def api_params(schema: Schema, validation_type: ValidationTypes = ValidationType data = request.get_json() - if not isinstance(data, list): + if not isinstance(data, list) and isinstance(schema._schema, list): data = [data] except JSONDecodeError: return self.error(ErrorCodes.bad_data_format) # pragma: no cover - elif validation_type == ValidationTypes.params: + elif validation_type == ValidationTypes.params and isinstance(schema._schema, list): # I really don't like this section here, but I can't think of a better way to do it multi = request.args # This is a MultiDict, which should be flattened to a list of dicts @@ -120,6 +127,13 @@ def api_params(schema: Schema, validation_type: ValidationTypes = ValidationType data.append(obj) + elif validation_type == ValidationTypes.params and isinstance(schema._schema, dict): + if not allow_duplicate_params: + for _arg, value in request.args.to_dict(flat=False).items(): + if len(value) > 1: + raise BadRequest("This view does not allow duplicate query arguments") + data = request.args.to_dict() + else: raise ValueError(f"Unknown validation type: {validation_type}") # pragma: no cover diff --git a/pysite/views/api/bot/hiphopify.py b/pysite/views/api/bot/hiphopify.py index 50a811c6..3a47b64e 100644 --- a/pysite/views/api/bot/hiphopify.py +++ b/pysite/views/api/bot/hiphopify.py @@ -12,25 +12,19 @@ from pysite.utils.time import is_expired, parse_duration log = logging.getLogger(__name__) -GET_SCHEMA = Schema([ - { - "user_id": str - } -]) - -POST_SCHEMA = Schema([ - { - "user_id": str, - "duration": str, - Optional("forced_nick"): str - } -]) - -DELETE_SCHEMA = Schema([ - { - "user_id": str - } -]) +GET_SCHEMA = Schema({ + "user_id": str +}) + +POST_SCHEMA = Schema({ + "user_id": str, + "duration": str, + Optional("forced_nick"): str +}) + +DELETE_SCHEMA = Schema({ + "user_id": str +}) class HiphopifyView(APIView, DBMixin): @@ -55,7 +49,7 @@ class HiphopifyView(APIView, DBMixin): API key must be provided as header. """ - user_id = params[0].get("user_id") + user_id = params.get("user_id") log.debug(f"Checking if user ({user_id}) is permitted to change their nickname.") data = self.db.get(self.prison_table, user_id) or {} @@ -83,9 +77,9 @@ class HiphopifyView(APIView, DBMixin): API key must be provided as header. """ - user_id = json_data[0].get("user_id") - duration = json_data[0].get("duration") - forced_nick = json_data[0].get("forced_nick") + user_id = json_data.get("user_id") + duration = json_data.get("duration") + forced_nick = json_data.get("forced_nick") log.debug(f"Attempting to imprison user ({user_id}).") @@ -146,7 +140,7 @@ class HiphopifyView(APIView, DBMixin): API key must be provided as header. """ - user_id = json_data[0].get("user_id") + user_id = json_data.get("user_id") log.debug(f"Attempting to release user ({user_id}) from hiphop-prison.") prisoner_data = self.db.get(self.prison_table, user_id) diff --git a/pysite/views/api/bot/tags.py b/pysite/views/api/bot/tags.py index 7fdaee3c..4394c224 100644 --- a/pysite/views/api/bot/tags.py +++ b/pysite/views/api/bot/tags.py @@ -6,24 +6,18 @@ from pysite.constants import ValidationTypes from pysite.decorators import api_key, api_params from pysite.mixins import DBMixin -GET_SCHEMA = Schema([ - { - Optional("tag_name"): str - } -]) - -POST_SCHEMA = Schema([ - { - "tag_name": str, - "tag_content": str - } -]) - -DELETE_SCHEMA = Schema([ - { - "tag_name": str - } -]) +GET_SCHEMA = Schema({ + Optional("tag_name"): str +}) + +POST_SCHEMA = Schema({ + "tag_name": str, + "tag_content": str +}) + +DELETE_SCHEMA = Schema({ + "tag_name": str +}) class TagsView(APIView, DBMixin): @@ -53,7 +47,7 @@ class TagsView(APIView, DBMixin): tag_name = None if params: - tag_name = params[0].get("tag_name") + tag_name = params.get("tag_name") if tag_name: data = self.db.get(self.table_name, tag_name) or {} @@ -76,8 +70,6 @@ class TagsView(APIView, DBMixin): API key must be provided as header. """ - json_data = json_data[0] - tag_name = json_data.get("tag_name") tag_content = json_data.get("tag_content") @@ -102,8 +94,7 @@ class TagsView(APIView, DBMixin): API key must be provided as header. """ - json = data[0] - tag_name = json.get("tag_name") + tag_name = data.get("tag_name") tag_exists = self.db.get(self.table_name, tag_name) if tag_exists: diff --git a/pysite/views/tests/index.py b/pysite/views/tests/index.py index b96590c0..f99e3f3c 100644 --- a/pysite/views/tests/index.py +++ b/pysite/views/tests/index.py @@ -1,20 +1,23 @@ from flask import jsonify from schema import Schema -from pysite.base_route import RouteView +from pysite.base_route import APIView from pysite.constants import ValidationTypes from pysite.decorators import api_params -SCHEMA = Schema([{"test": str}]) +LIST_SCHEMA = Schema([{"test": str}]) +DICT_SCHEMA = Schema({"segfault": str}) -REQUIRED_KEYS = ["test"] - -class TestParamsView(RouteView): +class TestParamsView(APIView): path = "/testparams" name = "testparams" - @api_params(schema=SCHEMA, validation_type=ValidationTypes.params) + @api_params(schema=DICT_SCHEMA, validation_type=ValidationTypes.params) + def get(self, data): + return jsonify(data) + + @api_params(schema=LIST_SCHEMA, validation_type=ValidationTypes.params) def post(self, data): jsonified = jsonify(data) return jsonified diff --git a/tests/test_decorators.py b/tests/test_decorators.py index a73052e4..5a3915b8 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -1,12 +1,22 @@ +from schema import Schema +from werkzeug.datastructures import ImmutableMultiDict +from werkzeug.exceptions import BadRequest + +from pysite.constants import ValidationTypes +from pysite.decorators import api_params from tests import SiteTest + +class DuckRequest: + """A quacking request with the `args` parameter used in schema validation.""" + + def __init__(self, args): + self.args = args + + class DecoratorTests(SiteTest): def test_decorator_api_json(self): """ Check the json validation decorator """ - from pysite.decorators import api_params - from pysite.constants import ValidationTypes - from schema import Schema - SCHEMA = Schema([{"user_id": int, "role": int}]) @api_params(schema=SCHEMA, validation_type=ValidationTypes.json) @@ -23,3 +33,16 @@ class DecoratorTests(SiteTest): self.assertEqual(response.status_code, 200) self.assertEqual(response.json, [{'test': 'params'}]) + + def test_duplicate_params_with_dict_schema_raises_400(self): + """Check that duplicate parameters with a dictionary schema return 400 Bad Request""" + + response = self.client.get('/testparams?segfault=yes&segfault=no') + self.assert400(response) + + def test_single_params_with_dict_schema(self): + """Single parameters with a dictionary schema and `allow_duplicate_keys=False` return 200""" + + response = self.client.get('/testparams?segfault=yes') + self.assert200(response) + self.assertEqual(response.json, {'segfault': 'yes'}) |