Timothy J. Baek 10 月之前
父节点
当前提交
514c7f1520
共有 3 个文件被更改,包括 30 次插入26 次删除
  1. 19 6
      backend/main.py
  2. 10 20
      src/lib/components/chat/Chat.svelte
  3. 1 0
      src/lib/components/chat/MessageInput.svelte

+ 19 - 6
backend/main.py

@@ -170,7 +170,9 @@ app.state.MODELS = {}
 origins = ["*"]
 
 
-async def get_function_call_response(messages, tool_id, template, task_model_id, user):
+async def get_function_call_response(
+    messages, files, tool_id, template, task_model_id, user
+):
     tool = Tools.get_tool_by_id(tool_id)
     tools_specs = json.dumps(tool.specs, indent=2)
     content = tools_function_calling_generation_template(template, tools_specs)
@@ -265,6 +267,13 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
                             "__messages__": messages,
                         }
 
+                    if "__files__" in sig.parameters:
+                        # Call the function with the '__files__' parameter included
+                        params = {
+                            **params,
+                            "__files__": files,
+                        }
+
                     function_result = function(**params)
                 except Exception as e:
                     print(e)
@@ -338,6 +347,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                     try:
                         response = await get_function_call_response(
                             messages=data["messages"],
+                            files=data.get("files", []),
                             tool_id=tool_id,
                             template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
                             task_model_id=task_model_id,
@@ -353,7 +363,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
 
                 print(f"tool_context: {context}")
 
-            # If docs field is present, generate RAG completions
+            # TODO: Check if tools & functions have files support to skip this step to delegate file processing
+            # If files field is present, generate RAG completions
             if "files" in data:
                 data = {**data}
                 rag_context, citations = get_rag_context(
@@ -376,15 +387,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                 system_prompt = rag_template(
                     rag_app.state.config.RAG_TEMPLATE, context, prompt
                 )
-
                 print(system_prompt)
-
                 data["messages"] = add_or_update_system_message(
                     f"\n{system_prompt}", data["messages"]
                 )
 
             modified_body_bytes = json.dumps(data).encode("utf-8")
-
             # Replace the request body with the modified one
             request._body = modified_body_bytes
             # Set custom header to ensure content-length matches new body length
@@ -961,7 +969,12 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
 
     try:
         context = await get_function_call_response(
-            form_data["messages"], form_data["tool_id"], template, model_id, user
+            form_data["messages"],
+            form_data.get("files", []),
+            form_data["tool_id"],
+            template,
+            model_id,
+            user,
         )
         return context
     except Exception as e:

+ 10 - 20
src/lib/components/chat/Chat.svelte

@@ -587,22 +587,17 @@
 		});
 
 		let files = [];
-
 		if (model?.info?.meta?.knowledge ?? false) {
 			files = model.info.meta.knowledge;
 		}
-
+		const lastUserMessage = messages.filter((message) => message.role === 'user').at(-1);
 		files = [
 			...files,
-			...messages
-				.filter((message) => message?.files ?? null)
-				.map((message) =>
-					message.files.filter((item) =>
-						['doc', 'file', 'collection', 'web_search_results'].includes(item.type)
-					)
-				)
-				.flat(1)
+			...(lastUserMessage?.files?.filter((item) =>
+				['doc', 'file', 'collection', 'web_search_results'].includes(item.type)
+			) ?? [])
 		].filter(
+			// Remove duplicates
 			(item, index, array) =>
 				array.findIndex((i) => JSON.stringify(i) === JSON.stringify(item)) === index
 		);
@@ -832,22 +827,17 @@
 		const responseMessage = history.messages[responseMessageId];
 
 		let files = [];
-
 		if (model?.info?.meta?.knowledge ?? false) {
 			files = model.info.meta.knowledge;
 		}
-
+		const lastUserMessage = messages.filter((message) => message.role === 'user').at(-1);
 		files = [
 			...files,
-			...messages
-				.filter((message) => message?.files ?? null)
-				.map((message) =>
-					message.files.filter((item) =>
-						['doc', 'file', 'collection', 'web_search_results'].includes(item.type)
-					)
-				)
-				.flat(1)
+			...(lastUserMessage?.files?.filter((item) =>
+				['doc', 'file', 'collection', 'web_search_results'].includes(item.type)
+			) ?? [])
 		].filter(
+			// Remove duplicates
 			(item, index, array) =>
 				array.findIndex((i) => JSON.stringify(i) === JSON.stringify(item)) === index
 		);

+ 1 - 0
src/lib/components/chat/MessageInput.svelte

@@ -153,6 +153,7 @@
 
 			if (res) {
 				fileItem.status = 'processed';
+				fileItem.collection_name = res.collection_name;
 				files = files;
 			}
 		} catch (e) {