Ver código fonte

factor out get_content_from_response

Michael Poluektov 8 meses atrás
pai
commit
9fb70969d7
1 arquivos alterados com 20 adições e 16 exclusões
  1. 20 16
      backend/main.py

+ 20 - 16
backend/main.py

@@ -282,6 +282,21 @@ def get_filter_function_ids(model):
     return filter_ids
 
 
+async def get_content_from_response(response) -> Optional[str]:
+    content = None
+    if hasattr(response, "body_iterator"):
+        async for chunk in response.body_iterator:
+            data = json.loads(chunk.decode("utf-8"))
+            content = data["choices"][0]["message"]["content"]
+
+        # Cleanup any remaining background tasks if necessary
+        if response.background is not None:
+            await response.background()
+    else:
+        content = response["choices"][0]["message"]["content"]
+    return content
+
+
 async def get_function_call_response(
     messages,
     files,
@@ -293,6 +308,9 @@ async def get_function_call_response(
     __event_call__=None,
 ):
     tool = Tools.get_tool_by_id(tool_id)
+    if tool is None:
+        return None, None, False
+
     tools_specs = json.dumps(tool.specs, indent=2)
     content = tools_function_calling_generation_template(template, tools_specs)
 
@@ -327,21 +345,9 @@ async def get_function_call_response(
 
     model = app.state.MODELS[task_model_id]
 
-    response = None
     try:
         response = await generate_chat_completions(form_data=payload, user=user)
-        content = None
-
-        if hasattr(response, "body_iterator"):
-            async for chunk in response.body_iterator:
-                data = json.loads(chunk.decode("utf-8"))
-                content = data["choices"][0]["message"]["content"]
-
-            # Cleanup any remaining background tasks if necessary
-            if response.background is not None:
-                await response.background()
-        else:
-            content = response["choices"][0]["message"]["content"]
+        content = await get_content_from_response(response)
 
         if content is None:
             return None, None, False
@@ -351,8 +357,6 @@ async def get_function_call_response(
         result = json.loads(content)
         print(result)
 
-        citation = None
-
         if "name" not in result:
             return None, None, False
 
@@ -375,6 +379,7 @@ async def get_function_call_response(
 
         function = getattr(toolkit_module, result["name"])
         function_result = None
+        citation = None
         try:
             # Get the signature of the function
             sig = inspect.signature(function)
@@ -1091,7 +1096,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
     if model.get("pipe"):
         return await generate_function_chat_completion(form_data, user=user)
     if model["owned_by"] == "ollama":
-        print("generate_ollama_chat_completion")
         return await generate_ollama_chat_completion(form_data, user=user)
     else:
         return await generate_openai_chat_completion(form_data, user=user)