diff options
| -rw-r--r-- | pysite/decorators.py | 22 | ||||
| -rw-r--r-- | pysite/views/api/bot/hiphopify.py | 42 | ||||
| -rw-r--r-- | pysite/views/api/bot/tags.py | 37 | ||||
| -rw-r--r-- | pysite/views/tests/index.py | 15 | ||||
| -rw-r--r-- | tests/test_decorators.py | 31 | 
5 files changed, 86 insertions, 61 deletions
| diff --git a/pysite/decorators.py b/pysite/decorators.py index 0dc1b092..1d840ac7 100644 --- a/pysite/decorators.py +++ b/pysite/decorators.py @@ -3,7 +3,7 @@ from json import JSONDecodeError  from flask import request  from schema import Schema, SchemaError -from werkzeug.exceptions import Forbidden +from werkzeug.exceptions import BadRequest, Forbidden  from pysite.base_route import APIView, RouteView  from pysite.constants import BOT_API_KEY, CSRF, DEBUG_MODE, ErrorCodes, ValidationTypes @@ -64,7 +64,10 @@ def api_key(f):      return inner_decorator -def api_params(schema: Schema, validation_type: ValidationTypes = ValidationTypes.json): +def api_params( +        schema: Schema, +        validation_type: ValidationTypes = ValidationTypes.json, +        allow_duplicate_params: bool = False):      """      Validate parameters of data passed to the decorated view. @@ -73,6 +76,10 @@ def api_params(schema: Schema, validation_type: ValidationTypes = ValidationType      This will pass the validated data in as the first parameter to the decorated function.      This data will always be a list, and view functions are expected to be able to handle that      in the case of multiple sets of data being provided by the api. + +    If `allow_duplicate_params` is set to False (only effects dictionary schemata +    and parameter validation), then the view will return a 400 Bad Request +    response if the client submits multiple parameters with the same name.      """      def inner_decorator(f): @@ -86,13 +93,13 @@ def api_params(schema: Schema, validation_type: ValidationTypes = ValidationType                      data = request.get_json() -                    if not isinstance(data, list): +                    if not isinstance(data, list) and isinstance(schema._schema, list):                          data = [data]                  except JSONDecodeError:                      return self.error(ErrorCodes.bad_data_format)  # pragma: no cover -            elif validation_type == ValidationTypes.params: +            elif validation_type == ValidationTypes.params and isinstance(schema._schema, list):                  # I really don't like this section here, but I can't think of a better way to do it                  multi = request.args  # This is a MultiDict, which should be flattened to a list of dicts @@ -120,6 +127,13 @@ def api_params(schema: Schema, validation_type: ValidationTypes = ValidationType                          data.append(obj) +            elif validation_type == ValidationTypes.params and isinstance(schema._schema, dict): +                if not allow_duplicate_params: +                    for _arg, value in request.args.to_dict(flat=False).items(): +                        if len(value) > 1: +                            raise BadRequest("This view does not allow duplicate query arguments") +                data = request.args.to_dict() +              else:                  raise ValueError(f"Unknown validation type: {validation_type}")  # pragma: no cover diff --git a/pysite/views/api/bot/hiphopify.py b/pysite/views/api/bot/hiphopify.py index 50a811c6..3a47b64e 100644 --- a/pysite/views/api/bot/hiphopify.py +++ b/pysite/views/api/bot/hiphopify.py @@ -12,25 +12,19 @@ from pysite.utils.time import is_expired, parse_duration  log = logging.getLogger(__name__) -GET_SCHEMA = Schema([ -    { -        "user_id": str -    } -]) - -POST_SCHEMA = Schema([ -    { -        "user_id": str, -        "duration": str, -        Optional("forced_nick"): str -    } -]) - -DELETE_SCHEMA = Schema([ -    { -        "user_id": str -    } -]) +GET_SCHEMA = Schema({ +    "user_id": str +}) + +POST_SCHEMA = Schema({ +    "user_id": str, +    "duration": str, +    Optional("forced_nick"): str +}) + +DELETE_SCHEMA = Schema({ +    "user_id": str +})  class HiphopifyView(APIView, DBMixin): @@ -55,7 +49,7 @@ class HiphopifyView(APIView, DBMixin):          API key must be provided as header.          """ -        user_id = params[0].get("user_id") +        user_id = params.get("user_id")          log.debug(f"Checking if user ({user_id}) is permitted to change their nickname.")          data = self.db.get(self.prison_table, user_id) or {} @@ -83,9 +77,9 @@ class HiphopifyView(APIView, DBMixin):          API key must be provided as header.          """ -        user_id = json_data[0].get("user_id") -        duration = json_data[0].get("duration") -        forced_nick = json_data[0].get("forced_nick") +        user_id = json_data.get("user_id") +        duration = json_data.get("duration") +        forced_nick = json_data.get("forced_nick")          log.debug(f"Attempting to imprison user ({user_id}).") @@ -146,7 +140,7 @@ class HiphopifyView(APIView, DBMixin):          API key must be provided as header.          """ -        user_id = json_data[0].get("user_id") +        user_id = json_data.get("user_id")          log.debug(f"Attempting to release user ({user_id}) from hiphop-prison.")          prisoner_data = self.db.get(self.prison_table, user_id) diff --git a/pysite/views/api/bot/tags.py b/pysite/views/api/bot/tags.py index 7fdaee3c..4394c224 100644 --- a/pysite/views/api/bot/tags.py +++ b/pysite/views/api/bot/tags.py @@ -6,24 +6,18 @@ from pysite.constants import ValidationTypes  from pysite.decorators import api_key, api_params  from pysite.mixins import DBMixin -GET_SCHEMA = Schema([ -    { -        Optional("tag_name"): str -    } -]) - -POST_SCHEMA = Schema([ -    { -        "tag_name": str, -        "tag_content": str -    } -]) - -DELETE_SCHEMA = Schema([ -    { -        "tag_name": str -    } -]) +GET_SCHEMA = Schema({ +    Optional("tag_name"): str +}) + +POST_SCHEMA = Schema({ +    "tag_name": str, +    "tag_content": str +}) + +DELETE_SCHEMA = Schema({ +    "tag_name": str +})  class TagsView(APIView, DBMixin): @@ -53,7 +47,7 @@ class TagsView(APIView, DBMixin):          tag_name = None          if params: -            tag_name = params[0].get("tag_name") +            tag_name = params.get("tag_name")          if tag_name:              data = self.db.get(self.table_name, tag_name) or {} @@ -76,8 +70,6 @@ class TagsView(APIView, DBMixin):          API key must be provided as header.          """ -        json_data = json_data[0] -          tag_name = json_data.get("tag_name")          tag_content = json_data.get("tag_content") @@ -102,8 +94,7 @@ class TagsView(APIView, DBMixin):          API key must be provided as header.          """ -        json = data[0] -        tag_name = json.get("tag_name") +        tag_name = data.get("tag_name")          tag_exists = self.db.get(self.table_name, tag_name)          if tag_exists: diff --git a/pysite/views/tests/index.py b/pysite/views/tests/index.py index b96590c0..f99e3f3c 100644 --- a/pysite/views/tests/index.py +++ b/pysite/views/tests/index.py @@ -1,20 +1,23 @@  from flask import jsonify  from schema import Schema -from pysite.base_route import RouteView +from pysite.base_route import APIView  from pysite.constants import ValidationTypes  from pysite.decorators import api_params -SCHEMA = Schema([{"test": str}]) +LIST_SCHEMA = Schema([{"test": str}]) +DICT_SCHEMA = Schema({"segfault": str}) -REQUIRED_KEYS = ["test"] - -class TestParamsView(RouteView): +class TestParamsView(APIView):      path = "/testparams"      name = "testparams" -    @api_params(schema=SCHEMA, validation_type=ValidationTypes.params) +    @api_params(schema=DICT_SCHEMA, validation_type=ValidationTypes.params) +    def get(self, data): +        return jsonify(data) + +    @api_params(schema=LIST_SCHEMA, validation_type=ValidationTypes.params)      def post(self, data):          jsonified = jsonify(data)          return jsonified diff --git a/tests/test_decorators.py b/tests/test_decorators.py index a73052e4..5a3915b8 100644 --- a/tests/test_decorators.py +++ b/tests/test_decorators.py @@ -1,12 +1,22 @@ +from schema import Schema +from werkzeug.datastructures import ImmutableMultiDict +from werkzeug.exceptions import BadRequest + +from pysite.constants import ValidationTypes +from pysite.decorators import api_params  from tests import SiteTest + +class DuckRequest: +    """A quacking request with the `args` parameter used in schema validation.""" + +    def __init__(self, args): +        self.args = args + +  class DecoratorTests(SiteTest):      def test_decorator_api_json(self):          """ Check the json validation decorator """ -        from pysite.decorators import api_params -        from pysite.constants import ValidationTypes -        from schema import Schema -          SCHEMA = Schema([{"user_id": int, "role": int}])          @api_params(schema=SCHEMA, validation_type=ValidationTypes.json) @@ -23,3 +33,16 @@ class DecoratorTests(SiteTest):          self.assertEqual(response.status_code, 200)          self.assertEqual(response.json, [{'test': 'params'}]) + +    def test_duplicate_params_with_dict_schema_raises_400(self): +        """Check that duplicate parameters with a dictionary schema return 400 Bad Request""" + +        response = self.client.get('/testparams?segfault=yes&segfault=no') +        self.assert400(response) + +    def test_single_params_with_dict_schema(self): +        """Single parameters with a dictionary schema and `allow_duplicate_keys=False` return 200""" + +        response = self.client.get('/testparams?segfault=yes') +        self.assert200(response) +        self.assertEqual(response.json, {'segfault': 'yes'}) | 
