Explorar o código

Merge pull request #9068 from df-cgdm/main

**feat** Add user related headers when calling an external embedding api
Timothy Jaeryang Baek hai 2 meses
pai
achega
f6f8c08cb0

+ 40 - 14
backend/open_webui/retrieval/utils.py

@@ -15,8 +15,9 @@ from langchain_core.documents import Document
 from open_webui.config import VECTOR_DB
 from open_webui.retrieval.vector.connector import VECTOR_DB_CLIENT
 from open_webui.utils.misc import get_last_user_message
+from open_webui.models.users import UserModel
 
-from open_webui.env import SRC_LOG_LEVELS, OFFLINE_MODE
+from open_webui.env import SRC_LOG_LEVELS, OFFLINE_MODE, ENABLE_FORWARD_USER_INFO_HEADERS
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
@@ -64,6 +65,7 @@ def query_doc(
     collection_name: str,
     query_embedding: list[float],
     k: int,
+    user: UserModel=None
 ):
     try:
         result = VECTOR_DB_CLIENT.search(
@@ -256,29 +258,32 @@ def get_embedding_function(
     embedding_function,
     url,
     key,
-    embedding_batch_size,
+    embedding_batch_size
 ):
     if embedding_engine == "":
-        return lambda query: embedding_function.encode(query).tolist()
+        return lambda query, user=None: embedding_function.encode(query).tolist()
     elif embedding_engine in ["ollama", "openai"]:
-        func = lambda query: generate_embeddings(
+        func = lambda query, user=None: generate_embeddings(
             engine=embedding_engine,
             model=embedding_model,
             text=query,
             url=url,
             key=key,
+            user=user
         )
 
-        def generate_multiple(query, func):
+        def generate_multiple(query, user, func):
             if isinstance(query, list):
                 embeddings = []
                 for i in range(0, len(query), embedding_batch_size):
-                    embeddings.extend(func(query[i : i + embedding_batch_size]))
+                    embeddings.extend(func(query[i : i + embedding_batch_size], user=user))
                 return embeddings
             else:
-                return func(query)
+                return func(query, user)
 
-        return lambda query: generate_multiple(query, func)
+        return lambda query, user=None: generate_multiple(query, user, func)
+    else:
+        raise ValueError(f"Unknown embedding engine: {embedding_engine}")
 
 
 def get_sources_from_files(
@@ -423,7 +428,7 @@ def get_model_path(model: str, update_model: bool = False):
 
 
 def generate_openai_batch_embeddings(
-    model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = ""
+    model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = "", user: UserModel = None
 ) -> Optional[list[list[float]]]:
     try:
         r = requests.post(
@@ -431,6 +436,16 @@ def generate_openai_batch_embeddings(
             headers={
                 "Content-Type": "application/json",
                 "Authorization": f"Bearer {key}",
+                **(
+                    {
+                        "X-OpenWebUI-User-Name": user.name,
+                        "X-OpenWebUI-User-Id": user.id,
+                        "X-OpenWebUI-User-Email": user.email,
+                        "X-OpenWebUI-User-Role": user.role,
+                    }
+                    if ENABLE_FORWARD_USER_INFO_HEADERS and user
+                    else {}
+                ),
             },
             json={"input": texts, "model": model},
         )
@@ -446,7 +461,7 @@ def generate_openai_batch_embeddings(
 
 
 def generate_ollama_batch_embeddings(
-    model: str, texts: list[str], url: str, key: str = ""
+    model: str, texts: list[str], url: str, key: str = "", user: UserModel = None
 ) -> Optional[list[list[float]]]:
     try:
         r = requests.post(
@@ -454,6 +469,16 @@ def generate_ollama_batch_embeddings(
             headers={
                 "Content-Type": "application/json",
                 "Authorization": f"Bearer {key}",
+                **(
+                    {
+                        "X-OpenWebUI-User-Name": user.name,
+                        "X-OpenWebUI-User-Id": user.id,
+                        "X-OpenWebUI-User-Email": user.email,
+                        "X-OpenWebUI-User-Role": user.role,
+                    }
+                    if ENABLE_FORWARD_USER_INFO_HEADERS
+                    else {}
+                ),
             },
             json={"input": texts, "model": model},
         )
@@ -472,22 +497,23 @@ def generate_ollama_batch_embeddings(
 def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
     url = kwargs.get("url", "")
     key = kwargs.get("key", "")
+    user = kwargs.get("user")
 
     if engine == "ollama":
         if isinstance(text, list):
             embeddings = generate_ollama_batch_embeddings(
-                **{"model": model, "texts": text, "url": url, "key": key}
+                **{"model": model, "texts": text, "url": url, "key": key, "user": user}
             )
         else:
             embeddings = generate_ollama_batch_embeddings(
-                **{"model": model, "texts": [text], "url": url, "key": key}
+                **{"model": model, "texts": [text], "url": url, "key": key, "user": user}
             )
         return embeddings[0] if isinstance(text, str) else embeddings
     elif engine == "openai":
         if isinstance(text, list):
-            embeddings = generate_openai_batch_embeddings(model, text, url, key)
+            embeddings = generate_openai_batch_embeddings(model, text, url, key, user)
         else:
-            embeddings = generate_openai_batch_embeddings(model, [text], url, key)
+            embeddings = generate_openai_batch_embeddings(model, [text], url, key, user)
 
         return embeddings[0] if isinstance(text, str) else embeddings
 

+ 4 - 2
backend/open_webui/routers/files.py

@@ -71,7 +71,7 @@ def upload_file(
         )
 
         try:
-            process_file(request, ProcessFileForm(file_id=id))
+            process_file(request, ProcessFileForm(file_id=id), user=user)
             file_item = Files.get_file_by_id(id=id)
         except Exception as e:
             log.exception(e)
@@ -193,7 +193,9 @@ async def update_file_data_content_by_id(
     if file and (file.user_id == user.id or user.role == "admin"):
         try:
             process_file(
-                request, ProcessFileForm(file_id=id, content=form_data.content)
+                request,
+                ProcessFileForm(file_id=id, content=form_data.content),
+                user=user
             )
             file = Files.get_file_by_id(id=id)
         except Exception as e:

+ 6 - 2
backend/open_webui/routers/knowledge.py

@@ -289,7 +289,9 @@ def add_file_to_knowledge_by_id(
     # Add content to the vector database
     try:
         process_file(
-            request, ProcessFileForm(file_id=form_data.file_id, collection_name=id)
+            request,
+            ProcessFileForm(file_id=form_data.file_id, collection_name=id),
+            user=user
         )
     except Exception as e:
         log.debug(e)
@@ -372,7 +374,9 @@ def update_file_from_knowledge_by_id(
     # Add content to the vector database
     try:
         process_file(
-            request, ProcessFileForm(file_id=form_data.file_id, collection_name=id)
+            request,
+            ProcessFileForm(file_id=form_data.file_id, collection_name=id),
+            user=user
         )
     except Exception as e:
         raise HTTPException(

+ 4 - 4
backend/open_webui/routers/memories.py

@@ -57,7 +57,7 @@ async def add_memory(
             {
                 "id": memory.id,
                 "text": memory.content,
-                "vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
+                "vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
                 "metadata": {"created_at": memory.created_at},
             }
         ],
@@ -82,7 +82,7 @@ async def query_memory(
 ):
     results = VECTOR_DB_CLIENT.search(
         collection_name=f"user-memory-{user.id}",
-        vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content)],
+        vectors=[request.app.state.EMBEDDING_FUNCTION(form_data.content, user)],
         limit=form_data.k,
     )
 
@@ -105,7 +105,7 @@ async def reset_memory_from_vector_db(
             {
                 "id": memory.id,
                 "text": memory.content,
-                "vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
+                "vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
                 "metadata": {
                     "created_at": memory.created_at,
                     "updated_at": memory.updated_at,
@@ -160,7 +160,7 @@ async def update_memory_by_id(
                 {
                     "id": memory.id,
                     "text": memory.content,
-                    "vector": request.app.state.EMBEDDING_FUNCTION(memory.content),
+                    "vector": request.app.state.EMBEDDING_FUNCTION(memory.content, user),
                     "metadata": {
                         "created_at": memory.created_at,
                         "updated_at": memory.updated_at,

+ 15 - 9
backend/open_webui/routers/retrieval.py

@@ -666,6 +666,7 @@ def save_docs_to_vector_db(
     overwrite: bool = False,
     split: bool = True,
     add: bool = False,
+    user = None,
 ) -> bool:
     def _get_docs_info(docs: list[Document]) -> str:
         docs_info = set()
@@ -781,7 +782,8 @@ def save_docs_to_vector_db(
         )
 
         embeddings = embedding_function(
-            list(map(lambda x: x.replace("\n", " "), texts))
+            list(map(lambda x: x.replace("\n", " "), texts)),
+            user = user
         )
 
         items = [
@@ -939,6 +941,7 @@ def process_file(
                     "hash": hash,
                 },
                 add=(True if form_data.collection_name else False),
+                user=user
             )
 
             if result:
@@ -996,7 +999,7 @@ def process_text(
     text_content = form_data.content
     log.debug(f"text_content: {text_content}")
 
-    result = save_docs_to_vector_db(request, docs, collection_name)
+    result = save_docs_to_vector_db(request, docs, collection_name, user=user)
     if result:
         return {
             "status": True,
@@ -1029,7 +1032,7 @@ def process_youtube_video(
         content = " ".join([doc.page_content for doc in docs])
         log.debug(f"text_content: {content}")
 
-        save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
+        save_docs_to_vector_db(request, docs, collection_name, overwrite=True, user=user)
 
         return {
             "status": True,
@@ -1070,7 +1073,7 @@ def process_web(
         content = " ".join([doc.page_content for doc in docs])
 
         log.debug(f"text_content: {content}")
-        save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
+        save_docs_to_vector_db(request, docs, collection_name, overwrite=True, user=user)
 
         return {
             "status": True,
@@ -1286,7 +1289,7 @@ def process_web_search(
             requests_per_second=request.app.state.config.RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
         )
         docs = loader.load()
-        save_docs_to_vector_db(request, docs, collection_name, overwrite=True)
+        save_docs_to_vector_db(request, docs, collection_name, overwrite=True, user=user)
 
         return {
             "status": True,
@@ -1320,7 +1323,7 @@ def query_doc_handler(
             return query_doc_with_hybrid_search(
                 collection_name=form_data.collection_name,
                 query=form_data.query,
-                embedding_function=request.app.state.EMBEDDING_FUNCTION,
+                embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(query, user=user),
                 k=form_data.k if form_data.k else request.app.state.config.TOP_K,
                 reranking_function=request.app.state.rf,
                 r=(
@@ -1328,12 +1331,14 @@ def query_doc_handler(
                     if form_data.r
                     else request.app.state.config.RELEVANCE_THRESHOLD
                 ),
+                user=user
             )
         else:
             return query_doc(
                 collection_name=form_data.collection_name,
-                query_embedding=request.app.state.EMBEDDING_FUNCTION(form_data.query),
+                query_embedding=request.app.state.EMBEDDING_FUNCTION(form_data.query, user=user),
                 k=form_data.k if form_data.k else request.app.state.config.TOP_K,
+                user=user
             )
     except Exception as e:
         log.exception(e)
@@ -1362,7 +1367,7 @@ def query_collection_handler(
             return query_collection_with_hybrid_search(
                 collection_names=form_data.collection_names,
                 queries=[form_data.query],
-                embedding_function=request.app.state.EMBEDDING_FUNCTION,
+                embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(query, user=user),
                 k=form_data.k if form_data.k else request.app.state.config.TOP_K,
                 reranking_function=request.app.state.rf,
                 r=(
@@ -1375,7 +1380,7 @@ def query_collection_handler(
             return query_collection(
                 collection_names=form_data.collection_names,
                 queries=[form_data.query],
-                embedding_function=request.app.state.EMBEDDING_FUNCTION,
+                embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(query,user=user),
                 k=form_data.k if form_data.k else request.app.state.config.TOP_K,
             )
 
@@ -1523,6 +1528,7 @@ def process_files_batch(
                 docs=all_docs,
                 collection_name=collection_name,
                 add=True,
+                user=user,
             )
 
             # Update all files with collection name

+ 1 - 1
backend/open_webui/utils/middleware.py

@@ -634,7 +634,7 @@ async def chat_completion_files_handler(
                     lambda: get_sources_from_files(
                         files=files,
                         queries=queries,
-                        embedding_function=request.app.state.EMBEDDING_FUNCTION,
+                        embedding_function=lambda query: request.app.state.EMBEDDING_FUNCTION(query,user=user),
                         k=request.app.state.config.TOP_K,
                         reranking_function=request.app.state.rf,
                         r=request.app.state.config.RELEVANCE_THRESHOLD,