diff options
Diffstat (limited to '')
| -rw-r--r-- | pysite/base_route.py | 8 | ||||
| -rw-r--r-- | pysite/constants.py | 10 | ||||
| -rw-r--r-- | pysite/decorators.py | 75 | ||||
| -rw-r--r-- | pysite/views/api/bot/tag.py | 8 | ||||
| -rw-r--r-- | pysite/views/api/bot/user.py | 36 | 
5 files changed, 106 insertions, 31 deletions
| diff --git a/pysite/base_route.py b/pysite/base_route.py index 5c197a42..8e6648ee 100644 --- a/pysite/base_route.py +++ b/pysite/base_route.py @@ -112,8 +112,12 @@ class APIView(RouteView):          elif error_code is ErrorCodes.invalid_api_key:              data["error_message"] = "Invalid API-key"              http_code = 401 -        elif error_code is ErrorCodes.missing_parameters: -            data["error_message"] = "Not all required parameters were provided" +        elif error_code is ErrorCodes.bad_data_format: +            data["error_message"] = "Input data in incorrect format" +            http_code = 400 +        elif error_code is ErrorCodes.incorrect_parameters: +            data["error_message"] = "Incorrect parameters provided" +            http_code = 400          response = jsonify(data)          response.status_code = http_code diff --git a/pysite/constants.py b/pysite/constants.py index 0c9c8ecb..59febcc9 100644 --- a/pysite/constants.py +++ b/pysite/constants.py @@ -1,13 +1,19 @@  # coding=utf-8 -from enum import IntEnum +from enum import Enum, IntEnum  class ErrorCodes(IntEnum):      unknown_route = 0      unauthorized = 1      invalid_api_key = 2 -    missing_parameters = 3 +    incorrect_parameters = 3 +    bad_data_format = 4 + + +class ValidationTypes(Enum): +    json = "json" +    params = "params"  OWNER_ROLE = 267627879762755584 diff --git a/pysite/decorators.py b/pysite/decorators.py index 8d0cf7f4..c404b375 100644 --- a/pysite/decorators.py +++ b/pysite/decorators.py @@ -1,13 +1,16 @@  # coding=utf-8  import os  from functools import wraps +from json import JSONDecodeError  from flask import request -from pysite.constants import ErrorCodes +from schema import Schema, SchemaError +from pysite.constants import ErrorCodes, ValidationTypes -def valid_api_key(f): + +def api_key(f):      """      Decorator to check if X-API-Key is valid. @@ -15,9 +18,73 @@ def valid_api_key(f):      """      @wraps(f) -    def has_valid_api_key(self, *args, **kwargs): +    def inner(self, *args, **kwargs):          if not request.headers.get("X-API-Key") == os.environ.get("API_KEY"):              return self.error(ErrorCodes.invalid_api_key)          return f(self, *args, **kwargs) -    return has_valid_api_key +    return inner + + +def api_params(schema: Schema, validation_type: ValidationTypes = ValidationTypes.json): +    """ +    Validate parameters of data passed to the decorated view. + +    Should only be applied to functions on APIView routes. + +    This will pass the validated data in as the first parameter to the decorated function. +    This data will always be a list, and view functions are expected to be able to handle that +    in the case of multiple sets of data being provided by the api. +    """ + +    def inner_decorator(f): + +        @wraps(f) +        def inner(self, *args, **kwargs): +            if validation_type == ValidationTypes.json: +                try: +                    if not request.is_json: +                        return self.error(ErrorCodes.bad_data_format) + +                    data = list(request.get_json()) +                except JSONDecodeError: +                    return self.error(ErrorCodes.bad_data_format) + +            elif validation_type == ValidationTypes.params: +                # I really don't like this section here, but I can't think of a better way to do it +                multi = request.args  # This is a MultiDict, which should be flattened to a list of dicts + +                # We'll assume that there's always an equal number of values for each param +                # Anything else doesn't really make sense anyway +                data = [] +                longest = None + +                for key, items in multi.lists(): +                    # Make sure every key has the same number of values +                    if longest is None: +                        # First iteration, store it +                        longest = len(items) + +                    elif len(items) != longest: +                        # At least one key has a different number of values +                        return self.error(ErrorCodes.bad_data_format) + +                for i in range(longest):  # Now we know all keys have the same number of values... +                    obj = {}  # New dict to store this set of values + +                    for key, items in multi.lists(): +                        obj[key] = items[i]  # Store the item at that specific index + +                    data.append(obj) + +            else: +                raise ValueError(f"Unknown validation type: {validation_type}") + +            try: +                schema.validate(data) +            except SchemaError: +                return self.error(ErrorCodes.incorrect_parameters) + +            return f(self, data, *args, **kwargs) +        return inner +    return inner_decorator diff --git a/pysite/views/api/bot/tag.py b/pysite/views/api/bot/tag.py index b2ce145b..363f98fe 100644 --- a/pysite/views/api/bot/tag.py +++ b/pysite/views/api/bot/tag.py @@ -4,7 +4,7 @@ from flask import jsonify, request  from pysite.base_route import APIView, DBViewMixin  from pysite.constants import ErrorCodes -from pysite.decorators import valid_api_key +from pysite.decorators import api_key  class TagView(APIView, DBViewMixin): @@ -13,7 +13,7 @@ class TagView(APIView, DBViewMixin):      table_name = "tag"      table_primary_key = "tag_name" -    @valid_api_key +    @api_key      def get(self):          """          Data must be provided as params, @@ -29,7 +29,7 @@ class TagView(APIView, DBViewMixin):          return jsonify(data) -    @valid_api_key +    @api_key      def post(self):          """          Data must be provided as JSON. @@ -51,6 +51,6 @@ class TagView(APIView, DBViewMixin):                  }              )          else: -            return self.error(ErrorCodes.missing_parameters) +            return self.error(ErrorCodes.incorrect_parameters)          return jsonify({"success": True}) diff --git a/pysite/views/api/bot/user.py b/pysite/views/api/bot/user.py index aad58e05..5e9dc444 100644 --- a/pysite/views/api/bot/user.py +++ b/pysite/views/api/bot/user.py @@ -1,10 +1,20 @@  # coding=utf-8 -from flask import jsonify, request +from flask import jsonify + +from schema import Schema  from pysite.base_route import APIView, DBViewMixin -from pysite.constants import ErrorCodes -from pysite.decorators import valid_api_key +from pysite.constants import ValidationTypes +from pysite.decorators import api_key, api_params + + +SCHEMA = Schema([ +    { +        "user_id": int, +        "role": int +    } +])  REQUIRED_KEYS = [      "user_id", @@ -18,24 +28,12 @@ class UserView(APIView, DBViewMixin):      table_name = "users"      table_primary_key = "user_id" -    @valid_api_key -    def post(self): -        data = request.get_json() - -        if not isinstance(data, list): -            data = [data] - +    @api_key +    @api_params(schema=SCHEMA, validation_type=ValidationTypes.json) +    def post(self, data):          for user in data: -            if not all(k in user for k in REQUIRED_KEYS): -                print(user) -                return self.error(ErrorCodes.missing_parameters) -              self.db.insert( -                self.table_name, -                { -                    "user_id": user["user_id"], -                    "role": user["role"], -                }, +                self.table_name, user,                  conflict="update",                  durability="soft"              ) | 
