Timothy J. Baek 8 months ago
parent
commit
15f3ebba93
1 changed files with 50 additions and 50 deletions
  1. 50 50
      backend/main.py

+ 50 - 50
backend/main.py

@@ -218,25 +218,6 @@ origins = ["*"]
 ##################################
 
 
-async def get_body_and_model_and_user(request):
-    # Read the original request body
-    body = await request.body()
-    body_str = body.decode("utf-8")
-    body = json.loads(body_str) if body_str else {}
-
-    model_id = body["model"]
-    if model_id not in app.state.MODELS:
-        raise Exception("Model not found")
-    model = app.state.MODELS[model_id]
-
-    user = get_current_user(
-        request,
-        get_http_authorization_cred(request.headers.get("Authorization")),
-    )
-
-    return body, model, user
-
-
 def get_task_model_id(default_model_id):
     # Set the task model
     task_model_id = default_model_id
@@ -283,26 +264,6 @@ def get_filter_function_ids(model):
     return filter_ids
 
 
-def get_tools_function_calling_payload(messages, task_model_id, content):
-    user_message = get_last_user_message(messages)
-    history = "\n".join(
-        f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
-        for message in messages[::-1][:4]
-    )
-
-    prompt = f"History:\n{history}\nQuery: {user_message}"
-
-    return {
-        "model": task_model_id,
-        "messages": [
-            {"role": "system", "content": content},
-            {"role": "user", "content": f"Query: {prompt}"},
-        ],
-        "stream": False,
-        "metadata": {"task": str(TASKS.FUNCTION_CALLING)},
-    }
-
-
 async def chat_completion_filter_functions_handler(body, model, extra_params):
     skip_files = None
 
@@ -369,12 +330,32 @@ async def chat_completion_filter_functions_handler(body, model, extra_params):
     return body, {}
 
 
+def get_tools_function_calling_payload(messages, task_model_id, content):
+    user_message = get_last_user_message(messages)
+    history = "\n".join(
+        f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
+        for message in messages[::-1][:4]
+    )
+
+    prompt = f"History:\n{history}\nQuery: {user_message}"
+
+    return {
+        "model": task_model_id,
+        "messages": [
+            {"role": "system", "content": content},
+            {"role": "user", "content": f"Query: {prompt}"},
+        ],
+        "stream": False,
+        "metadata": {"task": str(TASKS.FUNCTION_CALLING)},
+    }
+
+
 def apply_extra_params_to_tool_function(
-    function: Callable, custom_params: dict
+    function: Callable, extra_params: dict
 ) -> Callable[..., Awaitable]:
     sig = inspect.signature(function)
     extra_params = {
-        key: value for key, value in custom_params.items() if key in sig.parameters
+        key: value for key, value in extra_params.items() if key in sig.parameters
     }
     is_coroutine = inspect.iscoroutinefunction(function)
 
@@ -511,27 +492,27 @@ async def chat_completion_tools_handler(
             return body, {}
 
         result = json.loads(content)
-        tool_name = result.get("name", None)
-        if tool_name not in tools:
+
+        tool_function_name = result.get("name", None)
+        if tool_function_name not in tools:
             return body, {}
 
-        tool_params = result.get("parameters", {})
-        toolkit_id = tools[tool_name]["toolkit_id"]
+        tool_function_params = result.get("parameters", {})
 
         try:
-            tool_output = await tools[tool_name]["callable"](**tool_params)
+            tool_output = await tools[tool_function_name]["callable"](**tool_function_params)
         except Exception as e:
             tool_output = str(e)
 
-        if tools[tool_name]["citation"]:
+        if tools[tool_function_name]["citation"]:
             citations.append(
                 {
-                    "source": {"name": f"TOOL:{toolkit_id}/{tool_name}"},
+                    "source": {"name": f"TOOL:{tools[tool_function_name]["toolkit_id"]}/{tool_function_name}"},
                     "document": [tool_output],
-                    "metadata": [{"source": tool_name}],
+                    "metadata": [{"source": tool_function_name}],
                 }
             )
-        if tools[tool_name]["file_handler"]:
+        if tools[tool_function_name]["file_handler"]:
             skip_files = True
 
         if isinstance(tool_output, str):
@@ -576,6 +557,25 @@ def is_chat_completion_request(request):
     )
 
 
+async def get_body_and_model_and_user(request):
+    # Read the original request body
+    body = await request.body()
+    body_str = body.decode("utf-8")
+    body = json.loads(body_str) if body_str else {}
+
+    model_id = body["model"]
+    if model_id not in app.state.MODELS:
+        raise Exception("Model not found")
+    model = app.state.MODELS[model_id]
+
+    user = get_current_user(
+        request,
+        get_http_authorization_cred(request.headers.get("Authorization")),
+    )
+
+    return body, model, user
+
+
 class ChatCompletionMiddleware(BaseHTTPMiddleware):
     async def dispatch(self, request: Request, call_next):
         if not is_chat_completion_request(request):