123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210 |
- import asyncio
- import json
- import logging
- import uuid
- from typing import Optional
- import aiohttp
- import websockets
- from pydantic import BaseModel
- from open_webui.env import SRC_LOG_LEVELS
- logger = logging.getLogger(__name__)
- logger.setLevel(SRC_LOG_LEVELS["MAIN"])
- class ResultModel(BaseModel):
- """
- Execute Code Result Model
- """
- stdout: Optional[str] = ""
- stderr: Optional[str] = ""
- result: Optional[str] = ""
- class JupyterCodeExecuter:
- """
- Execute code in jupyter notebook
- """
- def __init__(
- self,
- base_url: str,
- code: str,
- token: str = "",
- password: str = "",
- timeout: int = 60,
- ):
- """
- :param base_url: Jupyter server URL (e.g., "http://localhost:8888")
- :param code: Code to execute
- :param token: Jupyter authentication token (optional)
- :param password: Jupyter password (optional)
- :param timeout: WebSocket timeout in seconds (default: 60s)
- """
- self.base_url = base_url.rstrip("/")
- self.code = code
- self.token = token
- self.password = password
- self.timeout = timeout
- self.kernel_id = ""
- self.session = aiohttp.ClientSession(base_url=self.base_url)
- self.params = {}
- self.result = ResultModel()
- async def __aenter__(self):
- return self
- async def __aexit__(self, exc_type, exc_val, exc_tb):
- if self.kernel_id:
- try:
- async with self.session.delete(
- f"/api/kernels/{self.kernel_id}", params=self.params
- ) as response:
- response.raise_for_status()
- except Exception as err:
- logger.exception("close kernel failed, %s", err)
- await self.session.close()
- async def run(self) -> ResultModel:
- try:
- await self.sign_in()
- await self.init_kernel()
- await self.execute_code()
- except Exception as err:
- logger.exception("execute code failed, %s", err)
- self.result.stderr = f"Error: {err}"
- return self.result
- async def sign_in(self) -> None:
- # password authentication
- if self.password and not self.token:
- async with self.session.get("/login") as response:
- response.raise_for_status()
- xsrf_token = response.cookies["_xsrf"].value
- if not xsrf_token:
- raise ValueError("_xsrf token not found")
- self.session.cookie_jar.update_cookies(response.cookies)
- self.session.headers.update({"X-XSRFToken": xsrf_token})
- async with self.session.post(
- "/login",
- data={"_xsrf": xsrf_token, "password": self.password},
- allow_redirects=False,
- ) as response:
- response.raise_for_status()
- self.session.cookie_jar.update_cookies(response.cookies)
- # token authentication
- if self.token:
- self.params.update({"token": self.token})
- async def init_kernel(self) -> None:
- async with self.session.post(
- url="/api/kernels", params=self.params
- ) as response:
- response.raise_for_status()
- kernel_data = await response.json()
- self.kernel_id = kernel_data["id"]
- def init_ws(self) -> (str, dict):
- ws_base = self.base_url.replace("http", "ws")
- ws_params = "?" + "&".join([f"{key}={val}" for key, val in self.params.items()])
- websocket_url = f"{ws_base}/api/kernels/{self.kernel_id}/channels{ws_params if len(ws_params) > 1 else ''}"
- ws_headers = {}
- if self.password and not self.token:
- ws_headers = {
- "Cookie": "; ".join(
- [
- f"{cookie.key}={cookie.value}"
- for cookie in self.session.cookie_jar
- ]
- ),
- **self.session.headers,
- }
- return websocket_url, ws_headers
- async def execute_code(self) -> None:
- # initialize ws
- websocket_url, ws_headers = self.init_ws()
- # execute
- async with websockets.connect(
- websocket_url, additional_headers=ws_headers
- ) as ws:
- await self.execute_in_jupyter(ws)
- async def execute_in_jupyter(self, ws) -> None:
- # send message
- msg_id = uuid.uuid4().hex
- await ws.send(
- json.dumps(
- {
- "header": {
- "msg_id": msg_id,
- "msg_type": "execute_request",
- "username": "user",
- "session": uuid.uuid4().hex,
- "date": "",
- "version": "5.3",
- },
- "parent_header": {},
- "metadata": {},
- "content": {
- "code": self.code,
- "silent": False,
- "store_history": True,
- "user_expressions": {},
- "allow_stdin": False,
- "stop_on_error": True,
- },
- "channel": "shell",
- }
- )
- )
- # parse message
- stdout, stderr, result = "", "", []
- while True:
- try:
- # wait for message
- message = await asyncio.wait_for(ws.recv(), self.timeout)
- message_data = json.loads(message)
- # msg id not match, skip
- if message_data.get("parent_header", {}).get("msg_id") != msg_id:
- continue
- # check message type
- msg_type = message_data.get("msg_type")
- match msg_type:
- case "stream":
- if message_data["content"]["name"] == "stdout":
- stdout += message_data["content"]["text"]
- elif message_data["content"]["name"] == "stderr":
- stderr += message_data["content"]["text"]
- case "execute_result" | "display_data":
- data = message_data["content"]["data"]
- if "image/png" in data:
- result.append(f"data:image/png;base64,{data['image/png']}")
- elif "text/plain" in data:
- result.append(data["text/plain"])
- case "error":
- stderr += "\n".join(message_data["content"]["traceback"])
- case "status":
- if message_data["content"]["execution_state"] == "idle":
- break
- except asyncio.TimeoutError:
- stderr += "\nExecution timed out."
- break
- self.result.stdout = stdout.strip()
- self.result.stderr = stderr.strip()
- self.result.result = "\n".join(result).strip() if result else ""
- async def execute_code_jupyter(
- base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60
- ) -> dict:
- async with JupyterCodeExecuter(
- base_url, code, token, password, timeout
- ) as executor:
- result = await executor.run()
- return result.model_dump()
|