From 60694f47d02c37e3e54933a0b94cb202926c4b3d Mon Sep 17 00:00:00 2001 From: Gareth Coles Date: Fri, 16 Feb 2018 16:01:00 +0000 Subject: API schema validation decorator #yzuf (#15) * API schema validation decorator * Remove stray comma * Remove unnecessary conditional * Only cast to list when needed to --- pysite/base_route.py | 8 +++-- pysite/constants.py | 10 ++++-- pysite/decorators.py | 75 +++++++++++++++++++++++++++++++++++++++++--- pysite/views/api/bot/tag.py | 8 ++--- pysite/views/api/bot/user.py | 36 ++++++++++----------- 5 files changed, 106 insertions(+), 31 deletions(-) (limited to 'pysite') diff --git a/pysite/base_route.py b/pysite/base_route.py index 5c197a42..8e6648ee 100644 --- a/pysite/base_route.py +++ b/pysite/base_route.py @@ -112,8 +112,12 @@ class APIView(RouteView): elif error_code is ErrorCodes.invalid_api_key: data["error_message"] = "Invalid API-key" http_code = 401 - elif error_code is ErrorCodes.missing_parameters: - data["error_message"] = "Not all required parameters were provided" + elif error_code is ErrorCodes.bad_data_format: + data["error_message"] = "Input data in incorrect format" + http_code = 400 + elif error_code is ErrorCodes.incorrect_parameters: + data["error_message"] = "Incorrect parameters provided" + http_code = 400 response = jsonify(data) response.status_code = http_code diff --git a/pysite/constants.py b/pysite/constants.py index 0c9c8ecb..59febcc9 100644 --- a/pysite/constants.py +++ b/pysite/constants.py @@ -1,13 +1,19 @@ # coding=utf-8 -from enum import IntEnum +from enum import Enum, IntEnum class ErrorCodes(IntEnum): unknown_route = 0 unauthorized = 1 invalid_api_key = 2 - missing_parameters = 3 + incorrect_parameters = 3 + bad_data_format = 4 + + +class ValidationTypes(Enum): + json = "json" + params = "params" OWNER_ROLE = 267627879762755584 diff --git a/pysite/decorators.py b/pysite/decorators.py index 8d0cf7f4..c404b375 100644 --- a/pysite/decorators.py +++ b/pysite/decorators.py @@ -1,13 +1,16 @@ # coding=utf-8 import os from functools import wraps +from json import JSONDecodeError from flask import request -from pysite.constants import ErrorCodes +from schema import Schema, SchemaError +from pysite.constants import ErrorCodes, ValidationTypes -def valid_api_key(f): + +def api_key(f): """ Decorator to check if X-API-Key is valid. @@ -15,9 +18,73 @@ def valid_api_key(f): """ @wraps(f) - def has_valid_api_key(self, *args, **kwargs): + def inner(self, *args, **kwargs): if not request.headers.get("X-API-Key") == os.environ.get("API_KEY"): return self.error(ErrorCodes.invalid_api_key) return f(self, *args, **kwargs) - return has_valid_api_key + return inner + + +def api_params(schema: Schema, validation_type: ValidationTypes = ValidationTypes.json): + """ + Validate parameters of data passed to the decorated view. + + Should only be applied to functions on APIView routes. + + This will pass the validated data in as the first parameter to the decorated function. + This data will always be a list, and view functions are expected to be able to handle that + in the case of multiple sets of data being provided by the api. + """ + + def inner_decorator(f): + + @wraps(f) + def inner(self, *args, **kwargs): + if validation_type == ValidationTypes.json: + try: + if not request.is_json: + return self.error(ErrorCodes.bad_data_format) + + data = list(request.get_json()) + except JSONDecodeError: + return self.error(ErrorCodes.bad_data_format) + + elif validation_type == ValidationTypes.params: + # I really don't like this section here, but I can't think of a better way to do it + multi = request.args # This is a MultiDict, which should be flattened to a list of dicts + + # We'll assume that there's always an equal number of values for each param + # Anything else doesn't really make sense anyway + data = [] + longest = None + + for key, items in multi.lists(): + # Make sure every key has the same number of values + if longest is None: + # First iteration, store it + longest = len(items) + + elif len(items) != longest: + # At least one key has a different number of values + return self.error(ErrorCodes.bad_data_format) + + for i in range(longest): # Now we know all keys have the same number of values... + obj = {} # New dict to store this set of values + + for key, items in multi.lists(): + obj[key] = items[i] # Store the item at that specific index + + data.append(obj) + + else: + raise ValueError(f"Unknown validation type: {validation_type}") + + try: + schema.validate(data) + except SchemaError: + return self.error(ErrorCodes.incorrect_parameters) + + return f(self, data, *args, **kwargs) + return inner + return inner_decorator diff --git a/pysite/views/api/bot/tag.py b/pysite/views/api/bot/tag.py index b2ce145b..363f98fe 100644 --- a/pysite/views/api/bot/tag.py +++ b/pysite/views/api/bot/tag.py @@ -4,7 +4,7 @@ from flask import jsonify, request from pysite.base_route import APIView, DBViewMixin from pysite.constants import ErrorCodes -from pysite.decorators import valid_api_key +from pysite.decorators import api_key class TagView(APIView, DBViewMixin): @@ -13,7 +13,7 @@ class TagView(APIView, DBViewMixin): table_name = "tag" table_primary_key = "tag_name" - @valid_api_key + @api_key def get(self): """ Data must be provided as params, @@ -29,7 +29,7 @@ class TagView(APIView, DBViewMixin): return jsonify(data) - @valid_api_key + @api_key def post(self): """ Data must be provided as JSON. @@ -51,6 +51,6 @@ class TagView(APIView, DBViewMixin): } ) else: - return self.error(ErrorCodes.missing_parameters) + return self.error(ErrorCodes.incorrect_parameters) return jsonify({"success": True}) diff --git a/pysite/views/api/bot/user.py b/pysite/views/api/bot/user.py index aad58e05..5e9dc444 100644 --- a/pysite/views/api/bot/user.py +++ b/pysite/views/api/bot/user.py @@ -1,10 +1,20 @@ # coding=utf-8 -from flask import jsonify, request +from flask import jsonify + +from schema import Schema from pysite.base_route import APIView, DBViewMixin -from pysite.constants import ErrorCodes -from pysite.decorators import valid_api_key +from pysite.constants import ValidationTypes +from pysite.decorators import api_key, api_params + + +SCHEMA = Schema([ + { + "user_id": int, + "role": int + } +]) REQUIRED_KEYS = [ "user_id", @@ -18,24 +28,12 @@ class UserView(APIView, DBViewMixin): table_name = "users" table_primary_key = "user_id" - @valid_api_key - def post(self): - data = request.get_json() - - if not isinstance(data, list): - data = [data] - + @api_key + @api_params(schema=SCHEMA, validation_type=ValidationTypes.json) + def post(self, data): for user in data: - if not all(k in user for k in REQUIRED_KEYS): - print(user) - return self.error(ErrorCodes.missing_parameters) - self.db.insert( - self.table_name, - { - "user_id": user["user_id"], - "role": user["role"], - }, + self.table_name, user, conflict="update", durability="soft" ) -- cgit v1.2.3