|
@@ -218,25 +218,6 @@ origins = ["*"]
|
|
##################################
|
|
##################################
|
|
|
|
|
|
|
|
|
|
-async def get_body_and_model_and_user(request):
|
|
|
|
- # Read the original request body
|
|
|
|
- body = await request.body()
|
|
|
|
- body_str = body.decode("utf-8")
|
|
|
|
- body = json.loads(body_str) if body_str else {}
|
|
|
|
-
|
|
|
|
- model_id = body["model"]
|
|
|
|
- if model_id not in app.state.MODELS:
|
|
|
|
- raise Exception("Model not found")
|
|
|
|
- model = app.state.MODELS[model_id]
|
|
|
|
-
|
|
|
|
- user = get_current_user(
|
|
|
|
- request,
|
|
|
|
- get_http_authorization_cred(request.headers.get("Authorization")),
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- return body, model, user
|
|
|
|
-
|
|
|
|
-
|
|
|
|
def get_task_model_id(default_model_id):
|
|
def get_task_model_id(default_model_id):
|
|
# Set the task model
|
|
# Set the task model
|
|
task_model_id = default_model_id
|
|
task_model_id = default_model_id
|
|
@@ -283,26 +264,6 @@ def get_filter_function_ids(model):
|
|
return filter_ids
|
|
return filter_ids
|
|
|
|
|
|
|
|
|
|
-def get_tools_function_calling_payload(messages, task_model_id, content):
|
|
|
|
- user_message = get_last_user_message(messages)
|
|
|
|
- history = "\n".join(
|
|
|
|
- f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
|
|
|
|
- for message in messages[::-1][:4]
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- prompt = f"History:\n{history}\nQuery: {user_message}"
|
|
|
|
-
|
|
|
|
- return {
|
|
|
|
- "model": task_model_id,
|
|
|
|
- "messages": [
|
|
|
|
- {"role": "system", "content": content},
|
|
|
|
- {"role": "user", "content": f"Query: {prompt}"},
|
|
|
|
- ],
|
|
|
|
- "stream": False,
|
|
|
|
- "metadata": {"task": str(TASKS.FUNCTION_CALLING)},
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
-
|
|
|
|
async def chat_completion_filter_functions_handler(body, model, extra_params):
|
|
async def chat_completion_filter_functions_handler(body, model, extra_params):
|
|
skip_files = None
|
|
skip_files = None
|
|
|
|
|
|
@@ -369,12 +330,32 @@ async def chat_completion_filter_functions_handler(body, model, extra_params):
|
|
return body, {}
|
|
return body, {}
|
|
|
|
|
|
|
|
|
|
|
|
+def get_tools_function_calling_payload(messages, task_model_id, content):
|
|
|
|
+ user_message = get_last_user_message(messages)
|
|
|
|
+ history = "\n".join(
|
|
|
|
+ f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
|
|
|
|
+ for message in messages[::-1][:4]
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ prompt = f"History:\n{history}\nQuery: {user_message}"
|
|
|
|
+
|
|
|
|
+ return {
|
|
|
|
+ "model": task_model_id,
|
|
|
|
+ "messages": [
|
|
|
|
+ {"role": "system", "content": content},
|
|
|
|
+ {"role": "user", "content": f"Query: {prompt}"},
|
|
|
|
+ ],
|
|
|
|
+ "stream": False,
|
|
|
|
+ "metadata": {"task": str(TASKS.FUNCTION_CALLING)},
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+
|
|
def apply_extra_params_to_tool_function(
|
|
def apply_extra_params_to_tool_function(
|
|
- function: Callable, custom_params: dict
|
|
|
|
|
|
+ function: Callable, extra_params: dict
|
|
) -> Callable[..., Awaitable]:
|
|
) -> Callable[..., Awaitable]:
|
|
sig = inspect.signature(function)
|
|
sig = inspect.signature(function)
|
|
extra_params = {
|
|
extra_params = {
|
|
- key: value for key, value in custom_params.items() if key in sig.parameters
|
|
|
|
|
|
+ key: value for key, value in extra_params.items() if key in sig.parameters
|
|
}
|
|
}
|
|
is_coroutine = inspect.iscoroutinefunction(function)
|
|
is_coroutine = inspect.iscoroutinefunction(function)
|
|
|
|
|
|
@@ -511,27 +492,27 @@ async def chat_completion_tools_handler(
|
|
return body, {}
|
|
return body, {}
|
|
|
|
|
|
result = json.loads(content)
|
|
result = json.loads(content)
|
|
- tool_name = result.get("name", None)
|
|
|
|
- if tool_name not in tools:
|
|
|
|
|
|
+
|
|
|
|
+ tool_function_name = result.get("name", None)
|
|
|
|
+ if tool_function_name not in tools:
|
|
return body, {}
|
|
return body, {}
|
|
|
|
|
|
- tool_params = result.get("parameters", {})
|
|
|
|
- toolkit_id = tools[tool_name]["toolkit_id"]
|
|
|
|
|
|
+ tool_function_params = result.get("parameters", {})
|
|
|
|
|
|
try:
|
|
try:
|
|
- tool_output = await tools[tool_name]["callable"](**tool_params)
|
|
|
|
|
|
+ tool_output = await tools[tool_function_name]["callable"](**tool_function_params)
|
|
except Exception as e:
|
|
except Exception as e:
|
|
tool_output = str(e)
|
|
tool_output = str(e)
|
|
|
|
|
|
- if tools[tool_name]["citation"]:
|
|
|
|
|
|
+ if tools[tool_function_name]["citation"]:
|
|
citations.append(
|
|
citations.append(
|
|
{
|
|
{
|
|
- "source": {"name": f"TOOL:{toolkit_id}/{tool_name}"},
|
|
|
|
|
|
+ "source": {"name": f"TOOL:{tools[tool_function_name]["toolkit_id"]}/{tool_function_name}"},
|
|
"document": [tool_output],
|
|
"document": [tool_output],
|
|
- "metadata": [{"source": tool_name}],
|
|
|
|
|
|
+ "metadata": [{"source": tool_function_name}],
|
|
}
|
|
}
|
|
)
|
|
)
|
|
- if tools[tool_name]["file_handler"]:
|
|
|
|
|
|
+ if tools[tool_function_name]["file_handler"]:
|
|
skip_files = True
|
|
skip_files = True
|
|
|
|
|
|
if isinstance(tool_output, str):
|
|
if isinstance(tool_output, str):
|
|
@@ -576,6 +557,25 @@ def is_chat_completion_request(request):
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
+async def get_body_and_model_and_user(request):
|
|
|
|
+ # Read the original request body
|
|
|
|
+ body = await request.body()
|
|
|
|
+ body_str = body.decode("utf-8")
|
|
|
|
+ body = json.loads(body_str) if body_str else {}
|
|
|
|
+
|
|
|
|
+ model_id = body["model"]
|
|
|
|
+ if model_id not in app.state.MODELS:
|
|
|
|
+ raise Exception("Model not found")
|
|
|
|
+ model = app.state.MODELS[model_id]
|
|
|
|
+
|
|
|
|
+ user = get_current_user(
|
|
|
|
+ request,
|
|
|
|
+ get_http_authorization_cred(request.headers.get("Authorization")),
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ return body, model, user
|
|
|
|
+
|
|
|
|
+
|
|
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
async def dispatch(self, request: Request, call_next):
|
|
async def dispatch(self, request: Request, call_next):
|
|
if not is_chat_completion_request(request):
|
|
if not is_chat_completion_request(request):
|