Przeglądaj źródła

put tool_ids and files in metadata

Michael Poluektov 8 miesięcy temu
rodzic
commit
2e3146263c

+ 3 - 5
backend/apps/ollama/main.py

@@ -731,12 +731,10 @@ async def generate_chat_completion(
     url_idx: Optional[int] = None,
     user=Depends(get_verified_user),
 ):
-    log.debug(f"{form_data.model_dump_json(exclude_none=True).encode()}=")
-
     payload = {**form_data.model_dump(exclude_none=True)}
-    for key in ["metadata", "files", "tool_ids"]:
-        if key in payload:
-            del payload[key]
+    log.debug(f"{payload = }")
+    if "metadata" in payload:
+        del payload["metadata"]
 
     model_id = form_data.model
     model_info = Models.get_model_by_id(model_id)

+ 4 - 2
backend/apps/webui/main.py

@@ -273,10 +273,12 @@ def get_function_params(function_module, form_data, user, extra_params={}):
     return params
 
 
-async def generate_function_chat_completion(form_data, user, files, tool_ids):
+async def generate_function_chat_completion(form_data, user):
     model_id = form_data.get("model")
     model_info = Models.get_model_by_id(model_id)
-    metadata = form_data.pop("metadata", None)
+    metadata = form_data.pop("metadata", {})
+    files = metadata.get("files", [])
+    tool_ids = metadata.get("tool_ids", [])
 
     __event_emitter__ = None
     __event_call__ = None

+ 9 - 11
backend/main.py

@@ -326,8 +326,8 @@ async def chat_completion_filter_functions_handler(body, model, extra_params):
             print(f"Error: {e}")
             raise e
 
-    if skip_files and "files" in body:
-        del body["files"]
+    if skip_files and "files" in body.get("metadata", {}):
+        del body["metadata"]["files"]
 
     return body, {}
 
@@ -371,7 +371,8 @@ async def chat_completion_tools_handler(
     body: dict, user: UserModel, extra_params: dict
 ) -> tuple[dict, dict]:
     # If tool_ids field is present, call the functions
-    tool_ids = body.get("tool_ids", None)
+    metadata = body.get("metadata", {})
+    tool_ids = metadata.get("tool_ids", None)
     if not tool_ids:
         return body, {}
 
@@ -387,7 +388,7 @@ async def chat_completion_tools_handler(
         **extra_params,
         "__model__": app.state.MODELS[task_model_id],
         "__messages__": body["messages"],
-        "__files__": body.get("files", []),
+        "__files__": metadata.get("files", []),
     }
     tools = get_tools(webui_app, tool_ids, user, custom_params)
     log.info(f"{tools=}")
@@ -454,8 +455,8 @@ async def chat_completion_tools_handler(
 
     log.debug(f"tool_contexts: {contexts}")
 
-    if skip_files and "files" in body:
-        del body["files"]
+    if skip_files and "files" in body.get("metadata", {}):
+        del body["metadata"]["files"]
 
     return body, {"contexts": contexts, "citations": citations}
 
@@ -464,7 +465,7 @@ async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]:
     contexts = []
     citations = []
 
-    if files := body.get("files", None):
+    if files := body.get("metadata", {}).get("files", None):
         contexts, citations = get_rag_context(
             files=files,
             messages=body["messages"],
@@ -986,11 +987,8 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
             detail="Model not found",
         )
     model = app.state.MODELS[model_id]
-    files = form_data.pop("files", [])
-    tool_ids = form_data.pop("tool_ids", [])
-
     if model.get("pipe"):
-        return await generate_function_chat_completion(form_data, user, files, tool_ids)
+        return await generate_function_chat_completion(form_data, user=user)
     if model["owned_by"] == "ollama":
         return await generate_ollama_chat_completion(form_data, user=user)
     else:

+ 8 - 4
src/lib/components/chat/Chat.svelte

@@ -844,8 +844,10 @@
 			},
 			format: $settings.requestFormat ?? undefined,
 			keep_alive: $settings.keepAlive ?? undefined,
-			tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
-			files: files.length > 0 ? files : undefined,
+			metadata: {
+				tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
+				files: files.length > 0 ? files : undefined
+			},
 			session_id: $socket?.id,
 			chat_id: $chatId,
 			id: responseMessageId
@@ -1136,8 +1138,10 @@
 					frequency_penalty:
 						params?.frequency_penalty ?? $settings?.params?.frequency_penalty ?? undefined,
 					max_tokens: params?.max_tokens ?? $settings?.params?.max_tokens ?? undefined,
-					tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
-					files: files.length > 0 ? files : undefined,
+					metadata: {
+						tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
+						files: files.length > 0 ? files : undefined
+					},
 					session_id: $socket?.id,
 					chat_id: $chatId,
 					id: responseMessageId