Timothy Jaeryang Baek 2 月之前
父節點
當前提交
e628bfe6ff
共有 1 個文件被更改,包括 30 次插入9 次删除
  1. 30 9
      backend/open_webui/utils/code_interpreter.py

+ 30 - 9
backend/open_webui/utils/code_interpreter.py

@@ -7,7 +7,6 @@ from typing import Optional
 import aiohttp
 import aiohttp
 import websockets
 import websockets
 from pydantic import BaseModel
 from pydantic import BaseModel
-from websockets import ClientConnection
 
 
 from open_webui.env import SRC_LOG_LEVELS
 from open_webui.env import SRC_LOG_LEVELS
 
 
@@ -30,7 +29,14 @@ class JupyterCodeExecuter:
     Execute code in jupyter notebook
     Execute code in jupyter notebook
     """
     """
 
 
-    def __init__(self, base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60):
+    def __init__(
+        self,
+        base_url: str,
+        code: str,
+        token: str = "",
+        password: str = "",
+        timeout: int = 60,
+    ):
         """
         """
         :param base_url: Jupyter server URL (e.g., "http://localhost:8888")
         :param base_url: Jupyter server URL (e.g., "http://localhost:8888")
         :param code: Code to execute
         :param code: Code to execute
@@ -54,7 +60,9 @@ class JupyterCodeExecuter:
     async def __aexit__(self, exc_type, exc_val, exc_tb):
     async def __aexit__(self, exc_type, exc_val, exc_tb):
         if self.kernel_id:
         if self.kernel_id:
             try:
             try:
-                async with self.session.delete(f"/api/kernels/{self.kernel_id}", params=self.params) as response:
+                async with self.session.delete(
+                    f"/api/kernels/{self.kernel_id}", params=self.params
+                ) as response:
                     response.raise_for_status()
                     response.raise_for_status()
             except Exception as err:
             except Exception as err:
                 logger.exception("close kernel failed, %s", err)
                 logger.exception("close kernel failed, %s", err)
@@ -81,7 +89,9 @@ class JupyterCodeExecuter:
                 self.session.cookie_jar.update_cookies(response.cookies)
                 self.session.cookie_jar.update_cookies(response.cookies)
                 self.session.headers.update({"X-XSRFToken": xsrf_token})
                 self.session.headers.update({"X-XSRFToken": xsrf_token})
             async with self.session.post(
             async with self.session.post(
-                "/login", data={"_xsrf": xsrf_token, "password": self.password}, allow_redirects=False
+                "/login",
+                data={"_xsrf": xsrf_token, "password": self.password},
+                allow_redirects=False,
             ) as response:
             ) as response:
                 response.raise_for_status()
                 response.raise_for_status()
                 self.session.cookie_jar.update_cookies(response.cookies)
                 self.session.cookie_jar.update_cookies(response.cookies)
@@ -91,7 +101,9 @@ class JupyterCodeExecuter:
             self.params.update({"token": self.token})
             self.params.update({"token": self.token})
 
 
     async def init_kernel(self) -> None:
     async def init_kernel(self) -> None:
-        async with self.session.post(url="/api/kernels", params=self.params) as response:
+        async with self.session.post(
+            url="/api/kernels", params=self.params
+        ) as response:
             response.raise_for_status()
             response.raise_for_status()
             kernel_data = await response.json()
             kernel_data = await response.json()
             self.kernel_id = kernel_data["id"]
             self.kernel_id = kernel_data["id"]
@@ -103,7 +115,12 @@ class JupyterCodeExecuter:
         ws_headers = {}
         ws_headers = {}
         if self.password and not self.token:
         if self.password and not self.token:
             ws_headers = {
             ws_headers = {
-                "Cookie": "; ".join([f"{cookie.key}={cookie.value}" for cookie in self.session.cookie_jar]),
+                "Cookie": "; ".join(
+                    [
+                        f"{cookie.key}={cookie.value}"
+                        for cookie in self.session.cookie_jar
+                    ]
+                ),
                 **self.session.headers,
                 **self.session.headers,
             }
             }
         return websocket_url, ws_headers
         return websocket_url, ws_headers
@@ -112,10 +129,12 @@ class JupyterCodeExecuter:
         # initialize ws
         # initialize ws
         websocket_url, ws_headers = self.init_ws()
         websocket_url, ws_headers = self.init_ws()
         # execute
         # execute
-        async with websockets.connect(websocket_url, additional_headers=ws_headers) as ws:
+        async with websockets.connect(
+            websocket_url, additional_headers=ws_headers
+        ) as ws:
             await self.execute_in_jupyter(ws)
             await self.execute_in_jupyter(ws)
 
 
-    async def execute_in_jupyter(self, ws: ClientConnection) -> None:
+    async def execute_in_jupyter(self, ws) -> None:
         # send message
         # send message
         msg_id = uuid.uuid4().hex
         msg_id = uuid.uuid4().hex
         await ws.send(
         await ws.send(
@@ -184,6 +203,8 @@ class JupyterCodeExecuter:
 async def execute_code_jupyter(
 async def execute_code_jupyter(
     base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60
     base_url: str, code: str, token: str = "", password: str = "", timeout: int = 60
 ) -> dict:
 ) -> dict:
-    async with JupyterCodeExecuter(base_url, code, token, password, timeout) as executor:
+    async with JupyterCodeExecuter(
+        base_url, code, token, password, timeout
+    ) as executor:
         result = await executor.run()
         result = await executor.run()
         return result.model_dump()
         return result.model_dump()