|
@@ -282,6 +282,21 @@ def get_filter_function_ids(model):
|
|
|
return filter_ids
|
|
|
|
|
|
|
|
|
+async def get_content_from_response(response) -> Optional[str]:
|
|
|
+ content = None
|
|
|
+ if hasattr(response, "body_iterator"):
|
|
|
+ async for chunk in response.body_iterator:
|
|
|
+ data = json.loads(chunk.decode("utf-8"))
|
|
|
+ content = data["choices"][0]["message"]["content"]
|
|
|
+
|
|
|
+ # Cleanup any remaining background tasks if necessary
|
|
|
+ if response.background is not None:
|
|
|
+ await response.background()
|
|
|
+ else:
|
|
|
+ content = response["choices"][0]["message"]["content"]
|
|
|
+ return content
|
|
|
+
|
|
|
+
|
|
|
async def get_function_call_response(
|
|
|
messages,
|
|
|
files,
|
|
@@ -293,6 +308,9 @@ async def get_function_call_response(
|
|
|
__event_call__=None,
|
|
|
):
|
|
|
tool = Tools.get_tool_by_id(tool_id)
|
|
|
+ if tool is None:
|
|
|
+ return None, None, False
|
|
|
+
|
|
|
tools_specs = json.dumps(tool.specs, indent=2)
|
|
|
content = tools_function_calling_generation_template(template, tools_specs)
|
|
|
|
|
@@ -327,21 +345,9 @@ async def get_function_call_response(
|
|
|
|
|
|
model = app.state.MODELS[task_model_id]
|
|
|
|
|
|
- response = None
|
|
|
try:
|
|
|
response = await generate_chat_completions(form_data=payload, user=user)
|
|
|
- content = None
|
|
|
-
|
|
|
- if hasattr(response, "body_iterator"):
|
|
|
- async for chunk in response.body_iterator:
|
|
|
- data = json.loads(chunk.decode("utf-8"))
|
|
|
- content = data["choices"][0]["message"]["content"]
|
|
|
-
|
|
|
- # Cleanup any remaining background tasks if necessary
|
|
|
- if response.background is not None:
|
|
|
- await response.background()
|
|
|
- else:
|
|
|
- content = response["choices"][0]["message"]["content"]
|
|
|
+ content = await get_content_from_response(response)
|
|
|
|
|
|
if content is None:
|
|
|
return None, None, False
|
|
@@ -351,8 +357,6 @@ async def get_function_call_response(
|
|
|
result = json.loads(content)
|
|
|
print(result)
|
|
|
|
|
|
- citation = None
|
|
|
-
|
|
|
if "name" not in result:
|
|
|
return None, None, False
|
|
|
|
|
@@ -375,6 +379,7 @@ async def get_function_call_response(
|
|
|
|
|
|
function = getattr(toolkit_module, result["name"])
|
|
|
function_result = None
|
|
|
+ citation = None
|
|
|
try:
|
|
|
# Get the signature of the function
|
|
|
sig = inspect.signature(function)
|
|
@@ -1091,7 +1096,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
|
|
|
if model.get("pipe"):
|
|
|
return await generate_function_chat_completion(form_data, user=user)
|
|
|
if model["owned_by"] == "ollama":
|
|
|
- print("generate_ollama_chat_completion")
|
|
|
return await generate_ollama_chat_completion(form_data, user=user)
|
|
|
else:
|
|
|
return await generate_openai_chat_completion(form_data, user=user)
|