From 0c1a65570be7f34f12a35da45669676f4479abd4 Mon Sep 17 00:00:00 2001 From: Stas Medvedev Date: Wed, 12 Jun 2024 16:56:07 +0300 Subject: add utils.restricted_exec add app.tasks --- app/main.py | 28 +++++++++++++++ app/tasks.py | 17 +++++++++ utils/__init__.py | 9 +++++ utils/restricted_exec.py | 89 ++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 143 insertions(+) create mode 100644 app/tasks.py create mode 100644 utils/restricted_exec.py diff --git a/app/main.py b/app/main.py index 59e7a8b..b90fef9 100644 --- a/app/main.py +++ b/app/main.py @@ -1,11 +1,15 @@ from typing import Annotated +import asyncio +from datetime import datetime, timedelta from fastapi import FastAPI, Request, Depends from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from starlette.templating import Jinja2Templates +from pydantic import BaseModel from utils import get_avatar_urls, get_client_geo +from app.tasks import restricted_exec_task templates = Jinja2Templates(directory="templates") @@ -33,3 +37,27 @@ async def client_addr( "partials/client_geo.html", {"request": request, "client_geo": client_geo}, ) + + +class RestrictedExecBase(BaseModel): + code: str + + +@app.post('/restricted_exec') +async def post_restricted_exec( + body: RestrictedExecBase +): + result = restricted_exec_task.delay(body.code) + + start_time = datetime.now() + while datetime.now() - start_time < timedelta(seconds=5): + if result.ready(): + return result.get() + + await asyncio.sleep(0.1) + + result.revoke(terminate=True, signal='SIGTERM') + return [ + 'Execution timeout, task revoked', + result.status + ] diff --git a/app/tasks.py b/app/tasks.py new file mode 100644 index 0000000..bba3fae --- /dev/null +++ b/app/tasks.py @@ -0,0 +1,17 @@ +from celery import Celery + +from utils import restricted_exec + +app = Celery( + 'tasks', + broker='redis://localhost:6379/0', + backend='redis://localhost:6379/0', + task_send_sent_event=True, + worker_send_task_events=True, + worker_enable_remote_control=True, +) + + +@app.task +def restricted_exec_task(code: str): + return restricted_exec.getoutput(code) diff --git a/utils/__init__.py b/utils/__init__.py index 22d7058..3f41882 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -4,6 +4,15 @@ from typing import Annotated from fastapi import Request, Header, Depends import httpx +from . import restricted_exec + +__all__ = [ + 'restricted_exec', + 'get_avatar_urls', + 'get_client_host', + 'get_client_geo' +] + def get_avatar_urls() -> list[str]: path = Path("./static") / "avatars" diff --git a/utils/restricted_exec.py b/utils/restricted_exec.py new file mode 100644 index 0000000..bf5205f --- /dev/null +++ b/utils/restricted_exec.py @@ -0,0 +1,89 @@ +import ast +import builtins +from copy import copy + +from RestrictedPython import compile_restricted_exec +from RestrictedPython.PrintCollector import PrintCollector +from RestrictedPython import safe_builtins +from RestrictedPython.Eval import ( + default_guarded_getiter, + default_guarded_getitem, + default_guarded_getattr +) +from RestrictedPython.Guards import guarded_iter_unpack_sequence + + +SAFE_MODULES = [ + 'math', + 'random', + 'math', + 'random', + 'time', + 'itertools', + 'functools', + 'operator', + 'collections', + 're', + 'json', + 'decimal', +] + +RESTRICTED_GLOBALS = { + '__builtins__': { + **safe_builtins.copy(), + '__import__': builtins.__import__ + }, + '_print_': PrintCollector, + '_getattr_': default_guarded_getattr, + '_getitem_': default_guarded_getitem, + '_getiter_': default_guarded_getiter, + '_iter_unpack_sequence_': guarded_iter_unpack_sequence, +} + + +def _restricted_exec(source, glb): + result = compile_restricted_exec(source) + assert result.errors == (), result.errors + assert result.code is not None + exec(result.code, glb) + return glb + + +def validate_code(code) -> list[str]: + err = [] + tree = ast.parse(code) + for node in ast.walk(tree): + if isinstance(node, (ast.Import, ast.ImportFrom)): + for node_name_obj in node.names: + if node_name_obj.name not in SAFE_MODULES: + msg = 'not allowed, %s' + if isinstance(node, ast.Import): + msg = msg % ('import %s' % node_name_obj.name) + else: + msg = msg % ('from %s import %s' % + (node.module, node_name_obj.name)) + + err.append(msg) + + return err + + +def restricted_exec(code) -> list[str]: + glb = _restricted_exec(code, copy(RESTRICTED_GLOBALS)) + + if '_print' in glb: + return glb['_print'].txt + else: + return [] + + +def getoutput(code: str) -> list[str]: + err = [] + try: + err.extend(validate_code(code)) + return restricted_exec(code) + + except Exception as error: + err.append('%s: %s' % (type(error).__name__, error)) + + return ['execution errors:', *err] -- cgit v1.2.3