From d0e09d2ba567f23d91ac76d1844966bafb9b063a Mon Sep 17 00:00:00 2001 From: Joe Banks Date: Sun, 7 Jul 2024 02:29:26 +0100 Subject: Apply fixable lint settings with Ruff --- backend/routes/forms/unittesting.py | 68 +++++++++++++++++++++---------------- 1 file changed, 38 insertions(+), 30 deletions(-) (limited to 'backend/routes/forms/unittesting.py') diff --git a/backend/routes/forms/unittesting.py b/backend/routes/forms/unittesting.py index a02afea..3239d35 100644 --- a/backend/routes/forms/unittesting.py +++ b/backend/routes/forms/unittesting.py @@ -1,7 +1,8 @@ import base64 -from collections import namedtuple from itertools import count +from pathlib import Path from textwrap import indent +from typing import NamedTuple import httpx from httpx import HTTPStatusError @@ -9,7 +10,7 @@ from httpx import HTTPStatusError from backend.constants import SNEKBOX_URL from backend.models import Form, FormResponse -with open("resources/unittest_template.py") as file: +with Path("resources/unittest_template.py").open(encoding="utf8") as file: TEST_TEMPLATE = file.read() @@ -17,9 +18,12 @@ class BypassDetectedError(Exception): """Detected an attempt at bypassing the unittests.""" -UnittestResult = namedtuple( - "UnittestResult", "question_id question_index return_code passed result" -) +class UnittestResult(NamedTuple): + question_id: str + question_index: int + return_code: int + passed: bool + result: str def filter_unittests(form: Form) -> Form: @@ -46,11 +50,11 @@ def _make_unit_code(units: dict[str, str]) -> str: elif unit_name == "tearDown": result += "\ndef tearDown(self):" else: - name = f"test_{unit_name.removeprefix('#').removeprefix('test_')}" + name = f"test_{unit_name.removeprefix("#").removeprefix("test_")}" result += f"\nasync def {name}(self):" # Unite code - result += f"\n{indent(unit_code, ' ')}" + result += f"\n{indent(unit_code, " ")}" return indent(result, " ") @@ -72,7 +76,8 @@ async def _post_eval(code: str) -> dict[str, str]: async def execute_unittest( - form_response: FormResponse, form: Form + form_response: FormResponse, + form: Form, ) -> tuple[list[UnittestResult], list[BypassDetectedError]]: """Execute all the unittests in this form and return the results.""" unittest_results = [] @@ -80,16 +85,17 @@ async def execute_unittest( for index, question in enumerate(form.questions): if question.type == "code": - # Exit early if the suite doesn't have any tests if question.data["unittests"] is None: - unittest_results.append(UnittestResult( - question_id=question.id, - question_index=index, - return_code=0, - passed=True, - result="" - )) + unittest_results.append( + UnittestResult( + question_id=question.id, + question_index=index, + return_code=0, + passed=True, + result="", + ) + ) continue passed = False @@ -98,7 +104,7 @@ async def execute_unittest( hidden_test_counter = count(1) hidden_tests = { test.removeprefix("#").removeprefix("test_"): next(hidden_test_counter) - for test in question.data["unittests"]["tests"].keys() + for test in question.data["unittests"]["tests"] if test.startswith("#") } @@ -124,18 +130,18 @@ async def execute_unittest( try: passed = bool(int(stdout[0])) except ValueError: - raise BypassDetectedError("Detected a bypass when reading result code.") + msg = "Detected a bypass when reading result code." + raise BypassDetectedError(msg) if passed and stdout.strip() != "1": # Most likely a bypass attempt # A 1 was written to stdout to indicate success, # followed by the actual output - raise BypassDetectedError( - "Detected improper value for stdout in unittest." - ) + msg = "Detected improper value for stdout in unittest." + raise BypassDetectedError(msg) # If the test failed, we have to populate the result string. - elif not passed: + if not passed: failed_tests = stdout[1:].strip().split(";") # Redact failed hidden tests @@ -146,7 +152,7 @@ async def execute_unittest( result = ";".join(failed_tests) else: result = "" - elif return_code in (5, 6, 99): + elif return_code in {5, 6, 99}: result = response["stdout"] # Killed by NsJail elif return_code == 137: @@ -162,12 +168,14 @@ async def execute_unittest( errors.append(error) passed = False - unittest_results.append(UnittestResult( - question_id=question.id, - question_index=index, - return_code=return_code, - passed=passed, - result=result - )) + unittest_results.append( + UnittestResult( + question_id=question.id, + question_index=index, + return_code=return_code, + passed=passed, + result=result, + ) + ) return unittest_results, errors -- cgit v1.2.3