Browse Source

fix(jupyter): fix kernel_id not set and optimize code

orenzhang 2 months ago
parent
commit
744ffbb1fb
1 changed files with 66 additions and 85 deletions
  1. 66 85
      backend/open_webui/utils/code_interpreter.py

+ 66 - 85
backend/open_webui/utils/code_interpreter.py

@@ -1,14 +1,15 @@
 import asyncio
 import json
 import uuid
+from typing import Optional
+
+import httpx
 import websockets
-import requests
-from urllib.parse import urljoin
 
 
 async def execute_code_jupyter(
-    jupyter_url, code, token=None, password=None, timeout=10
-):
+    jupyter_url: str, code: str, token: str = None, password: str = None, timeout: int = 60
+) -> Optional[dict]:
     """
     Executes Python code in a Jupyter kernel.
     Supports authentication with a token or password.
@@ -20,80 +21,70 @@ async def execute_code_jupyter(
     :return: Dictionary with stdout, stderr, and result
              - Images are prefixed with "base64:image/png," and separated by newlines if multiple.
     """
-    session = requests.Session()  # Maintain cookies
-    headers = {}  # Headers for requests
 
-    # Authenticate using password
+    jupyter_url = jupyter_url.rstrip("/")
+    client = httpx.AsyncClient(base_url=jupyter_url, timeout=timeout, follow_redirects=True)
+    headers = {}
+
+    # password authentication
     if password and not token:
         try:
-            login_url = urljoin(jupyter_url, "/login")
-            response = session.get(login_url)
+            response = await client.get("/login")
             response.raise_for_status()
-            xsrf_token = session.cookies.get("_xsrf")
+            xsrf_token = response.cookies.get("_xsrf")
             if not xsrf_token:
-                raise ValueError("Failed to fetch _xsrf token")
-
-            login_data = {"_xsrf": xsrf_token, "password": password}
-            login_response = session.post(
-                login_url, data=login_data, cookies=session.cookies
-            )
-            login_response.raise_for_status()
+                raise ValueError("_xsrf token not found")
+            response = await client.post("/login", data={"_xsrf": xsrf_token, "password": password})
+            response.raise_for_status()
             headers["X-XSRFToken"] = xsrf_token
         except Exception as e:
-            return {
-                "stdout": "",
-                "stderr": f"Authentication Error: {str(e)}",
-                "result": "",
-            }
+            return {"stdout": "", "stderr": f"Authentication Error: {str(e)}", "result": ""}
 
-    # Construct API URLs with authentication token if provided
-    params = f"?token={token}" if token else ""
-    kernel_url = urljoin(jupyter_url, f"/api/kernels{params}")
+    # token authentication
+    params = {"token": token} if token else {}
 
+    kernel_id = ""
     try:
-        response = session.post(kernel_url, headers=headers, cookies=session.cookies)
+        response = await client.post(url="/api/kernels", params=params, headers=headers)
         response.raise_for_status()
         kernel_id = response.json()["id"]
 
-        websocket_url = urljoin(
-            jupyter_url.replace("http", "ws"),
-            f"/api/kernels/{kernel_id}/channels{params}",
-        )
-
+        ws_base = jupyter_url.replace("http", "ws")
+        websocket_url = f"{ws_base}/api/kernels/{kernel_id}/channels" + (f"?token={token}" if token else "")
         ws_headers = {}
         if password and not token:
-            ws_headers["X-XSRFToken"] = session.cookies.get("_xsrf")
-            cookies = {name: value for name, value in session.cookies.items()}
-            ws_headers["Cookie"] = "; ".join(
-                [f"{name}={value}" for name, value in cookies.items()]
-            )
+            ws_headers = {
+                "X-XSRFToken": client.cookies.get("_xsrf"),
+                "Cookie": "; ".join([f"{name}={value}" for name, value in client.cookies.items()]),
+            }
 
-        async with websockets.connect(
-            websocket_url, additional_headers=ws_headers
-        ) as ws:
+        async with websockets.connect(websocket_url, additional_headers=ws_headers) as ws:
             msg_id = str(uuid.uuid4())
-            execute_request = {
-                "header": {
-                    "msg_id": msg_id,
-                    "msg_type": "execute_request",
-                    "username": "user",
-                    "session": str(uuid.uuid4()),
-                    "date": "",
-                    "version": "5.3",
-                },
-                "parent_header": {},
-                "metadata": {},
-                "content": {
-                    "code": code,
-                    "silent": False,
-                    "store_history": True,
-                    "user_expressions": {},
-                    "allow_stdin": False,
-                    "stop_on_error": True,
-                },
-                "channel": "shell",
-            }
-            await ws.send(json.dumps(execute_request))
+            await ws.send(
+                json.dumps(
+                    {
+                        "header": {
+                            "msg_id": msg_id,
+                            "msg_type": "execute_request",
+                            "username": "user",
+                            "session": str(uuid.uuid4()),
+                            "date": "",
+                            "version": "5.3",
+                        },
+                        "parent_header": {},
+                        "metadata": {},
+                        "content": {
+                            "code": code,
+                            "silent": False,
+                            "store_history": True,
+                            "user_expressions": {},
+                            "allow_stdin": False,
+                            "stop_on_error": True,
+                        },
+                        "channel": "shell",
+                    }
+                )
+            )
 
             stdout, stderr, result = "", "", []
 
@@ -101,32 +92,27 @@ async def execute_code_jupyter(
                 try:
                     message = await asyncio.wait_for(ws.recv(), timeout)
                     message_data = json.loads(message)
-                    if message_data.get("parent_header", {}).get("msg_id") == msg_id:
-                        msg_type = message_data.get("msg_type")
+                    if message_data.get("parent_header", {}).get("msg_id") != msg_id:
+                        continue
 
-                        if msg_type == "stream":
+                    msg_type = message_data.get("msg_type")
+                    match msg_type:
+                        case "stream":
                             if message_data["content"]["name"] == "stdout":
                                 stdout += message_data["content"]["text"]
                             elif message_data["content"]["name"] == "stderr":
                                 stderr += message_data["content"]["text"]
-
-                        elif msg_type in ("execute_result", "display_data"):
+                        case "execute_result" | "display_data":
                             data = message_data["content"]["data"]
                             if "image/png" in data:
-                                result.append(
-                                    f"data:image/png;base64,{data['image/png']}"
-                                )
+                                result.append(f"data:image/png;base64,{data['image/png']}")
                             elif "text/plain" in data:
                                 result.append(data["text/plain"])
-
-                        elif msg_type == "error":
+                        case "error":
                             stderr += "\n".join(message_data["content"]["traceback"])
-
-                        elif (
-                            msg_type == "status"
-                            and message_data["content"]["execution_state"] == "idle"
-                        ):
-                            break
+                        case "status":
+                            if message_data["content"]["execution_state"] == "idle":
+                                break
 
                 except asyncio.TimeoutError:
                     stderr += "\nExecution timed out."
@@ -137,12 +123,7 @@ async def execute_code_jupyter(
 
     finally:
         if kernel_id:
-            requests.delete(
-                f"{kernel_url}/{kernel_id}", headers=headers, cookies=session.cookies
-            )
+            await client.delete(f"/api/kernels/{kernel_id}", headers=headers, params=params)
+        await client.aclose()
 
-    return {
-        "stdout": stdout.strip(),
-        "stderr": stderr.strip(),
-        "result": "\n".join(result).strip() if result else "",
-    }
+    return {"stdout": stdout.strip(), "stderr": stderr.strip(), "result": "\n".join(result).strip() if result else ""}