diff options
| author | 2018-04-09 10:34:00 +0100 | |
|---|---|---|
| committer | 2018-04-09 10:34:00 +0100 | |
| commit | afc30354493ef346138281c4115118b8b0dde01b (patch) | |
| tree | 9e3ee401d89a461c2372d6f16aebf6baa6367feb /pysite | |
| parent | Contribs too picky, need new payment logos (diff) | |
| parent | Added Python 3 cheat sheet to resources.json (diff) | |
Merge remote-tracking branch 'origin/master'
Diffstat (limited to 'pysite')
| -rw-r--r-- | pysite/base_route.py | 6 | ||||
| -rw-r--r-- | pysite/constants.py | 3 | ||||
| -rw-r--r-- | pysite/database.py | 63 | ||||
| -rw-r--r-- | pysite/decorators.py | 11 | ||||
| -rw-r--r-- | pysite/mixins.py | 6 | ||||
| -rw-r--r-- | pysite/oauth.py | 2 | ||||
| -rw-r--r-- | pysite/views/error_handlers/http_4xx.py | 6 | ||||
| -rw-r--r-- | pysite/views/error_handlers/http_5xx.py | 6 | ||||
| -rw-r--r-- | pysite/views/wiki/edit.py | 35 | ||||
| -rw-r--r-- | pysite/views/wiki/page.py | 13 | ||||
| -rw-r--r-- | pysite/views/wiki/render.py | 4 | 
11 files changed, 101 insertions, 54 deletions
| diff --git a/pysite/base_route.py b/pysite/base_route.py index 494875ed..1d30669d 100644 --- a/pysite/base_route.py +++ b/pysite/base_route.py @@ -6,7 +6,7 @@ from flask import Blueprint, Response, jsonify, render_template, url_for  from flask.views import MethodView  from werkzeug.exceptions import default_exceptions -from pysite.constants import ErrorCodes +from pysite.constants import DEBUG_MODE, ErrorCodes  from pysite.mixins import OauthMixin @@ -52,6 +52,7 @@ class BaseView(MethodView, OauthMixin):          context["view"] = self          context["logged_in"] = self.logged_in          context["static_file"] = self._static_file +        context["debug"] = DEBUG_MODE          return render_template(template_names, **context) @@ -204,4 +205,5 @@ class ErrorView(BaseView):                  else:                      blueprint.errorhandler(code)(cls.as_view(cls.name))          else: -            raise RuntimeError("Error views must have an `error_code` that is either an `int` or an iterable")  # pragma: no cover # noqa: E501 +            raise RuntimeError( +                "Error views must have an `error_code` that is either an `int` or an iterable")  # pragma: no cover # noqa: E501 diff --git a/pysite/constants.py b/pysite/constants.py index 69633127..e41b33bf 100644 --- a/pysite/constants.py +++ b/pysite/constants.py @@ -19,6 +19,8 @@ class ValidationTypes(Enum):      params = "params" +DEBUG_MODE = "FLASK_DEBUG" in environ +  OWNER_ROLE = 267627879762755584  ADMIN_ROLE = 267628507062992896  MODERATOR_ROLE = 267629731250176001 @@ -26,6 +28,7 @@ DEVOPS_ROLE = 409416496733880320  HELPER_ROLE = 267630620367257601  ALL_STAFF_ROLES = (OWNER_ROLE, ADMIN_ROLE, MODERATOR_ROLE, DEVOPS_ROLE) +EDITOR_ROLES = ALL_STAFF_ROLES + (HELPER_ROLE,)  SERVER_ID = 267624335836053506 diff --git a/pysite/database.py b/pysite/database.py index 4c2153fe..82e1e84e 100644 --- a/pysite/database.py +++ b/pysite/database.py @@ -8,17 +8,27 @@ from flask import abort  from rethinkdb.ast import RqlMethodQuery, Table, UserError  from rethinkdb.net import DefaultConnection +ALL_TABLES = { +    # table: primary_key + +    "oauth_data": "id", +    "tags": "tag_name", +    "users": "user_id", +    "wiki": "slug", +} +  class RethinkDB: -    def __init__(self, loop_type: str = "gevent"): +    def __init__(self, loop_type: Optional[str] = "gevent"):          self.host = os.environ.get("RETHINKDB_HOST", "127.0.0.1")          self.port = os.environ.get("RETHINKDB_PORT", "28016")          self.database = os.environ.get("RETHINKDB_DATABASE", "pythondiscord") -        self.log = logging.getLogger() +        self.log = logging.getLogger(__name__)          self.conn = None -        rethinkdb.set_loop_type(loop_type) +        if loop_type: +            rethinkdb.set_loop_type(loop_type)          with self.get_connection(connect_database=False) as conn:              try: @@ -27,7 +37,16 @@ class RethinkDB:              except rethinkdb.RqlRuntimeError:                  self.log.debug(f"Database found: '{self.database}'") -    def get_connection(self, connect_database: bool=True) -> DefaultConnection: +    def create_tables(self) -> int: +        created = 0 + +        for table, primary_key in ALL_TABLES.items(): +            if self.create_table(table, primary_key): +                created += 1 + +        return created + +    def get_connection(self, connect_database: bool = True) -> DefaultConnection:          """          Grab a connection to the RethinkDB server, optionally without selecting a database @@ -63,8 +82,8 @@ class RethinkDB:      # region: Convenience wrappers -    def create_table(self, table_name: str, primary_key: str="id", durability: str="hard", shards: int=1, -                     replicas: Union[int, Dict[str, int]]=1, primary_replica_tag: Optional[str]=None) -> bool: +    def create_table(self, table_name: str, primary_key: str = "id", durability: str = "hard", shards: int = 1, +                     replicas: Union[int, Dict[str, int]] = 1, primary_replica_tag: Optional[str] = None) -> bool:          """          Attempt to create a new table on the current database @@ -106,7 +125,7 @@ class RethinkDB:      def delete(self,                 table_name: str,                 primary_key: Union[str, None] = None, -               durability: str="hard", +               durability: str = "hard",                 return_changes: Union[bool, str] = False) -> dict:          """          Delete one or all documents from a table. This can only delete @@ -170,10 +189,13 @@ class RethinkDB:          :return: The RethinkDB table object for the table          """ +        if table_name not in ALL_TABLES: +            self.log.warning(f"Table not declared in database.py: {table_name}") +          return rethinkdb.table(table_name) -    def run(self, query: RqlMethodQuery, *, new_connection: bool=False, -            connect_database: bool=True, coerce: type=None) -> Union[rethinkdb.Cursor, List, Dict, object]: +    def run(self, query: RqlMethodQuery, *, new_connection: bool = False, +            connect_database: bool = True, coerce: type = None) -> Union[rethinkdb.Cursor, List, Dict, object]:          """          Run a query using a table object obtained from a call to `query()` @@ -215,14 +237,14 @@ class RethinkDB:      # region: RethinkDB wrapper functions      def insert(self, table_name: str, *objects: Dict[str, Any], -               durability: str="hard", -               return_changes: Union[bool, str]=False, +               durability: str = "hard", +               return_changes: Union[bool, str] = False,                 conflict: Union[  # Any of...                     str, Callable[  # ...str, or a callable that...                         [Dict[str, Any], Dict[str, Any]],  # ...takes two dicts with string keys and any values...                         Dict[str, Any]  # ...and returns a dict with string keys and any values                     ] -               ]="error") -> Union[List, Dict]:  # flake8: noqa +               ] = "error") -> Union[List, Dict]:  # flake8: noqa          """          Insert an object or a set of objects into a table @@ -262,7 +284,7 @@ class RethinkDB:          return dict(result) if result else None  # pragma: no cover -    def get_all(self, table_name: str, *keys: str, index: str="id") -> List[Any]: +    def get_all(self, table_name: str, *keys: str, index: str = "id") -> List[Any]:          """          Get a list of documents matching a set of keys, on a specific index @@ -278,7 +300,7 @@ class RethinkDB:              coerce=list          ) -    def wait(self, table_name: str, wait_for: str="all_replicas_ready", timeout: int=0) -> bool: +    def wait(self, table_name: str, wait_for: str = "all_replicas_ready", timeout: int = 0) -> bool:          """          Wait until an operation has happened on a specific table; will block the current function @@ -312,9 +334,9 @@ class RethinkDB:          return result.get("synced", 0) > 0  # pragma: no cover -    def changes(self, table_name: str, squash: Union[bool, int]=False, changefeed_queue_size: int=100_000, -                include_initial: Optional[bool]=None, include_states: bool=False, -                include_types: bool=False) -> Iterator[Dict[str, Any]]: +    def changes(self, table_name: str, squash: Union[bool, int] = False, changefeed_queue_size: int = 100_000, +                include_initial: Optional[bool] = None, include_states: bool = False, +                include_types: bool = False) -> Iterator[Dict[str, Any]]:          """          A complicated function allowing you to follow a changefeed for a specific table @@ -420,8 +442,9 @@ class RethinkDB:              self.query(table_name).without(*selectors)          ) -    def between(self, table_name: str, *, lower: Any=rethinkdb.minval, upper: Any=rethinkdb.maxval, -                index: Optional[str]=None, left_bound: str="closed", right_bound: str ="open") -> List[Dict[str, Any]]: +    def between(self, table_name: str, *, lower: Any = rethinkdb.minval, upper: Any = rethinkdb.maxval, +                index: Optional[str] = None, left_bound: str = "closed", right_bound: str = "open") -> List[ +        Dict[str, Any]]:          """          Get all documents between two keys @@ -483,7 +506,7 @@ class RethinkDB:          )      def filter(self, table_name: str, predicate: Callable[[Dict[str, Any]], bool], -               default: Union[bool, UserError]=False) -> List[Dict[str, Any]]: +               default: Union[bool, UserError] = False) -> List[Dict[str, Any]]:          """          Return all documents in a table for which `predicate` returns true. diff --git a/pysite/decorators.py b/pysite/decorators.py index 426d4846..0b02ebde 100644 --- a/pysite/decorators.py +++ b/pysite/decorators.py @@ -8,7 +8,7 @@ from schema import Schema, SchemaError  from werkzeug.exceptions import Forbidden  from pysite.base_route import APIView, BaseView -from pysite.constants import CSRF, ErrorCodes, ValidationTypes +from pysite.constants import CSRF, DEBUG_MODE, ErrorCodes, ValidationTypes  def csrf(f): @@ -32,9 +32,11 @@ def require_roles(*roles: int):          def inner(self: BaseView, *args, **kwargs):              data = self.user_data -            if data: +            if DEBUG_MODE: +                return f(self, *args, **kwargs) +            elif data:                  for role in roles: -                    if role in data.get("roles", []): +                    if DEBUG_MODE or role in data.get("roles", []):                          return f(self, *args, **kwargs)                  if isinstance(self, APIView): @@ -42,6 +44,7 @@ def require_roles(*roles: int):                  raise Forbidden()              return redirect(url_for("discord.login")) +          return inner      return inner_decorator @@ -128,5 +131,7 @@ def api_params(schema: Schema, validation_type: ValidationTypes = ValidationType                  return self.error(ErrorCodes.incorrect_parameters)              return f(self, data, *args, **kwargs) +          return inner +      return inner_decorator diff --git a/pysite/mixins.py b/pysite/mixins.py index a2730ae4..efbc2d0c 100644 --- a/pysite/mixins.py +++ b/pysite/mixins.py @@ -4,6 +4,7 @@ from weakref import ref  from flask import Blueprint  from rethinkdb.ast import Table +from pysite.constants import DEBUG_MODE  from pysite.database import RethinkDB @@ -51,7 +52,9 @@ class DBMixin:              raise RuntimeError("Routes using DBViewMixin must define `table_name`")          cls._db = ref(manager.db) -        manager.db.create_table(cls.table_name, primary_key=cls.table_primary_key) + +        if DEBUG_MODE: +            manager.db.create_table(cls.table_name, primary_key=cls.table_primary_key)      @property      def table(self) -> Table: @@ -89,7 +92,6 @@ class OauthMixin:      @classmethod      def setup(cls: "OauthMixin", manager: "pysite.route_manager.RouteManager", blueprint: Blueprint): -          if hasattr(super(), "setup"):              super().setup(manager, blueprint)  # pragma: no cover diff --git a/pysite/oauth.py b/pysite/oauth.py index ef86aa8a..86a2024d 100644 --- a/pysite/oauth.py +++ b/pysite/oauth.py @@ -84,4 +84,4 @@ class OauthBackend(BaseBackend):      def logout(self):          sess_id = session.get("session_id")          if sess_id and self.db.get(OAUTH_DATABASE, sess_id):  # If user exists in db, -            self.db.delete(OAUTH_DATABASE, sess_id)           # remove them (at least, their session) +            self.db.delete(OAUTH_DATABASE, sess_id)  # remove them (at least, their session) diff --git a/pysite/views/error_handlers/http_4xx.py b/pysite/views/error_handlers/http_4xx.py index 48ae7f0f..2d6c54c6 100644 --- a/pysite/views/error_handlers/http_4xx.py +++ b/pysite/views/error_handlers/http_4xx.py @@ -11,7 +11,6 @@ class Error400View(ErrorView):      error_code = range(400, 430)      def __init__(self): -          # Direct errors for all methods at self.return_error          methods = [              'get', 'post', 'put', @@ -27,7 +26,6 @@ class Error400View(ErrorView):          return self.render(              "errors/error.html", code=error.code, req=request, error_title=error_desc, -            error_message=error_desc + -            " If you believe we have made a mistake, please " -            "<a href='https://github.com/discord-python/site/issues'>open an issue on our GitHub</a>." +            error_message=f"{error_desc} If you believe we have made a mistake, please " +                          "<a href='https://github.com/discord-python/site/issues'>open an issue on our GitHub</a>."          ), error.code diff --git a/pysite/views/error_handlers/http_5xx.py b/pysite/views/error_handlers/http_5xx.py index 14c016c5..46c65e38 100644 --- a/pysite/views/error_handlers/http_5xx.py +++ b/pysite/views/error_handlers/http_5xx.py @@ -36,7 +36,7 @@ class Error500View(ErrorView):          return self.render(              "errors/error.html", code=error.code, req=request, error_title=error_desc,              error_message="An error occurred while processing this request, please try " -            "again later. If you believe we have made a mistake, please " -            "<a href='https://github.com/discord-python/site/issues'>file an issue on our" -            " GitHub</a>." +                          "again later. If you believe we have made a mistake, please " +                          "<a href='https://github.com/discord-python/site/issues'>file an issue on our" +                          " GitHub</a>."          ), error.code diff --git a/pysite/views/wiki/edit.py b/pysite/views/wiki/edit.py index 1a100b8b..0a0af15b 100644 --- a/pysite/views/wiki/edit.py +++ b/pysite/views/wiki/edit.py @@ -1,9 +1,10 @@  # coding=utf-8 -from flask import url_for +from docutils.core import publish_parts +from flask import request, url_for  from werkzeug.utils import redirect  from pysite.base_route import RouteView -from pysite.constants import ALL_STAFF_ROLES +from pysite.constants import EDITOR_ROLES  from pysite.decorators import csrf, require_roles  from pysite.mixins import DBMixin @@ -15,28 +16,38 @@ class EditView(RouteView, DBMixin):      table_name = "wiki"      table_primary_key = "slug" -    @require_roles(*ALL_STAFF_ROLES) +    @require_roles(*EDITOR_ROLES)      def get(self, page):          rst = ""          title = "" +        preview = "<p>Preview will appear here.</p>"          obj = self.db.get(self.table_name, page)          if obj:              rst = obj["rst"]              title = obj["title"] +            preview = obj["html"] -        return self.render("wiki/page_edit.html", page=page, rst=rst, title=title) +        return self.render("wiki/page_edit.html", page=page, rst=rst, title=title, preview=preview) -    @require_roles(*ALL_STAFF_ROLES) +    @require_roles(*EDITOR_ROLES)      @csrf      def post(self, page): -        # rst = request.form["rst"] -        # obj = { -        #     "slug": page, -        #     "title": request.form["title"], -        #     "rst": request.form["rst"], -        #     "html": "" -        # } +        rst = request.form["rst"] +        obj = { +            "slug": page, +            "title": request.form["title"], +            "rst": rst, +            "html": publish_parts( +                source=rst, writer_name="html5", settings_overrides={"halt_level": 2} +            )["html_body"] +        } + +        self.db.insert( +            self.table_name, +            obj, +            conflict="replace" +        )          return redirect(url_for("wiki.page", page=page), code=303)  # Redirect, ensuring a GET diff --git a/pysite/views/wiki/page.py b/pysite/views/wiki/page.py index 01c8fa8a..a7f60f02 100644 --- a/pysite/views/wiki/page.py +++ b/pysite/views/wiki/page.py @@ -1,8 +1,9 @@  # coding=utf-8  from flask import redirect, url_for +from werkzeug.exceptions import NotFound  from pysite.base_route import RouteView -from pysite.constants import ALL_STAFF_ROLES +from pysite.constants import DEBUG_MODE, EDITOR_ROLES  from pysite.mixins import DBMixin @@ -18,19 +19,21 @@ class PageView(RouteView, DBMixin):          if obj is None:              if self.is_staff(): -                return redirect(url_for("wiki.edit", page=page)) +                return redirect(url_for("wiki.edit", page=page, can_edit=False)) -            return self.render("wiki/page_missing.html", page=page) -        return self.render("wiki/page_view.html", page=page, data=obj) +            raise NotFound() +        return self.render("wiki/page_view.html", page=page, data=obj, can_edit=self.is_staff())      def is_staff(self): +        if DEBUG_MODE: +            return True          if not self.logged_in:              return False          roles = self.user_data.get("roles", [])          for role in roles: -            if role in ALL_STAFF_ROLES: +            if role in EDITOR_ROLES:                  return True          return False diff --git a/pysite/views/wiki/render.py b/pysite/views/wiki/render.py index 73c38731..131db1d3 100644 --- a/pysite/views/wiki/render.py +++ b/pysite/views/wiki/render.py @@ -7,7 +7,7 @@ from flask import jsonify  from schema import Schema  from pysite.base_route import APIView -from pysite.constants import ALL_STAFF_ROLES, ValidationTypes +from pysite.constants import EDITOR_ROLES, ValidationTypes  from pysite.decorators import api_params, csrf, require_roles  SCHEMA = Schema([{ @@ -22,7 +22,7 @@ class RenderView(APIView):      name = "render"      @csrf -    @require_roles(*ALL_STAFF_ROLES) +    @require_roles(*EDITOR_ROLES)      @api_params(schema=SCHEMA, validation_type=ValidationTypes.json)      def post(self, data):          if not len(data): | 
