summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStas Medvedev <medvedevsa97@gmail.com>2024-06-13 02:40:17 +0300
committerStas Medvedev <medvedevsa97@gmail.com>2024-06-13 02:40:17 +0300
commit3281f17b3fc560e40d886b6fb01b9119457f62a6 (patch)
tree1617533ecc4c765766d75711a743d90dcbcf0359
parentd76981975476c561e3164f53d48eea305dd9756a (diff)
add SAFE_MODULE
-rw-r--r--app/main.py3
-rw-r--r--utils/restricted_exec.py26
2 files changed, 14 insertions, 15 deletions
diff --git a/app/main.py b/app/main.py
index b90fef9..ebc4297 100644
--- a/app/main.py
+++ b/app/main.py
@@ -58,6 +58,5 @@ async def post_restricted_exec(
result.revoke(terminate=True, signal='SIGTERM')
return [
- 'Execution timeout, task revoked',
- result.status
+ 'Execution timeout, task revoked'
]
diff --git a/utils/restricted_exec.py b/utils/restricted_exec.py
index bf5205f..d0afd08 100644
--- a/utils/restricted_exec.py
+++ b/utils/restricted_exec.py
@@ -14,6 +14,8 @@ from RestrictedPython.Guards import guarded_iter_unpack_sequence
SAFE_MODULES = [
+ 'base64',
+ 'datetime',
'math',
'random',
'math',
@@ -31,7 +33,7 @@ SAFE_MODULES = [
RESTRICTED_GLOBALS = {
'__builtins__': {
**safe_builtins.copy(),
- '__import__': builtins.__import__
+ '__import__': builtins.__import__,
},
'_print_': PrintCollector,
'_getattr_': default_guarded_getattr,
@@ -52,18 +54,15 @@ def _restricted_exec(source, glb):
def validate_code(code) -> list[str]:
err = []
tree = ast.parse(code)
+ err_template = 'not allowed, "%s"'
for node in ast.walk(tree):
- if isinstance(node, (ast.Import, ast.ImportFrom)):
- for node_name_obj in node.names:
- if node_name_obj.name not in SAFE_MODULES:
- msg = 'not allowed, %s'
- if isinstance(node, ast.Import):
- msg = msg % ('import %s' % node_name_obj.name)
- else:
- msg = msg % ('from %s import %s' %
- (node.module, node_name_obj.name))
-
- err.append(msg)
+ if isinstance(node, ast.ImportFrom):
+ if node.module not in SAFE_MODULES:
+ err.append(err_template % (
+ 'from %s import %s' % (node.module, node.names[0].name)
+ ))
+ elif isinstance(node, ast.Import):
+ err.append(err_template % ('import %s' % node.names[0].name))
return err
@@ -81,7 +80,8 @@ def getoutput(code: str) -> list[str]:
err = []
try:
err.extend(validate_code(code))
- return restricted_exec(code)
+ if not err:
+ return restricted_exec(code)
except Exception as error:
err.append('%s: %s' % (type(error).__name__, error))