Преглед изворни кода

Merge pull request #1130 from open-webui/dev

fix: rag
Timothy Jaeryang Baek пре 1 година
родитељ
комит
11ca2703b0
2 измењених фајлова са 121 додато и 99 уклоњено
  1. 86 0
      backend/apps/rag/utils.py
  2. 35 99
      backend/main.py

+ 86 - 0
backend/apps/rag/utils.py

@@ -95,3 +95,89 @@ def rag_template(template: str, context: str, query: str):
     template = re.sub(r"\[query\]", query, template)
 
     return template
+
+
+def rag_messages(docs, messages, template, k, embedding_function):
+    print(docs)
+
+    last_user_message_idx = None
+    for i in range(len(messages) - 1, -1, -1):
+        if messages[i]["role"] == "user":
+            last_user_message_idx = i
+            break
+
+    user_message = messages[last_user_message_idx]
+
+    if isinstance(user_message["content"], list):
+        # Handle list content input
+        content_type = "list"
+        query = ""
+        for content_item in user_message["content"]:
+            if content_item["type"] == "text":
+                query = content_item["text"]
+                break
+    elif isinstance(user_message["content"], str):
+        # Handle text content input
+        content_type = "text"
+        query = user_message["content"]
+    else:
+        # Fallback in case the input does not match expected types
+        content_type = None
+        query = ""
+
+    relevant_contexts = []
+
+    for doc in docs:
+        context = None
+
+        try:
+            if doc["type"] == "collection":
+                context = query_collection(
+                    collection_names=doc["collection_names"],
+                    query=query,
+                    k=k,
+                    embedding_function=embedding_function,
+                )
+            else:
+                context = query_doc(
+                    collection_name=doc["collection_name"],
+                    query=query,
+                    k=k,
+                    embedding_function=embedding_function,
+                )
+        except Exception as e:
+            print(e)
+            context = None
+
+        relevant_contexts.append(context)
+
+    context_string = ""
+    for context in relevant_contexts:
+        if context:
+            context_string += " ".join(context["documents"][0]) + "\n"
+
+    ra_content = rag_template(
+        template=template,
+        context=context_string,
+        query=query,
+    )
+
+    if content_type == "list":
+        new_content = []
+        for content_item in user_message["content"]:
+            if content_item["type"] == "text":
+                # Update the text item's content with ra_content
+                new_content.append({"type": "text", "text": ra_content})
+            else:
+                # Keep other types of content as they are
+                new_content.append(content_item)
+        new_user_message = {**user_message, "content": new_content}
+    else:
+        new_user_message = {
+            **user_message,
+            "content": ra_content,
+        }
+
+    messages[last_user_message_idx] = new_user_message
+
+    return messages

+ 35 - 99
backend/main.py

@@ -28,7 +28,7 @@ from typing import List
 
 
 from utils.utils import get_admin_user
-from apps.rag.utils import query_doc, query_collection, rag_template
+from apps.rag.utils import rag_messages
 
 from config import (
     WEBUI_NAME,
@@ -60,19 +60,6 @@ app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
 
 origins = ["*"]
 
-app.add_middleware(
-    CORSMiddleware,
-    allow_origins=origins,
-    allow_credentials=True,
-    allow_methods=["*"],
-    allow_headers=["*"],
-)
-
-
-@app.on_event("startup")
-async def on_startup():
-    await litellm_app_startup()
-
 
 class RAGMiddleware(BaseHTTPMiddleware):
     async def dispatch(self, request: Request, call_next):
@@ -91,98 +78,33 @@ class RAGMiddleware(BaseHTTPMiddleware):
             # Example: Add a new key-value pair or modify existing ones
             # data["modified"] = True  # Example modification
             if "docs" in data:
-                docs = data["docs"]
-                print(docs)
-
-                last_user_message_idx = None
-                for i in range(len(data["messages"]) - 1, -1, -1):
-                    if data["messages"][i]["role"] == "user":
-                        last_user_message_idx = i
-                        break
-
-                user_message = data["messages"][last_user_message_idx]
-
-                if isinstance(user_message["content"], list):
-                    # Handle list content input
-                    content_type = "list"
-                    query = ""
-                    for content_item in user_message["content"]:
-                        if content_item["type"] == "text":
-                            query = content_item["text"]
-                            break
-                elif isinstance(user_message["content"], str):
-                    # Handle text content input
-                    content_type = "text"
-                    query = user_message["content"]
-                else:
-                    # Fallback in case the input does not match expected types
-                    content_type = None
-                    query = ""
-
-                relevant_contexts = []
-
-                for doc in docs:
-                    context = None
-
-                    try:
-                        if doc["type"] == "collection":
-                            context = query_collection(
-                                collection_names=doc["collection_names"],
-                                query=query,
-                                k=rag_app.state.TOP_K,
-                                embedding_function=rag_app.state.sentence_transformer_ef,
-                            )
-                        else:
-                            context = query_doc(
-                                collection_name=doc["collection_name"],
-                                query=query,
-                                k=rag_app.state.TOP_K,
-                                embedding_function=rag_app.state.sentence_transformer_ef,
-                            )
-                    except Exception as e:
-                        print(e)
-                        context = None
-
-                    relevant_contexts.append(context)
-
-                context_string = ""
-                for context in relevant_contexts:
-                    if context:
-                        context_string += " ".join(context["documents"][0]) + "\n"
-
-                ra_content = rag_template(
-                    template=rag_app.state.RAG_TEMPLATE,
-                    context=context_string,
-                    query=query,
-                )
 
-                if content_type == "list":
-                    new_content = []
-                    for content_item in user_message["content"]:
-                        if content_item["type"] == "text":
-                            # Update the text item's content with ra_content
-                            new_content.append({"type": "text", "text": ra_content})
-                        else:
-                            # Keep other types of content as they are
-                            new_content.append(content_item)
-                    new_user_message = {**user_message, "content": new_content}
-                else:
-                    new_user_message = {
-                        **user_message,
-                        "content": ra_content,
-                    }
-
-                data["messages"][last_user_message_idx] = new_user_message
+                data = {**data}
+                data["messages"] = rag_messages(
+                    data["docs"],
+                    data["messages"],
+                    rag_app.state.RAG_TEMPLATE,
+                    rag_app.state.TOP_K,
+                    rag_app.state.sentence_transformer_ef,
+                )
                 del data["docs"]
 
                 print(data["messages"])
 
             modified_body_bytes = json.dumps(data).encode("utf-8")
 
-            # Create a new request with the modified body
-            scope = request.scope
-            scope["body"] = modified_body_bytes
-            request = Request(scope, receive=lambda: self._receive(modified_body_bytes))
+            # Replace the request body with the modified one
+            request._body = modified_body_bytes
+
+            # Set custom header to ensure content-length matches new body length
+            request.headers.__dict__["_list"] = [
+                (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
+                *[
+                    (k, v)
+                    for k, v in request.headers.raw
+                    if k.lower() != b"content-length"
+                ],
+            ]
 
         response = await call_next(request)
         return response
@@ -194,6 +116,15 @@ class RAGMiddleware(BaseHTTPMiddleware):
 app.add_middleware(RAGMiddleware)
 
 
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=origins,
+    allow_credentials=True,
+    allow_methods=["*"],
+    allow_headers=["*"],
+)
+
+
 @app.middleware("http")
 async def check_url(request: Request, call_next):
     start_time = int(time.time())
@@ -204,6 +135,11 @@ async def check_url(request: Request, call_next):
     return response
 
 
+@app.on_event("startup")
+async def on_startup():
+    await litellm_app_startup()
+
+
 app.mount("/api/v1", webui_app)
 app.mount("/litellm/api", litellm_app)