Forráskód Böngészése

fix: query generation

Timothy Jaeryang Baek 5 hónapja
szülő
commit
429fa2befa
1 módosított fájl, 22 hozzáadás és 21 törlés
  1. 22 21
      backend/open_webui/main.py

+ 22 - 21
backend/open_webui/main.py

@@ -515,32 +515,32 @@ async def chat_completion_files_handler(
 ) -> tuple[dict, dict[str, list]]:
     sources = []
 
-    try:
-        queries_response = await generate_queries(
-            {
-                "model": body["model"],
-                "messages": body["messages"],
-                "type": "retrieval",
-            },
-            user,
-        )
-        queries_response = queries_response["choices"][0]["message"]["content"]
-
+    if files := body.get("metadata", {}).get("files", None):
         try:
-            queries_response = json.loads(queries_response)
-        except Exception as e:
-            queries_response = {"queries": []}
+            queries_response = await generate_queries(
+                {
+                    "model": body["model"],
+                    "messages": body["messages"],
+                    "type": "retrieval",
+                },
+                user,
+            )
+            queries_response = queries_response["choices"][0]["message"]["content"]
 
-        queries = queries_response.get("queries", [])
-    except Exception as e:
-        queries = []
+            try:
+                queries_response = json.loads(queries_response)
+            except Exception as e:
+                queries_response = {"queries": []}
+
+            queries = queries_response.get("queries", [])
+        except Exception as e:
+            queries = []
 
-    if len(queries) == 0:
-        queries = [get_last_user_message(body["messages"])]
+        if len(queries) == 0:
+            queries = [get_last_user_message(body["messages"])]
 
-    print(f"{queries=}")
+        print(f"{queries=}")
 
-    if files := body.get("metadata", {}).get("files", None):
         sources = get_sources_from_files(
             files=files,
             queries=queries,
@@ -691,6 +691,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
 
         # If context is not empty, insert it into the messages
         if len(sources) > 0:
+            print("\n\n\n\n\n\n\nHI\n\n\n\n\n\n")
             context_string = ""
             for source_idx, source in enumerate(sources):
                 source_id = source.get("source", {}).get("name", "")