Timothy J. Baek 10 月之前
父節點
當前提交
e82027310d
共有 1 個文件被更改,包括 43 次插入15 次删除
  1. 43 15
      backend/main.py

+ 43 - 15
backend/main.py

@@ -196,7 +196,11 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
         "stream": False,
         "stream": False,
     }
     }
 
 
-    payload = filter_pipeline(payload, user)
+    try:
+        payload = filter_pipeline(payload, user)
+    except Exception as e:
+        raise e
+
     model = app.state.MODELS[task_model_id]
     model = app.state.MODELS[task_model_id]
 
 
     response = None
     response = None
@@ -326,16 +330,19 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                 print(data["tool_ids"])
                 print(data["tool_ids"])
                 for tool_id in data["tool_ids"]:
                 for tool_id in data["tool_ids"]:
                     print(tool_id)
                     print(tool_id)
-                    response = await get_function_call_response(
-                        messages=data["messages"],
-                        tool_id=tool_id,
-                        template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
-                        task_model_id=task_model_id,
-                        user=user,
-                    )
+                    try:
+                        response = await get_function_call_response(
+                            messages=data["messages"],
+                            tool_id=tool_id,
+                            template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
+                            task_model_id=task_model_id,
+                            user=user,
+                        )
 
 
-                    if response:
-                        context += ("\n" if context != "" else "") + response
+                        if response:
+                            context += ("\n" if context != "" else "") + response
+                    except Exception as e:
+                        print(f"Error: {e}")
                 del data["tool_ids"]
                 del data["tool_ids"]
 
 
                 print(f"tool_context: {context}")
                 print(f"tool_context: {context}")
@@ -767,7 +774,14 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
     }
     }
 
 
     print(payload)
     print(payload)
-    payload = filter_pipeline(payload, user)
+
+    try:
+        payload = filter_pipeline(payload, user)
+    except Exception as e:
+        return JSONResponse(
+            status_code=e.args[0],
+            content={"detail": e.args[1]},
+        )
 
 
     if model["owned_by"] == "ollama":
     if model["owned_by"] == "ollama":
         return await generate_ollama_chat_completion(
         return await generate_ollama_chat_completion(
@@ -824,7 +838,14 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
     }
     }
 
 
     print(payload)
     print(payload)
-    payload = filter_pipeline(payload, user)
+
+    try:
+        payload = filter_pipeline(payload, user)
+    except Exception as e:
+        return JSONResponse(
+            status_code=e.args[0],
+            content={"detail": e.args[1]},
+        )
 
 
     if model["owned_by"] == "ollama":
     if model["owned_by"] == "ollama":
         return await generate_ollama_chat_completion(
         return await generate_ollama_chat_completion(
@@ -861,9 +882,16 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
     print(model_id)
     print(model_id)
     template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
     template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
 
 
-    return await get_function_call_response(
-        form_data["messages"], form_data["tool_id"], template, model_id, user
-    )
+    try:
+        context = await get_function_call_response(
+            form_data["messages"], form_data["tool_id"], template, model_id, user
+        )
+        return context
+    except Exception as e:
+        return JSONResponse(
+            status_code=e.args[0],
+            content={"detail": e.args[1]},
+        )
 
 
 
 
 @app.post("/api/chat/completions")
 @app.post("/api/chat/completions")