|
@@ -185,39 +185,48 @@ async def get_function_call_response(prompt, tool_id, template, task_model_id, u
|
|
model = app.state.MODELS[task_model_id]
|
|
model = app.state.MODELS[task_model_id]
|
|
|
|
|
|
response = None
|
|
response = None
|
|
- if model["owned_by"] == "ollama":
|
|
|
|
- response = await generate_ollama_chat_completion(
|
|
|
|
- OpenAIChatCompletionForm(**payload), user=user
|
|
|
|
- )
|
|
|
|
- else:
|
|
|
|
- response = await generate_openai_chat_completion(payload, user=user)
|
|
|
|
-
|
|
|
|
- print(response)
|
|
|
|
- content = response["choices"][0]["message"]["content"]
|
|
|
|
-
|
|
|
|
- # Parse the function response
|
|
|
|
- if content != "":
|
|
|
|
- result = json.loads(content)
|
|
|
|
- print(result)
|
|
|
|
-
|
|
|
|
- # Call the function
|
|
|
|
- if "name" in result:
|
|
|
|
- if tool_id in webui_app.state.TOOLS:
|
|
|
|
- toolkit_module = webui_app.state.TOOLS[tool_id]
|
|
|
|
- else:
|
|
|
|
- toolkit_module = load_toolkit_module_by_id(tool_id)
|
|
|
|
- webui_app.state.TOOLS[tool_id] = toolkit_module
|
|
|
|
-
|
|
|
|
- function = getattr(toolkit_module, result["name"])
|
|
|
|
- function_result = None
|
|
|
|
- try:
|
|
|
|
- function_result = function(**result["parameters"])
|
|
|
|
- except Exception as e:
|
|
|
|
- print(e)
|
|
|
|
|
|
+ try:
|
|
|
|
+ if model["owned_by"] == "ollama":
|
|
|
|
+ response = await generate_ollama_chat_completion(
|
|
|
|
+ OpenAIChatCompletionForm(**payload), user=user
|
|
|
|
+ )
|
|
|
|
+ else:
|
|
|
|
+ response = await generate_openai_chat_completion(payload, user=user)
|
|
|
|
+
|
|
|
|
+ content = None
|
|
|
|
+ async for chunk in response.body_iterator:
|
|
|
|
+ data = json.loads(chunk.decode("utf-8"))
|
|
|
|
+ content = data["choices"][0]["message"]["content"]
|
|
|
|
+
|
|
|
|
+ # Cleanup any remaining background tasks if necessary
|
|
|
|
+ if response.background is not None:
|
|
|
|
+ await response.background()
|
|
|
|
+
|
|
|
|
+ # Parse the function response
|
|
|
|
+ if content is not None:
|
|
|
|
+ result = json.loads(content)
|
|
|
|
+ print(result)
|
|
|
|
+
|
|
|
|
+ # Call the function
|
|
|
|
+ if "name" in result:
|
|
|
|
+ if tool_id in webui_app.state.TOOLS:
|
|
|
|
+ toolkit_module = webui_app.state.TOOLS[tool_id]
|
|
|
|
+ else:
|
|
|
|
+ toolkit_module = load_toolkit_module_by_id(tool_id)
|
|
|
|
+ webui_app.state.TOOLS[tool_id] = toolkit_module
|
|
|
|
+
|
|
|
|
+ function = getattr(toolkit_module, result["name"])
|
|
|
|
+ function_result = None
|
|
|
|
+ try:
|
|
|
|
+ function_result = function(**result["parameters"])
|
|
|
|
+ except Exception as e:
|
|
|
|
+ print(e)
|
|
|
|
|
|
- # Add the function result to the system prompt
|
|
|
|
- if function_result:
|
|
|
|
- return function_result
|
|
|
|
|
|
+ # Add the function result to the system prompt
|
|
|
|
+ if function_result:
|
|
|
|
+ return function_result
|
|
|
|
+ except Exception as e:
|
|
|
|
+ print(f"Error: {e}")
|
|
|
|
|
|
return None
|
|
return None
|
|
|
|
|
|
@@ -285,15 +294,18 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
print(response)
|
|
print(response)
|
|
|
|
|
|
if response:
|
|
if response:
|
|
- context += f"\n{response}"
|
|
|
|
|
|
+ context = ("\n" if context != "" else "") + response
|
|
|
|
|
|
- system_prompt = rag_template(
|
|
|
|
- rag_app.state.config.RAG_TEMPLATE, context, prompt
|
|
|
|
- )
|
|
|
|
|
|
+ if context != "":
|
|
|
|
+ system_prompt = rag_template(
|
|
|
|
+ rag_app.state.config.RAG_TEMPLATE, context, prompt
|
|
|
|
+ )
|
|
|
|
|
|
- data["messages"] = add_or_update_system_message(
|
|
|
|
- system_prompt, data["messages"]
|
|
|
|
- )
|
|
|
|
|
|
+ print(system_prompt)
|
|
|
|
+
|
|
|
|
+ data["messages"] = add_or_update_system_message(
|
|
|
|
+ f"\n{system_prompt}", data["messages"]
|
|
|
|
+ )
|
|
|
|
|
|
del data["tool_ids"]
|
|
del data["tool_ids"]
|
|
|
|
|