summaryrefslogtreecommitdiff
path: root/utils/restricted_exec.py
blob: bf5205f97836734ddd0ccf2b6f3a0c38c4ff9afc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import ast
import builtins
from copy import copy

from RestrictedPython import compile_restricted_exec
from RestrictedPython.PrintCollector import PrintCollector
from RestrictedPython import safe_builtins
from RestrictedPython.Eval import (
    default_guarded_getiter,
    default_guarded_getitem,
    default_guarded_getattr
)
from RestrictedPython.Guards import guarded_iter_unpack_sequence


SAFE_MODULES = [
    'math',
    'random',
    'math',
    'random',
    'time',
    'itertools',
    'functools',
    'operator',
    'collections',
    're',
    'json',
    'decimal',
]

RESTRICTED_GLOBALS = {
    '__builtins__': {
        **safe_builtins.copy(),
        '__import__': builtins.__import__
    },
    '_print_': PrintCollector,
    '_getattr_': default_guarded_getattr,
    '_getitem_': default_guarded_getitem,
    '_getiter_': default_guarded_getiter,
    '_iter_unpack_sequence_': guarded_iter_unpack_sequence,
}


def _restricted_exec(source, glb):
    result = compile_restricted_exec(source)
    assert result.errors == (), result.errors
    assert result.code is not None
    exec(result.code, glb)
    return glb


def validate_code(code) -> list[str]:
    err = []
    tree = ast.parse(code)
    for node in ast.walk(tree):
        if isinstance(node, (ast.Import, ast.ImportFrom)):
            for node_name_obj in node.names:
                if node_name_obj.name not in SAFE_MODULES:
                    msg = 'not allowed, %s'
                    if isinstance(node, ast.Import):
                        msg = msg % ('import %s' % node_name_obj.name)
                    else:
                        msg = msg % ('from %s import %s' %
                                     (node.module, node_name_obj.name))

                    err.append(msg)

    return err


def restricted_exec(code) -> list[str]:
    glb = _restricted_exec(code, copy(RESTRICTED_GLOBALS))

    if '_print' in glb:
        return glb['_print'].txt
    else:
        return []


def getoutput(code: str) -> list[str]:
    err = []
    try:
        err.extend(validate_code(code))
        return restricted_exec(code)

    except Exception as error:
        err.append('%s: %s' % (type(error).__name__, error))

    return ['execution errors:', *err]