|
@@ -7,7 +7,6 @@ from typing import Optional
|
|
import aiohttp
|
|
import aiohttp
|
|
import websockets
|
|
import websockets
|
|
from pydantic import BaseModel
|
|
from pydantic import BaseModel
|
|
-from websockets import ClientConnection
|
|
|
|
|
|
|
|
from open_webui.env import SRC_LOG_LEVELS
|
|
from open_webui.env import SRC_LOG_LEVELS
|
|
|
|
|
|
@@ -30,7 +29,14 @@ class JupyterCodeExecuter:
|
|
Execute code in jupyter notebook
|
|
Execute code in jupyter notebook
|
|
"""
|
|
"""
|
|
|
|
|
|
- def __init__(self, base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60):
|
|
|
|
|
|
+ def __init__(
|
|
|
|
+ self,
|
|
|
|
+ base_url: str,
|
|
|
|
+ code: str,
|
|
|
|
+ token: str = "",
|
|
|
|
+ password: str = "",
|
|
|
|
+ timeout: int = 60,
|
|
|
|
+ ):
|
|
"""
|
|
"""
|
|
:param base_url: Jupyter server URL (e.g., "http://localhost:8888")
|
|
:param base_url: Jupyter server URL (e.g., "http://localhost:8888")
|
|
:param code: Code to execute
|
|
:param code: Code to execute
|
|
@@ -54,7 +60,9 @@ class JupyterCodeExecuter:
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
if self.kernel_id:
|
|
if self.kernel_id:
|
|
try:
|
|
try:
|
|
- async with self.session.delete(f"/api/kernels/{self.kernel_id}", params=self.params) as response:
|
|
|
|
|
|
+ async with self.session.delete(
|
|
|
|
+ f"/api/kernels/{self.kernel_id}", params=self.params
|
|
|
|
+ ) as response:
|
|
response.raise_for_status()
|
|
response.raise_for_status()
|
|
except Exception as err:
|
|
except Exception as err:
|
|
logger.exception("close kernel failed, %s", err)
|
|
logger.exception("close kernel failed, %s", err)
|
|
@@ -81,7 +89,9 @@ class JupyterCodeExecuter:
|
|
self.session.cookie_jar.update_cookies(response.cookies)
|
|
self.session.cookie_jar.update_cookies(response.cookies)
|
|
self.session.headers.update({"X-XSRFToken": xsrf_token})
|
|
self.session.headers.update({"X-XSRFToken": xsrf_token})
|
|
async with self.session.post(
|
|
async with self.session.post(
|
|
- "/login", data={"_xsrf": xsrf_token, "password": self.password}, allow_redirects=False
|
|
|
|
|
|
+ "/login",
|
|
|
|
+ data={"_xsrf": xsrf_token, "password": self.password},
|
|
|
|
+ allow_redirects=False,
|
|
) as response:
|
|
) as response:
|
|
response.raise_for_status()
|
|
response.raise_for_status()
|
|
self.session.cookie_jar.update_cookies(response.cookies)
|
|
self.session.cookie_jar.update_cookies(response.cookies)
|
|
@@ -91,7 +101,9 @@ class JupyterCodeExecuter:
|
|
self.params.update({"token": self.token})
|
|
self.params.update({"token": self.token})
|
|
|
|
|
|
async def init_kernel(self) -> None:
|
|
async def init_kernel(self) -> None:
|
|
- async with self.session.post(url="/api/kernels", params=self.params) as response:
|
|
|
|
|
|
+ async with self.session.post(
|
|
|
|
+ url="/api/kernels", params=self.params
|
|
|
|
+ ) as response:
|
|
response.raise_for_status()
|
|
response.raise_for_status()
|
|
kernel_data = await response.json()
|
|
kernel_data = await response.json()
|
|
self.kernel_id = kernel_data["id"]
|
|
self.kernel_id = kernel_data["id"]
|
|
@@ -103,7 +115,12 @@ class JupyterCodeExecuter:
|
|
ws_headers = {}
|
|
ws_headers = {}
|
|
if self.password and not self.token:
|
|
if self.password and not self.token:
|
|
ws_headers = {
|
|
ws_headers = {
|
|
- "Cookie": "; ".join([f"{cookie.key}={cookie.value}" for cookie in self.session.cookie_jar]),
|
|
|
|
|
|
+ "Cookie": "; ".join(
|
|
|
|
+ [
|
|
|
|
+ f"{cookie.key}={cookie.value}"
|
|
|
|
+ for cookie in self.session.cookie_jar
|
|
|
|
+ ]
|
|
|
|
+ ),
|
|
**self.session.headers,
|
|
**self.session.headers,
|
|
}
|
|
}
|
|
return websocket_url, ws_headers
|
|
return websocket_url, ws_headers
|
|
@@ -112,10 +129,12 @@ class JupyterCodeExecuter:
|
|
# initialize ws
|
|
# initialize ws
|
|
websocket_url, ws_headers = self.init_ws()
|
|
websocket_url, ws_headers = self.init_ws()
|
|
# execute
|
|
# execute
|
|
- async with websockets.connect(websocket_url, additional_headers=ws_headers) as ws:
|
|
|
|
|
|
+ async with websockets.connect(
|
|
|
|
+ websocket_url, additional_headers=ws_headers
|
|
|
|
+ ) as ws:
|
|
await self.execute_in_jupyter(ws)
|
|
await self.execute_in_jupyter(ws)
|
|
|
|
|
|
- async def execute_in_jupyter(self, ws: ClientConnection) -> None:
|
|
|
|
|
|
+ async def execute_in_jupyter(self, ws) -> None:
|
|
# send message
|
|
# send message
|
|
msg_id = uuid.uuid4().hex
|
|
msg_id = uuid.uuid4().hex
|
|
await ws.send(
|
|
await ws.send(
|
|
@@ -184,6 +203,8 @@ class JupyterCodeExecuter:
|
|
async def execute_code_jupyter(
|
|
async def execute_code_jupyter(
|
|
base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60
|
|
base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60
|
|
) -> dict:
|
|
) -> dict:
|
|
- async with JupyterCodeExecuter(base_url, code, token, password, timeout) as executor:
|
|
|
|
|
|
+ async with JupyterCodeExecuter(
|
|
|
|
+ base_url, code, token, password, timeout
|
|
|
|
+ ) as executor:
|
|
result = await executor.run()
|
|
result = await executor.run()
|
|
return result.model_dump()
|
|
return result.model_dump()
|