summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorStas Medvedev <medvedevsa97@gmail.com>2024-06-12 16:56:07 +0300
committerStas Medvedev <medvedevsa97@gmail.com>2024-06-12 16:56:07 +0300
commit0c1a65570be7f34f12a35da45669676f4479abd4 (patch)
treeeb0905e4a6454d154b679830666e31f066a9fea7
parented49bb17b9e93a1406ab51f7dca5906661863627 (diff)
add utils.restricted_exec
add app.tasks
-rw-r--r--app/main.py28
-rw-r--r--app/tasks.py17
-rw-r--r--utils/__init__.py9
-rw-r--r--utils/restricted_exec.py89
4 files changed, 143 insertions, 0 deletions
diff --git a/app/main.py b/app/main.py
index 59e7a8b..b90fef9 100644
--- a/app/main.py
+++ b/app/main.py
@@ -1,11 +1,15 @@
from typing import Annotated
+import asyncio
+from datetime import datetime, timedelta
from fastapi import FastAPI, Request, Depends
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from starlette.templating import Jinja2Templates
+from pydantic import BaseModel
from utils import get_avatar_urls, get_client_geo
+from app.tasks import restricted_exec_task
templates = Jinja2Templates(directory="templates")
@@ -33,3 +37,27 @@ async def client_addr(
"partials/client_geo.html",
{"request": request, "client_geo": client_geo},
)
+
+
+class RestrictedExecBase(BaseModel):
+ code: str
+
+
+@app.post('/restricted_exec')
+async def post_restricted_exec(
+ body: RestrictedExecBase
+):
+ result = restricted_exec_task.delay(body.code)
+
+ start_time = datetime.now()
+ while datetime.now() - start_time < timedelta(seconds=5):
+ if result.ready():
+ return result.get()
+
+ await asyncio.sleep(0.1)
+
+ result.revoke(terminate=True, signal='SIGTERM')
+ return [
+ 'Execution timeout, task revoked',
+ result.status
+ ]
diff --git a/app/tasks.py b/app/tasks.py
new file mode 100644
index 0000000..bba3fae
--- /dev/null
+++ b/app/tasks.py
@@ -0,0 +1,17 @@
+from celery import Celery
+
+from utils import restricted_exec
+
+app = Celery(
+ 'tasks',
+ broker='redis://localhost:6379/0',
+ backend='redis://localhost:6379/0',
+ task_send_sent_event=True,
+ worker_send_task_events=True,
+ worker_enable_remote_control=True,
+)
+
+
+@app.task
+def restricted_exec_task(code: str):
+ return restricted_exec.getoutput(code)
diff --git a/utils/__init__.py b/utils/__init__.py
index 22d7058..3f41882 100644
--- a/utils/__init__.py
+++ b/utils/__init__.py
@@ -4,6 +4,15 @@ from typing import Annotated
from fastapi import Request, Header, Depends
import httpx
+from . import restricted_exec
+
+__all__ = [
+ 'restricted_exec',
+ 'get_avatar_urls',
+ 'get_client_host',
+ 'get_client_geo'
+]
+
def get_avatar_urls() -> list[str]:
path = Path("./static") / "avatars"
diff --git a/utils/restricted_exec.py b/utils/restricted_exec.py
new file mode 100644
index 0000000..bf5205f
--- /dev/null
+++ b/utils/restricted_exec.py
@@ -0,0 +1,89 @@
+import ast
+import builtins
+from copy import copy
+
+from RestrictedPython import compile_restricted_exec
+from RestrictedPython.PrintCollector import PrintCollector
+from RestrictedPython import safe_builtins
+from RestrictedPython.Eval import (
+ default_guarded_getiter,
+ default_guarded_getitem,
+ default_guarded_getattr
+)
+from RestrictedPython.Guards import guarded_iter_unpack_sequence
+
+
+SAFE_MODULES = [
+ 'math',
+ 'random',
+ 'math',
+ 'random',
+ 'time',
+ 'itertools',
+ 'functools',
+ 'operator',
+ 'collections',
+ 're',
+ 'json',
+ 'decimal',
+]
+
+RESTRICTED_GLOBALS = {
+ '__builtins__': {
+ **safe_builtins.copy(),
+ '__import__': builtins.__import__
+ },
+ '_print_': PrintCollector,
+ '_getattr_': default_guarded_getattr,
+ '_getitem_': default_guarded_getitem,
+ '_getiter_': default_guarded_getiter,
+ '_iter_unpack_sequence_': guarded_iter_unpack_sequence,
+}
+
+
+def _restricted_exec(source, glb):
+ result = compile_restricted_exec(source)
+ assert result.errors == (), result.errors
+ assert result.code is not None
+ exec(result.code, glb)
+ return glb
+
+
+def validate_code(code) -> list[str]:
+ err = []
+ tree = ast.parse(code)
+ for node in ast.walk(tree):
+ if isinstance(node, (ast.Import, ast.ImportFrom)):
+ for node_name_obj in node.names:
+ if node_name_obj.name not in SAFE_MODULES:
+ msg = 'not allowed, %s'
+ if isinstance(node, ast.Import):
+ msg = msg % ('import %s' % node_name_obj.name)
+ else:
+ msg = msg % ('from %s import %s' %
+ (node.module, node_name_obj.name))
+
+ err.append(msg)
+
+ return err
+
+
+def restricted_exec(code) -> list[str]:
+ glb = _restricted_exec(code, copy(RESTRICTED_GLOBALS))
+
+ if '_print' in glb:
+ return glb['_print'].txt
+ else:
+ return []
+
+
+def getoutput(code: str) -> list[str]:
+ err = []
+ try:
+ err.extend(validate_code(code))
+ return restricted_exec(code)
+
+ except Exception as error:
+ err.append('%s: %s' % (type(error).__name__, error))
+
+ return ['execution errors:', *err]