Browse Source

minor refac

Michael Poluektov 8 months ago
parent
commit
4042219b3e
1 changed files with 13 additions and 13 deletions
  1. 13 13
      backend/main.py

+ 13 - 13
backend/main.py

@@ -378,9 +378,8 @@ async def chat_completion_inlets_handler(body, model, extra_params):
             print(f"Error: {e}")
             raise e
 
-    if skip_files:
-        if "files" in body:
-            del body["files"]
+    if skip_files and "files" in body:
+        del body["files"]
 
     return body, {}
 
@@ -431,12 +430,17 @@ def get_configured_tools(
             )
 
         for spec in toolkit.specs:
+            # TODO: Fix hack for OpenAI API
+            for val in spec.get("parameters", {}).get("properties", {}).values():
+                if val["type"] == "str":
+                    val["type"] = "string"
             name = spec["name"]
             callable = getattr(module, name)
 
             # convert to function that takes only model params and inserts custom params
             custom_callable = get_tool_with_custom_params(callable, extra_params)
 
+            # TODO: This needs to be a pydantic model
             tool_dict = {
                 "spec": spec,
                 "citation": has_citation,
@@ -444,6 +448,7 @@ def get_configured_tools(
                 "toolkit_id": tool_id,
                 "callable": custom_callable,
             }
+            # TODO: if collision, prepend toolkit name
             if name in tools:
                 log.warning(f"Tool {name} already exists in another toolkit!")
                 log.warning(f"Collision between {toolkit} and {tool_id}.")
@@ -533,9 +538,9 @@ async def chat_completion_tools_handler(
     return body, {"contexts": contexts, "citations": citations}
 
 
-async def chat_completion_files_handler(body):
+async def chat_completion_files_handler(body) -> tuple[dict, dict[str, list]]:
     contexts = []
-    citations = None
+    citations = []
 
     if files := body.pop("files", None):
         contexts, citations = get_rag_context(
@@ -550,10 +555,7 @@ async def chat_completion_files_handler(body):
 
         log.debug(f"rag_contexts: {contexts}, citations: {citations}")
 
-    return body, {
-        **({"contexts": contexts} if contexts is not None else {}),
-        **({"citations": citations} if citations is not None else {}),
-    }
+    return body, {"contexts": contexts, "citations": citations}
 
 
 def is_chat_completion_request(request):
@@ -618,16 +620,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
             contexts.extend(flags.get("contexts", []))
             citations.extend(flags.get("citations", []))
         except Exception as e:
-            print(e)
-            pass
+            log.exception(e)
 
         try:
             body, flags = await chat_completion_files_handler(body)
             contexts.extend(flags.get("contexts", []))
             citations.extend(flags.get("citations", []))
         except Exception as e:
-            print(e)
-            pass
+            log.exception(e)
 
         # If context is not empty, insert it into the messages
         if len(contexts) > 0: