From 0c1a65570be7f34f12a35da45669676f4479abd4 Mon Sep 17 00:00:00 2001 From: Stas Medvedev Date: Wed, 12 Jun 2024 16:56:07 +0300 Subject: add utils.restricted_exec add app.tasks --- utils/__init__.py | 9 +++++ utils/restricted_exec.py | 89 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 98 insertions(+) create mode 100644 utils/restricted_exec.py (limited to 'utils') diff --git a/utils/__init__.py b/utils/__init__.py index 22d7058..3f41882 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -4,6 +4,15 @@ from typing import Annotated from fastapi import Request, Header, Depends import httpx +from . import restricted_exec + +__all__ = [ + 'restricted_exec', + 'get_avatar_urls', + 'get_client_host', + 'get_client_geo' +] + def get_avatar_urls() -> list[str]: path = Path("./static") / "avatars" diff --git a/utils/restricted_exec.py b/utils/restricted_exec.py new file mode 100644 index 0000000..bf5205f --- /dev/null +++ b/utils/restricted_exec.py @@ -0,0 +1,89 @@ +import ast +import builtins +from copy import copy + +from RestrictedPython import compile_restricted_exec +from RestrictedPython.PrintCollector import PrintCollector +from RestrictedPython import safe_builtins +from RestrictedPython.Eval import ( + default_guarded_getiter, + default_guarded_getitem, + default_guarded_getattr +) +from RestrictedPython.Guards import guarded_iter_unpack_sequence + + +SAFE_MODULES = [ + 'math', + 'random', + 'math', + 'random', + 'time', + 'itertools', + 'functools', + 'operator', + 'collections', + 're', + 'json', + 'decimal', +] + +RESTRICTED_GLOBALS = { + '__builtins__': { + **safe_builtins.copy(), + '__import__': builtins.__import__ + }, + '_print_': PrintCollector, + '_getattr_': default_guarded_getattr, + '_getitem_': default_guarded_getitem, + '_getiter_': default_guarded_getiter, + '_iter_unpack_sequence_': guarded_iter_unpack_sequence, +} + + +def _restricted_exec(source, glb): + result = compile_restricted_exec(source) + assert result.errors == (), result.errors + assert result.code is not None + exec(result.code, glb) + return glb + + +def validate_code(code) -> list[str]: + err = [] + tree = ast.parse(code) + for node in ast.walk(tree): + if isinstance(node, (ast.Import, ast.ImportFrom)): + for node_name_obj in node.names: + if node_name_obj.name not in SAFE_MODULES: + msg = 'not allowed, %s' + if isinstance(node, ast.Import): + msg = msg % ('import %s' % node_name_obj.name) + else: + msg = msg % ('from %s import %s' % + (node.module, node_name_obj.name)) + + err.append(msg) + + return err + + +def restricted_exec(code) -> list[str]: + glb = _restricted_exec(code, copy(RESTRICTED_GLOBALS)) + + if '_print' in glb: + return glb['_print'].txt + else: + return [] + + +def getoutput(code: str) -> list[str]: + err = [] + try: + err.extend(validate_code(code)) + return restricted_exec(code) + + except Exception as error: + err.append('%s: %s' % (type(error).__name__, error)) + + return ['execution errors:', *err] -- cgit v1.2.3