Kaynağa Gözat

feat: tools full integration

Timothy J. Baek 10 ay önce
ebeveyn
işleme
3d6f5f418d

+ 51 - 39
backend/main.py

@@ -185,39 +185,48 @@ async def get_function_call_response(prompt, tool_id, template, task_model_id, u
     model = app.state.MODELS[task_model_id]
 
     response = None
-    if model["owned_by"] == "ollama":
-        response = await generate_ollama_chat_completion(
-            OpenAIChatCompletionForm(**payload), user=user
-        )
-    else:
-        response = await generate_openai_chat_completion(payload, user=user)
-
-    print(response)
-    content = response["choices"][0]["message"]["content"]
-
-    # Parse the function response
-    if content != "":
-        result = json.loads(content)
-        print(result)
-
-        # Call the function
-        if "name" in result:
-            if tool_id in webui_app.state.TOOLS:
-                toolkit_module = webui_app.state.TOOLS[tool_id]
-            else:
-                toolkit_module = load_toolkit_module_by_id(tool_id)
-                webui_app.state.TOOLS[tool_id] = toolkit_module
-
-            function = getattr(toolkit_module, result["name"])
-            function_result = None
-            try:
-                function_result = function(**result["parameters"])
-            except Exception as e:
-                print(e)
+    try:
+        if model["owned_by"] == "ollama":
+            response = await generate_ollama_chat_completion(
+                OpenAIChatCompletionForm(**payload), user=user
+            )
+        else:
+            response = await generate_openai_chat_completion(payload, user=user)
+
+        content = None
+        async for chunk in response.body_iterator:
+            data = json.loads(chunk.decode("utf-8"))
+            content = data["choices"][0]["message"]["content"]
+
+        # Cleanup any remaining background tasks if necessary
+        if response.background is not None:
+            await response.background()
+
+        # Parse the function response
+        if content is not None:
+            result = json.loads(content)
+            print(result)
+
+            # Call the function
+            if "name" in result:
+                if tool_id in webui_app.state.TOOLS:
+                    toolkit_module = webui_app.state.TOOLS[tool_id]
+                else:
+                    toolkit_module = load_toolkit_module_by_id(tool_id)
+                    webui_app.state.TOOLS[tool_id] = toolkit_module
+
+                function = getattr(toolkit_module, result["name"])
+                function_result = None
+                try:
+                    function_result = function(**result["parameters"])
+                except Exception as e:
+                    print(e)
 
-            # Add the function result to the system prompt
-            if function_result:
-                return function_result
+                # Add the function result to the system prompt
+                if function_result:
+                    return function_result
+    except Exception as e:
+        print(f"Error: {e}")
 
     return None
 
@@ -285,15 +294,18 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                     print(response)
 
                     if response:
-                        context += f"\n{response}"
+                        context = ("\n" if context != "" else "") + response
 
-                system_prompt = rag_template(
-                    rag_app.state.config.RAG_TEMPLATE, context, prompt
-                )
+                if context != "":
+                    system_prompt = rag_template(
+                        rag_app.state.config.RAG_TEMPLATE, context, prompt
+                    )
 
-                data["messages"] = add_or_update_system_message(
-                    system_prompt, data["messages"]
-                )
+                    print(system_prompt)
+
+                    data["messages"] = add_or_update_system_message(
+                       f"\n{system_prompt}", data["messages"]
+                    )
 
                 del data["tool_ids"]
 

+ 4 - 0
src/lib/components/chat/Chat.svelte

@@ -73,6 +73,7 @@
 	let selectedModels = [''];
 	let atSelectedModel: Model | undefined;
 
+	let selectedToolIds = [];
 	let webSearchEnabled = false;
 
 	let chat = null;
@@ -687,6 +688,7 @@
 			},
 			format: $settings.requestFormat ?? undefined,
 			keep_alive: $settings.keepAlive ?? undefined,
+			tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
 			docs: docs.length > 0 ? docs : undefined,
 			citations: docs.length > 0,
 			chat_id: $chatId
@@ -948,6 +950,7 @@
 					top_p: $settings?.params?.top_p ?? undefined,
 					frequency_penalty: $settings?.params?.frequency_penalty ?? undefined,
 					max_tokens: $settings?.params?.max_tokens ?? undefined,
+					tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
 					docs: docs.length > 0 ? docs : undefined,
 					citations: docs.length > 0,
 					chat_id: $chatId
@@ -1274,6 +1277,7 @@
 				bind:files
 				bind:prompt
 				bind:autoScroll
+				bind:selectedToolIds
 				bind:webSearchEnabled
 				bind:atSelectedModel
 				{selectedModels}

+ 12 - 1
src/lib/components/chat/MessageInput.svelte

@@ -8,7 +8,8 @@
 		showSidebar,
 		models,
 		config,
-		showCallOverlay
+		showCallOverlay,
+		tools
 	} from '$lib/stores';
 	import { blobToFile, calculateSHA256, findWordIndices } from '$lib/utils';
 
@@ -57,6 +58,7 @@
 	let chatInputPlaceholder = '';
 
 	export let files = [];
+	export let selectedToolIds = [];
 
 	export let webSearchEnabled = false;
 
@@ -653,6 +655,15 @@
 								<div class=" ml-0.5 self-end mb-1.5 flex space-x-1">
 									<InputMenu
 										bind:webSearchEnabled
+										bind:selectedToolIds
+										tools={$tools.reduce((a, e, i, arr) => {
+											a[e.id] = {
+												name: e.name,
+												enabled: false
+											};
+
+											return a;
+										}, {})}
 										uploadFilesHandler={() => {
 											filesInputElement.click();
 										}}

+ 12 - 3
src/lib/components/chat/MessageInput/InputMenu.svelte

@@ -14,6 +14,8 @@
 	const i18n = getContext('i18n');
 
 	export let uploadFilesHandler: Function;
+
+	export let selectedToolIds: string[] = [];
 	export let webSearchEnabled: boolean;
 
 	export let tools = {};
@@ -44,16 +46,23 @@
 			transition={flyAndScale}
 		>
 			{#if Object.keys(tools).length > 0}
-				{#each Object.keys(tools) as tool}
+				{#each Object.keys(tools) as toolId}
 					<div
 						class="flex gap-2 items-center px-3 py-2 text-sm font-medium cursor-pointer rounded-xl"
 					>
 						<div class="flex-1 flex items-center gap-2">
 							<WrenchSolid />
-							<div class="flex items-center">{tool}</div>
+							<div class="flex items-center">{tools[toolId].name}</div>
 						</div>
 
-						<Switch bind:state={tools[tool]} />
+						<Switch
+							bind:state={tools[toolId].enabled}
+							on:change={(e) => {
+								selectedToolIds = e.detail
+									? [...selectedToolIds, toolId]
+									: selectedToolIds.filter((id) => id !== toolId);
+							}}
+						/>
 					</div>
 				{/each}
 				<hr class="border-gray-100 dark:border-gray-800 my-1" />