Explorar el Código

refac: code intepreter

Timothy Jaeryang Baek hace 2 meses
padre
commit
a273cba0fb

+ 18 - 23
backend/open_webui/utils/code_interpreter.py

@@ -18,6 +18,7 @@ async def execute_code_jupyter(
     :param password: Jupyter password (optional)
     :param timeout: WebSocket timeout in seconds (default: 10s)
     :return: Dictionary with stdout, stderr, and result
+             - Images are prefixed with "base64:image/png," and separated by newlines if multiple.
     """
     session = requests.Session()  # Maintain cookies
     headers = {}  # Headers for requests
@@ -28,20 +29,15 @@ async def execute_code_jupyter(
             login_url = urljoin(jupyter_url, "/login")
             response = session.get(login_url)
             response.raise_for_status()
-
-            # Retrieve `_xsrf` token
             xsrf_token = session.cookies.get("_xsrf")
             if not xsrf_token:
                 raise ValueError("Failed to fetch _xsrf token")
 
-            # Send login request
             login_data = {"_xsrf": xsrf_token, "password": password}
             login_response = session.post(
                 login_url, data=login_data, cookies=session.cookies
             )
             login_response.raise_for_status()
-
-            # Update headers with `_xsrf`
             headers["X-XSRFToken"] = xsrf_token
         except Exception as e:
             return {
@@ -55,18 +51,15 @@ async def execute_code_jupyter(
     kernel_url = urljoin(jupyter_url, f"/api/kernels{params}")
 
     try:
-        # Include cookies if authenticating with password
         response = session.post(kernel_url, headers=headers, cookies=session.cookies)
         response.raise_for_status()
         kernel_id = response.json()["id"]
 
-        # Construct WebSocket URL
         websocket_url = urljoin(
             jupyter_url.replace("http", "ws"),
             f"/api/kernels/{kernel_id}/channels{params}",
         )
 
-        # **IMPORTANT:** Include authentication cookies for WebSockets
         ws_headers = {}
         if password and not token:
             ws_headers["X-XSRFToken"] = session.cookies.get("_xsrf")
@@ -75,13 +68,10 @@ async def execute_code_jupyter(
                 [f"{name}={value}" for name, value in cookies.items()]
             )
 
-        # Connect to the WebSocket
         async with websockets.connect(
             websocket_url, additional_headers=ws_headers
         ) as ws:
             msg_id = str(uuid.uuid4())
-
-            # Send execution request
             execute_request = {
                 "header": {
                     "msg_id": msg_id,
@@ -105,37 +95,47 @@ async def execute_code_jupyter(
             }
             await ws.send(json.dumps(execute_request))
 
-            # Collect execution results
-            stdout, stderr, result = "", "", None
+            stdout, stderr, result = "", "", []
+
             while True:
                 try:
                     message = await asyncio.wait_for(ws.recv(), timeout)
                     message_data = json.loads(message)
                     if message_data.get("parent_header", {}).get("msg_id") == msg_id:
                         msg_type = message_data.get("msg_type")
+
                         if msg_type == "stream":
                             if message_data["content"]["name"] == "stdout":
                                 stdout += message_data["content"]["text"]
                             elif message_data["content"]["name"] == "stderr":
                                 stderr += message_data["content"]["text"]
+
                         elif msg_type in ("execute_result", "display_data"):
-                            result = message_data["content"]["data"].get(
-                                "text/plain", ""
-                            )
+                            data = message_data["content"]["data"]
+                            if "image/png" in data:
+                                result.append(
+                                    f"data:image/png;base64,{data['image/png']}"
+                                )
+                            elif "text/plain" in data:
+                                result.append(data["text/plain"])
+
                         elif msg_type == "error":
                             stderr += "\n".join(message_data["content"]["traceback"])
+
                         elif (
                             msg_type == "status"
                             and message_data["content"]["execution_state"] == "idle"
                         ):
                             break
+
                 except asyncio.TimeoutError:
                     stderr += "\nExecution timed out."
                     break
+
     except Exception as e:
         return {"stdout": "", "stderr": f"Error: {str(e)}", "result": ""}
+
     finally:
-        # Shutdown the kernel
         if kernel_id:
             requests.delete(
                 f"{kernel_url}/{kernel_id}", headers=headers, cookies=session.cookies
@@ -144,10 +144,5 @@ async def execute_code_jupyter(
     return {
         "stdout": stdout.strip(),
         "stderr": stderr.strip(),
-        "result": result.strip() if result else "",
+        "result": "\n".join(result).strip() if result else "",
     }
-
-
-# Example Usage
-# asyncio.run(execute_code_jupyter("http://localhost:8888", "print('Hello, world!')", token="your-token"))
-# asyncio.run(execute_code_jupyter("http://localhost:8888", "print('Hello, world!')", password="your-password"))

+ 32 - 0
backend/open_webui/utils/middleware.py

@@ -1723,6 +1723,38 @@ async def process_chat_response(
                                                 )
 
                                         output["stdout"] = "\n".join(stdoutLines)
+
+                                    result = output.get("result", "")
+
+                                    if result:
+                                        resultLines = result.split("\n")
+                                        for idx, line in enumerate(resultLines):
+                                            if "data:image/png;base64" in line:
+                                                id = str(uuid4())
+
+                                                # ensure the path exists
+                                                os.makedirs(
+                                                    os.path.join(CACHE_DIR, "images"),
+                                                    exist_ok=True,
+                                                )
+
+                                                image_path = os.path.join(
+                                                    CACHE_DIR,
+                                                    f"images/{id}.png",
+                                                )
+
+                                                with open(image_path, "wb") as f:
+                                                    f.write(
+                                                        base64.b64decode(
+                                                            line.split(",")[1]
+                                                        )
+                                                    )
+
+                                                resultLines[idx] = (
+                                                    f"![Output Image {idx}](/cache/images/{id}.png)"
+                                                )
+
+                                        output["result"] = "\n".join(resultLines)
                         except Exception as e:
                             output = str(e)