Timothy J. Baek 10 月之前
父节点
当前提交
4dd77b785a
共有 2 个文件被更改,包括 11 次插入9 次删除
  1. 1 1
      backend/apps/ollama/main.py
  2. 10 8
      backend/main.py

+ 1 - 1
backend/apps/ollama/main.py

@@ -895,8 +895,8 @@ async def generate_openai_chat_completion(
     user=Depends(get_verified_user),
 ):
     form_data = OpenAIChatCompletionForm(**form_data)
+    payload = {**form_data.model_dump(exclude_none=True, exclude=["metadata"])}
 
-    payload = {**form_data}
     if "metadata" in payload:
         del payload["metadata"]
 

+ 10 - 8
backend/main.py

@@ -317,7 +317,7 @@ async def get_function_call_response(
             {"role": "user", "content": f"Query: {prompt}"},
         ],
         "stream": False,
-        "task": TASKS.FUNCTION_CALLING,
+        "task": str(TASKS.FUNCTION_CALLING),
     }
 
     try:
@@ -632,8 +632,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                 message_id = body["id"]
                 del body["id"]
 
-           
-
             __event_emitter__ = await get_event_emitter(
                 {"chat_id": chat_id, "message_id": message_id, "session_id": session_id}
             )
@@ -1037,12 +1035,16 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
         task = form_data["task"]
         del form_data["task"]
 
-    if "metadata" in form_data:
-        form_data["metadata"]['task'] = task
+    if task:
+        if "metadata" in form_data:
+            form_data["metadata"]["task"] = task
+        else:
+            form_data["metadata"] = {"task": task}
 
     if model.get("pipe"):
         return await generate_function_chat_completion(form_data, user=user)
     if model["owned_by"] == "ollama":
+        print("generate_ollama_chat_completion")
         return await generate_ollama_chat_completion(form_data, user=user)
     else:
         return await generate_openai_chat_completion(form_data, user=user)
@@ -1311,7 +1313,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
         "stream": False,
         "max_tokens": 50,
         "chat_id": form_data.get("chat_id", None),
-        "task": TASKS.TITLE_GENERATION,
+        "task": str(TASKS.TITLE_GENERATION),
     }
 
     log.debug(payload)
@@ -1364,7 +1366,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
         "messages": [{"role": "user", "content": content}],
         "stream": False,
         "max_tokens": 30,
-        "task": TASKS.QUERY_GENERATION,
+        "task": str(TASKS.QUERY_GENERATION),
     }
 
     print(payload)
@@ -1421,7 +1423,7 @@ Message: """{{prompt}}"""
         "stream": False,
         "max_tokens": 4,
         "chat_id": form_data.get("chat_id", None),
-        "task": TASKS.EMOJI_GENERATION,
+        "task": str(TASKS.EMOJI_GENERATION),
     }
 
     log.debug(payload)