aboutsummaryrefslogtreecommitdiffstats
path: root/tests/test_main.py
blob: 196c1966f71a4eba25dd6fecea21bb45723ff574 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import contextlib
import io
import unittest
from argparse import Namespace
from subprocess import CompletedProcess
from unittest.mock import patch

import snekbox.__main__ as snekbox_main


class ArgParseTests(unittest.TestCase):
    def test_parse_args(self):
        subtests = (
            (
                ["", "code"],
                Namespace(code="code", nsjail_args=[], py_args=["-c"])
            ),
            (
                ["", "code", "--time_limit", "0"],
                Namespace(code="code", nsjail_args=["--time_limit", "0"], py_args=["-c"])
            ),
            (
                ["", "code", "---", "-m", "timeit"],
                Namespace(code="code", nsjail_args=[], py_args=["-m", "timeit"])
            ),
            (
                ["", "code", "--time_limit", "0", "---", "-m", "timeit"],
                Namespace(code="code", nsjail_args=["--time_limit", "0"], py_args=["-m", "timeit"])
            ),
            (
                ["", "code", "--time_limit", "0", "---"],
                Namespace(code="code", nsjail_args=["--time_limit", "0"], py_args=[])
            ),
            (
                ["", "code", "---"],
                Namespace(code="code", nsjail_args=[], py_args=[])
            )
        )

        for argv, expected in subtests:
            with self.subTest(argv=argv, expected=expected), patch("sys.argv", argv):
                args = snekbox_main.parse_args()
                self.assertEqual(args, expected)

    @patch("sys.argv", [""])
    def test_parse_args_code_missing_exits(self):
        with self.assertRaises(SystemExit) as cm:
            with contextlib.redirect_stderr(io.StringIO()) as stderr:
                snekbox_main.parse_args()

        self.assertEqual(cm.exception.code, 2)
        self.assertIn("the following arguments are required: code", stderr.getvalue())


class EntrypointTests(unittest.TestCase):
    @patch("sys.argv", ["", "code"])
    @patch("snekbox.__main__.NsJail", autospec=True)
    def test_main_prints_stdout(self, mock_nsjail):
        mock_nsjail.return_value.python3.return_value = CompletedProcess(
            args=[],
            returncode=0,
            stdout="output",
            stderr=None
        )

        with contextlib.redirect_stdout(io.StringIO()) as stdout:
            snekbox_main.main()

        self.assertEqual(stdout.getvalue(), "output\n")

    @patch("sys.argv", ["", "code"])
    @patch("snekbox.__main__.NsJail", autospec=True)
    def test_main_exits_with_returncode(self, mock_nsjail):
        mock_nsjail.return_value.python3.return_value = CompletedProcess(
            args=[],
            returncode=137,
            stdout="output",
            stderr=None
        )

        with self.assertRaises(SystemExit) as cm:
            snekbox_main.main()

        self.assertEqual(cm.exception.code, 137)

    @patch("sys.argv", ["", "code", "--time_limit", "0", "---", "-m", "timeit"])
    @patch("snekbox.__main__.NsJail", autospec=True)
    def test_main_forwards_args(self, mock_nsjail):
        mock_nsjail.return_value.python3.return_value = CompletedProcess(
            args=[],
            returncode=0,
            stdout="output",
            stderr=None
        )

        snekbox_main.main()

        mock_nsjail.return_value.python3.assert_called_once_with(
            "code", nsjail_args=["--time_limit", "0"], py_args=["-m", "timeit"]
        )