aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--pysite/decorators.py22
-rw-r--r--pysite/views/api/bot/hiphopify.py42
-rw-r--r--pysite/views/api/bot/tags.py37
-rw-r--r--pysite/views/tests/index.py15
-rw-r--r--tests/test_decorators.py31
5 files changed, 86 insertions, 61 deletions
diff --git a/pysite/decorators.py b/pysite/decorators.py
index 0dc1b092..1d840ac7 100644
--- a/pysite/decorators.py
+++ b/pysite/decorators.py
@@ -3,7 +3,7 @@ from json import JSONDecodeError
from flask import request
from schema import Schema, SchemaError
-from werkzeug.exceptions import Forbidden
+from werkzeug.exceptions import BadRequest, Forbidden
from pysite.base_route import APIView, RouteView
from pysite.constants import BOT_API_KEY, CSRF, DEBUG_MODE, ErrorCodes, ValidationTypes
@@ -64,7 +64,10 @@ def api_key(f):
return inner_decorator
-def api_params(schema: Schema, validation_type: ValidationTypes = ValidationTypes.json):
+def api_params(
+ schema: Schema,
+ validation_type: ValidationTypes = ValidationTypes.json,
+ allow_duplicate_params: bool = False):
"""
Validate parameters of data passed to the decorated view.
@@ -73,6 +76,10 @@ def api_params(schema: Schema, validation_type: ValidationTypes = ValidationType
This will pass the validated data in as the first parameter to the decorated function.
This data will always be a list, and view functions are expected to be able to handle that
in the case of multiple sets of data being provided by the api.
+
+ If `allow_duplicate_params` is set to False (only effects dictionary schemata
+ and parameter validation), then the view will return a 400 Bad Request
+ response if the client submits multiple parameters with the same name.
"""
def inner_decorator(f):
@@ -86,13 +93,13 @@ def api_params(schema: Schema, validation_type: ValidationTypes = ValidationType
data = request.get_json()
- if not isinstance(data, list):
+ if not isinstance(data, list) and isinstance(schema._schema, list):
data = [data]
except JSONDecodeError:
return self.error(ErrorCodes.bad_data_format) # pragma: no cover
- elif validation_type == ValidationTypes.params:
+ elif validation_type == ValidationTypes.params and isinstance(schema._schema, list):
# I really don't like this section here, but I can't think of a better way to do it
multi = request.args # This is a MultiDict, which should be flattened to a list of dicts
@@ -120,6 +127,13 @@ def api_params(schema: Schema, validation_type: ValidationTypes = ValidationType
data.append(obj)
+ elif validation_type == ValidationTypes.params and isinstance(schema._schema, dict):
+ if not allow_duplicate_params:
+ for _arg, value in request.args.to_dict(flat=False).items():
+ if len(value) > 1:
+ raise BadRequest("This view does not allow duplicate query arguments")
+ data = request.args.to_dict()
+
else:
raise ValueError(f"Unknown validation type: {validation_type}") # pragma: no cover
diff --git a/pysite/views/api/bot/hiphopify.py b/pysite/views/api/bot/hiphopify.py
index 50a811c6..3a47b64e 100644
--- a/pysite/views/api/bot/hiphopify.py
+++ b/pysite/views/api/bot/hiphopify.py
@@ -12,25 +12,19 @@ from pysite.utils.time import is_expired, parse_duration
log = logging.getLogger(__name__)
-GET_SCHEMA = Schema([
- {
- "user_id": str
- }
-])
-
-POST_SCHEMA = Schema([
- {
- "user_id": str,
- "duration": str,
- Optional("forced_nick"): str
- }
-])
-
-DELETE_SCHEMA = Schema([
- {
- "user_id": str
- }
-])
+GET_SCHEMA = Schema({
+ "user_id": str
+})
+
+POST_SCHEMA = Schema({
+ "user_id": str,
+ "duration": str,
+ Optional("forced_nick"): str
+})
+
+DELETE_SCHEMA = Schema({
+ "user_id": str
+})
class HiphopifyView(APIView, DBMixin):
@@ -55,7 +49,7 @@ class HiphopifyView(APIView, DBMixin):
API key must be provided as header.
"""
- user_id = params[0].get("user_id")
+ user_id = params.get("user_id")
log.debug(f"Checking if user ({user_id}) is permitted to change their nickname.")
data = self.db.get(self.prison_table, user_id) or {}
@@ -83,9 +77,9 @@ class HiphopifyView(APIView, DBMixin):
API key must be provided as header.
"""
- user_id = json_data[0].get("user_id")
- duration = json_data[0].get("duration")
- forced_nick = json_data[0].get("forced_nick")
+ user_id = json_data.get("user_id")
+ duration = json_data.get("duration")
+ forced_nick = json_data.get("forced_nick")
log.debug(f"Attempting to imprison user ({user_id}).")
@@ -146,7 +140,7 @@ class HiphopifyView(APIView, DBMixin):
API key must be provided as header.
"""
- user_id = json_data[0].get("user_id")
+ user_id = json_data.get("user_id")
log.debug(f"Attempting to release user ({user_id}) from hiphop-prison.")
prisoner_data = self.db.get(self.prison_table, user_id)
diff --git a/pysite/views/api/bot/tags.py b/pysite/views/api/bot/tags.py
index 7fdaee3c..4394c224 100644
--- a/pysite/views/api/bot/tags.py
+++ b/pysite/views/api/bot/tags.py
@@ -6,24 +6,18 @@ from pysite.constants import ValidationTypes
from pysite.decorators import api_key, api_params
from pysite.mixins import DBMixin
-GET_SCHEMA = Schema([
- {
- Optional("tag_name"): str
- }
-])
-
-POST_SCHEMA = Schema([
- {
- "tag_name": str,
- "tag_content": str
- }
-])
-
-DELETE_SCHEMA = Schema([
- {
- "tag_name": str
- }
-])
+GET_SCHEMA = Schema({
+ Optional("tag_name"): str
+})
+
+POST_SCHEMA = Schema({
+ "tag_name": str,
+ "tag_content": str
+})
+
+DELETE_SCHEMA = Schema({
+ "tag_name": str
+})
class TagsView(APIView, DBMixin):
@@ -53,7 +47,7 @@ class TagsView(APIView, DBMixin):
tag_name = None
if params:
- tag_name = params[0].get("tag_name")
+ tag_name = params.get("tag_name")
if tag_name:
data = self.db.get(self.table_name, tag_name) or {}
@@ -76,8 +70,6 @@ class TagsView(APIView, DBMixin):
API key must be provided as header.
"""
- json_data = json_data[0]
-
tag_name = json_data.get("tag_name")
tag_content = json_data.get("tag_content")
@@ -102,8 +94,7 @@ class TagsView(APIView, DBMixin):
API key must be provided as header.
"""
- json = data[0]
- tag_name = json.get("tag_name")
+ tag_name = data.get("tag_name")
tag_exists = self.db.get(self.table_name, tag_name)
if tag_exists:
diff --git a/pysite/views/tests/index.py b/pysite/views/tests/index.py
index b96590c0..f99e3f3c 100644
--- a/pysite/views/tests/index.py
+++ b/pysite/views/tests/index.py
@@ -1,20 +1,23 @@
from flask import jsonify
from schema import Schema
-from pysite.base_route import RouteView
+from pysite.base_route import APIView
from pysite.constants import ValidationTypes
from pysite.decorators import api_params
-SCHEMA = Schema([{"test": str}])
+LIST_SCHEMA = Schema([{"test": str}])
+DICT_SCHEMA = Schema({"segfault": str})
-REQUIRED_KEYS = ["test"]
-
-class TestParamsView(RouteView):
+class TestParamsView(APIView):
path = "/testparams"
name = "testparams"
- @api_params(schema=SCHEMA, validation_type=ValidationTypes.params)
+ @api_params(schema=DICT_SCHEMA, validation_type=ValidationTypes.params)
+ def get(self, data):
+ return jsonify(data)
+
+ @api_params(schema=LIST_SCHEMA, validation_type=ValidationTypes.params)
def post(self, data):
jsonified = jsonify(data)
return jsonified
diff --git a/tests/test_decorators.py b/tests/test_decorators.py
index a73052e4..5a3915b8 100644
--- a/tests/test_decorators.py
+++ b/tests/test_decorators.py
@@ -1,12 +1,22 @@
+from schema import Schema
+from werkzeug.datastructures import ImmutableMultiDict
+from werkzeug.exceptions import BadRequest
+
+from pysite.constants import ValidationTypes
+from pysite.decorators import api_params
from tests import SiteTest
+
+class DuckRequest:
+ """A quacking request with the `args` parameter used in schema validation."""
+
+ def __init__(self, args):
+ self.args = args
+
+
class DecoratorTests(SiteTest):
def test_decorator_api_json(self):
""" Check the json validation decorator """
- from pysite.decorators import api_params
- from pysite.constants import ValidationTypes
- from schema import Schema
-
SCHEMA = Schema([{"user_id": int, "role": int}])
@api_params(schema=SCHEMA, validation_type=ValidationTypes.json)
@@ -23,3 +33,16 @@ class DecoratorTests(SiteTest):
self.assertEqual(response.status_code, 200)
self.assertEqual(response.json, [{'test': 'params'}])
+
+ def test_duplicate_params_with_dict_schema_raises_400(self):
+ """Check that duplicate parameters with a dictionary schema return 400 Bad Request"""
+
+ response = self.client.get('/testparams?segfault=yes&segfault=no')
+ self.assert400(response)
+
+ def test_single_params_with_dict_schema(self):
+ """Single parameters with a dictionary schema and `allow_duplicate_keys=False` return 200"""
+
+ response = self.client.get('/testparams?segfault=yes')
+ self.assert200(response)
+ self.assertEqual(response.json, {'segfault': 'yes'})