Browse Source

refac: code interpreter

Timothy Jaeryang Baek 2 months ago
parent
commit
6ee924924e

+ 3 - 1
backend/open_webui/config.py

@@ -1335,8 +1335,10 @@ DEFAULT_CODE_INTERPRETER_PROMPT = """
    - When coding, **always aim to print meaningful outputs** (e.g., results, tables, summaries, or visuals) to better interpret and verify the findings. Avoid relying on implicit outputs; prioritize explicit and clear print statements so the results are effectively communicated to the user.  
    - When coding, **always aim to print meaningful outputs** (e.g., results, tables, summaries, or visuals) to better interpret and verify the findings. Avoid relying on implicit outputs; prioritize explicit and clear print statements so the results are effectively communicated to the user.  
    - After obtaining the printed output, **always provide a concise analysis, interpretation, or next steps to help the user understand the findings or refine the outcome further.**  
    - After obtaining the printed output, **always provide a concise analysis, interpretation, or next steps to help the user understand the findings or refine the outcome further.**  
    - If the results are unclear, unexpected, or require validation, refine the code and execute it again as needed. Always aim to deliver meaningful insights from the results, iterating if necessary.  
    - If the results are unclear, unexpected, or require validation, refine the code and execute it again as needed. Always aim to deliver meaningful insights from the results, iterating if necessary.  
+   - If a link is provided for an image, audio, or any file, include it in the response exactly as given to ensure the user has access to the original resource.  
    - All responses should be communicated in the chat's primary language, ensuring seamless understanding. If the chat is multilingual, default to English for clarity.
    - All responses should be communicated in the chat's primary language, ensuring seamless understanding. If the chat is multilingual, default to English for clarity.
-
+   - **If a link to an image, audio, or any file is provided in markdown format, explicitly display it as part of the response to ensure the user can access it easily.**
+   
 Ensure that the tools are effectively utilized to achieve the highest-quality analysis for the user."""
 Ensure that the tools are effectively utilized to achieve the highest-quality analysis for the user."""
 
 
 
 

+ 99 - 60
backend/open_webui/utils/middleware.py

@@ -1,6 +1,8 @@
 import time
 import time
 import logging
 import logging
 import sys
 import sys
+import os
+import base64
 
 
 import asyncio
 import asyncio
 from aiocache import cached
 from aiocache import cached
@@ -69,6 +71,7 @@ from open_webui.utils.plugin import load_function_module_by_id
 from open_webui.tasks import create_task
 from open_webui.tasks import create_task
 
 
 from open_webui.config import (
 from open_webui.config import (
+    CACHE_DIR,
     DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
     DEFAULT_TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
     DEFAULT_CODE_INTERPRETER_PROMPT,
     DEFAULT_CODE_INTERPRETER_PROMPT,
 )
 )
@@ -1241,7 +1244,9 @@ async def process_chat_response(
 
 
             # We might want to disable this by default
             # We might want to disable this by default
             DETECT_REASONING = True
             DETECT_REASONING = True
-            DETECT_CODE_INTERPRETER = True
+            DETECT_CODE_INTERPRETER = metadata.get("features", {}).get(
+                "code_interpreter", False
+            )
 
 
             reasoning_tags = ["think", "reason", "reasoning", "thought", "Thought"]
             reasoning_tags = ["think", "reason", "reasoning", "thought", "Thought"]
             code_interpreter_tags = ["code_interpreter"]
             code_interpreter_tags = ["code_interpreter"]
@@ -1386,74 +1391,108 @@ async def process_chat_response(
 
 
                 await stream_body_handler(response)
                 await stream_body_handler(response)
 
 
-                MAX_RETRIES = 5
-                retries = 0
+                if DETECT_CODE_INTERPRETER:
+                    MAX_RETRIES = 5
+                    retries = 0
 
 
-                while (
-                    content_blocks[-1]["type"] == "code_interpreter"
-                    and retries < MAX_RETRIES
-                ):
-                    retries += 1
-                    log.debug(f"Attempt count: {retries}")
+                    while (
+                        content_blocks[-1]["type"] == "code_interpreter"
+                        and retries < MAX_RETRIES
+                    ):
+                        retries += 1
+                        log.debug(f"Attempt count: {retries}")
 
 
-                    output = ""
-                    try:
-                        if content_blocks[-1]["attributes"].get("type") == "code":
-                            output = await event_caller(
-                                {
-                                    "type": "execute:python",
-                                    "data": {
-                                        "id": str(uuid4()),
-                                        "code": content_blocks[-1]["content"],
-                                    },
-                                }
-                            )
-                    except Exception as e:
-                        output = str(e)
+                        output = ""
+                        try:
+                            if content_blocks[-1]["attributes"].get("type") == "code":
+                                output = await event_caller(
+                                    {
+                                        "type": "execute:python",
+                                        "data": {
+                                            "id": str(uuid4()),
+                                            "code": content_blocks[-1]["content"],
+                                        },
+                                    }
+                                )
 
 
-                    content_blocks[-1]["output"] = output
-                    content_blocks.append(
-                        {
-                            "type": "text",
-                            "content": "",
-                        }
-                    )
+                                if isinstance(output, dict):
+                                    stdout = output.get("stdout", "")
+
+                                    if stdout:
+                                        stdoutLines = stdout.split("\n")
+                                        for idx, line in enumerate(stdoutLines):
+                                            if "data:image/png;base64" in line:
+                                                id = str(uuid4())
+
+                                                # ensure the path exists
+                                                os.makedirs(
+                                                    os.path.join(CACHE_DIR, "images"),
+                                                    exist_ok=True,
+                                                )
+
+                                                image_path = os.path.join(
+                                                    CACHE_DIR,
+                                                    f"images/{id}.png",
+                                                )
+
+                                                with open(image_path, "wb") as f:
+                                                    f.write(
+                                                        base64.b64decode(
+                                                            line.split(",")[1]
+                                                        )
+                                                    )
+
+                                                stdoutLines[idx] = (
+                                                    f"![Output Image {idx}](/cache/images/{id}.png)"
+                                                )
+
+                                        output["stdout"] = "\n".join(stdoutLines)
+                        except Exception as e:
+                            output = str(e)
 
 
-                    await event_emitter(
-                        {
-                            "type": "chat:completion",
-                            "data": {
-                                "content": serialize_content_blocks(content_blocks),
-                            },
-                        }
-                    )
+                        content_blocks[-1]["output"] = output
+                        content_blocks.append(
+                            {
+                                "type": "text",
+                                "content": "",
+                            }
+                        )
 
 
-                    try:
-                        res = await generate_chat_completion(
-                            request,
+                        await event_emitter(
                             {
                             {
-                                "model": model_id,
-                                "stream": True,
-                                "messages": [
-                                    *form_data["messages"],
-                                    {
-                                        "role": "assistant",
-                                        "content": serialize_content_blocks(
-                                            content_blocks, raw=True
-                                        ),
-                                    },
-                                ],
-                            },
-                            user,
+                                "type": "chat:completion",
+                                "data": {
+                                    "content": serialize_content_blocks(content_blocks),
+                                },
+                            }
                         )
                         )
 
 
-                        if isinstance(res, StreamingResponse):
-                            await stream_body_handler(res)
-                        else:
+                        try:
+                            res = await generate_chat_completion(
+                                request,
+                                {
+                                    "model": model_id,
+                                    "stream": True,
+                                    "messages": [
+                                        *form_data["messages"],
+                                        {
+                                            "role": "assistant",
+                                            "content": serialize_content_blocks(
+                                                content_blocks, raw=True
+                                            ),
+                                        },
+                                    ],
+                                },
+                                user,
+                            )
+
+                            if isinstance(res, StreamingResponse):
+                                await stream_body_handler(res)
+                            else:
+                                break
+                        except Exception as e:
+                            log.debug(e)
                             break
                             break
-                    except Exception as e:
-                        log.debug(e)
-                        break
 
 
                 title = Chats.get_chat_title_by_id(metadata["chat_id"])
                 title = Chats.get_chat_title_by_id(metadata["chat_id"])
                 data = {
                 data = {

+ 42 - 5
src/lib/components/chat/Messages/CodeBlock.svelte

@@ -50,6 +50,7 @@
 	let stdout = null;
 	let stdout = null;
 	let stderr = null;
 	let stderr = null;
 	let result = null;
 	let result = null;
+	let files = null;
 
 
 	let copied = false;
 	let copied = false;
 	let saved = false;
 	let saved = false;
@@ -110,7 +111,7 @@
 	};
 	};
 
 
 	const executePython = async (code) => {
 	const executePython = async (code) => {
-		if (!code.includes('input') && !code.includes('matplotlib')) {
+		if (!code.includes('input')) {
 			executePythonAsWorker(code);
 			executePythonAsWorker(code);
 		} else {
 		} else {
 			result = null;
 			result = null;
@@ -211,7 +212,8 @@ __builtins__.input = input`);
 			code.includes('re') ? 'regex' : null,
 			code.includes('re') ? 'regex' : null,
 			code.includes('seaborn') ? 'seaborn' : null,
 			code.includes('seaborn') ? 'seaborn' : null,
 			code.includes('sympy') ? 'sympy' : null,
 			code.includes('sympy') ? 'sympy' : null,
-			code.includes('tiktoken') ? 'tiktoken' : null
+			code.includes('tiktoken') ? 'tiktoken' : null,
+			code.includes('matplotlib') ? 'matplotlib' : null
 		].filter(Boolean);
 		].filter(Boolean);
 
 
 		console.log(packages);
 		console.log(packages);
@@ -238,7 +240,31 @@ __builtins__.input = input`);
 
 
 			console.log(id, data);
 			console.log(id, data);
 
 
-			data['stdout'] && (stdout = data['stdout']);
+			if (data['stdout']) {
+				stdout = data['stdout'];
+				const stdoutLines = stdout.split('\n');
+
+				for (const [idx, line] of stdoutLines.entries()) {
+					if (line.startsWith('data:image/png;base64')) {
+						if (files) {
+							files.push({
+								type: 'image/png',
+								data: line
+							});
+						} else {
+							files = [
+								{
+									type: 'image/png',
+									data: line
+								}
+							];
+						}
+
+						stdout = stdout.replace(`${line}\n`, ``);
+					}
+				}
+			}
+
 			data['stderr'] && (stderr = data['stderr']);
 			data['stderr'] && (stderr = data['stderr']);
 			data['result'] && (result = data['result']);
 			data['result'] && (result = data['result']);
 
 
@@ -430,10 +456,21 @@ __builtins__.input = input`);
 								<div class="text-sm">{stdout || stderr}</div>
 								<div class="text-sm">{stdout || stderr}</div>
 							</div>
 							</div>
 						{/if}
 						{/if}
-						{#if result}
+						{#if result || files}
 							<div class=" ">
 							<div class=" ">
 								<div class=" text-gray-500 text-xs mb-1">RESULT</div>
 								<div class=" text-gray-500 text-xs mb-1">RESULT</div>
-								<div class="text-sm">{`${JSON.stringify(result)}`}</div>
+								{#if result}
+									<div class="text-sm">{`${JSON.stringify(result)}`}</div>
+								{/if}
+								{#if files}
+									<div class="flex flex-col gap-2">
+										{#each files as file}
+											{#if file.type.startsWith('image')}
+												<img src={file.data} alt="Output" />
+											{/if}
+										{/each}
+									</div>
+								{/if}
 							</div>
 							</div>
 						{/if}
 						{/if}
 					{/if}
 					{/if}

+ 29 - 0
src/lib/workers/pyodide.worker.ts

@@ -77,6 +77,35 @@ self.onmessage = async (event) => {
 	await loadPyodideAndPackages(self.packages);
 	await loadPyodideAndPackages(self.packages);
 
 
 	try {
 	try {
+		// check if matplotlib is imported in the code
+		if (code.includes('matplotlib')) {
+			// Override plt.show() to return base64 image
+			await self.pyodide.runPythonAsync(`import base64
+import os
+from io import BytesIO
+
+# before importing matplotlib
+# to avoid the wasm backend (which needs js.document', not available in worker)
+os.environ["MPLBACKEND"] = "AGG"
+
+import matplotlib.pyplot
+
+_old_show = matplotlib.pyplot.show
+assert _old_show, "matplotlib.pyplot.show"
+
+def show(*, block=None):
+	buf = BytesIO()
+	matplotlib.pyplot.savefig(buf, format="png")
+	buf.seek(0)
+	# encode to a base64 str
+	img_str = base64.b64encode(buf.read()).decode('utf-8')
+	matplotlib.pyplot.clf()
+	buf.close()
+	print(f"data:image/png;base64,{img_str}")
+
+matplotlib.pyplot.show = show`);
+		}
+
 		self.result = await self.pyodide.runPythonAsync(code);
 		self.result = await self.pyodide.runPythonAsync(code);
 
 
 		// Safely process and recursively serialize the result
 		// Safely process and recursively serialize the result