summaryrefslogtreecommitdiff
path: root/utils/restricted_exec.py
blob: d0afd0844925e6228d49f491c77ec1b31f736ebe (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import ast
import builtins
from copy import copy

from RestrictedPython import compile_restricted_exec
from RestrictedPython.PrintCollector import PrintCollector
from RestrictedPython import safe_builtins
from RestrictedPython.Eval import (
    default_guarded_getiter,
    default_guarded_getitem,
    default_guarded_getattr
)
from RestrictedPython.Guards import guarded_iter_unpack_sequence


SAFE_MODULES = [
    'base64',
    'datetime',
    'math',
    'random',
    'math',
    'random',
    'time',
    'itertools',
    'functools',
    'operator',
    'collections',
    're',
    'json',
    'decimal',
]

RESTRICTED_GLOBALS = {
    '__builtins__': {
        **safe_builtins.copy(),
        '__import__': builtins.__import__,
    },
    '_print_': PrintCollector,
    '_getattr_': default_guarded_getattr,
    '_getitem_': default_guarded_getitem,
    '_getiter_': default_guarded_getiter,
    '_iter_unpack_sequence_': guarded_iter_unpack_sequence,
}


def _restricted_exec(source, glb):
    result = compile_restricted_exec(source)
    assert result.errors == (), result.errors
    assert result.code is not None
    exec(result.code, glb)
    return glb


def validate_code(code) -> list[str]:
    err = []
    tree = ast.parse(code)
    err_template = 'not allowed, "%s"'
    for node in ast.walk(tree):
        if isinstance(node, ast.ImportFrom):
            if node.module not in SAFE_MODULES:
                err.append(err_template % (
                    'from %s import %s' % (node.module, node.names[0].name)
                ))
        elif isinstance(node, ast.Import):
            err.append(err_template % ('import %s' % node.names[0].name))

    return err


def restricted_exec(code) -> list[str]:
    glb = _restricted_exec(code, copy(RESTRICTED_GLOBALS))

    if '_print' in glb:
        return glb['_print'].txt
    else:
        return []


def getoutput(code: str) -> list[str]:
    err = []
    try:
        err.extend(validate_code(code))
        if not err:
            return restricted_exec(code)

    except Exception as error:
        err.append('%s: %s' % (type(error).__name__, error))

    return ['execution errors:', *err]