فهرست منبع

refactor: Update GenerateEmbeddingsForm to support batch processing

refactor: Update embedding batch size handling in RAG configuration

refactor: add query_doc query caching

refactor: update logging statements in generate_chat_completion function

change embedding_batch_size to Optional
Peter De-Ath 7 ماه پیش
والد
کامیت
885b9f1ece

+ 29 - 55
backend/open_webui/apps/ollama/main.py

@@ -547,8 +547,8 @@ class GenerateEmbeddingsForm(BaseModel):
 
 
 class GenerateEmbedForm(BaseModel):
 class GenerateEmbedForm(BaseModel):
     model: str
     model: str
-    input: str
-    truncate: Optional[bool]
+    input: list[str]
+    truncate: Optional[bool] = None
     options: Optional[dict] = None
     options: Optional[dict] = None
     keep_alive: Optional[Union[int, str]] = None
     keep_alive: Optional[Union[int, str]] = None
 
 
@@ -560,48 +560,7 @@ async def generate_embeddings(
     url_idx: Optional[int] = None,
     url_idx: Optional[int] = None,
     user=Depends(get_verified_user),
     user=Depends(get_verified_user),
 ):
 ):
-    if url_idx is None:
-        model = form_data.model
-
-        if ":" not in model:
-            model = f"{model}:latest"
-
-        if model in app.state.MODELS:
-            url_idx = random.choice(app.state.MODELS[model]["urls"])
-        else:
-            raise HTTPException(
-                status_code=400,
-                detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
-            )
-
-    url = app.state.config.OLLAMA_BASE_URLS[url_idx]
-    log.info(f"url: {url}")
-
-    r = requests.request(
-        method="POST",
-        url=f"{url}/api/embed",
-        headers={"Content-Type": "application/json"},
-        data=form_data.model_dump_json(exclude_none=True).encode(),
-    )
-    try:
-        r.raise_for_status()
-
-        return r.json()
-    except Exception as e:
-        log.exception(e)
-        error_detail = "Open WebUI: Server Connection Error"
-        if r is not None:
-            try:
-                res = r.json()
-                if "error" in res:
-                    error_detail = f"Ollama: {res['error']}"
-            except Exception:
-                error_detail = f"Ollama: {e}"
-
-        raise HTTPException(
-            status_code=r.status_code if r else 500,
-            detail=error_detail,
-        )
+    return generate_ollama_batch_embeddings(form_data, url_idx)
 
 
 
 
 @app.post("/api/embeddings")
 @app.post("/api/embeddings")
@@ -611,6 +570,15 @@ async def generate_embeddings(
     url_idx: Optional[int] = None,
     url_idx: Optional[int] = None,
     user=Depends(get_verified_user),
     user=Depends(get_verified_user),
 ):
 ):
+    return generate_ollama_embeddings(form_data=form_data, url_idx=url_idx)
+
+
+def generate_ollama_embeddings(
+    form_data: GenerateEmbeddingsForm,
+    url_idx: Optional[int] = None,
+):
+    log.info(f"generate_ollama_embeddings {form_data}")
+
     if url_idx is None:
     if url_idx is None:
         model = form_data.model
         model = form_data.model
 
 
@@ -637,7 +605,14 @@ async def generate_embeddings(
     try:
     try:
         r.raise_for_status()
         r.raise_for_status()
 
 
-        return r.json()
+        data = r.json()
+
+        log.info(f"generate_ollama_embeddings {data}")
+
+        if "embedding" in data:
+            return data
+        else:
+            raise Exception("Something went wrong :/")
     except Exception as e:
     except Exception as e:
         log.exception(e)
         log.exception(e)
         error_detail = "Open WebUI: Server Connection Error"
         error_detail = "Open WebUI: Server Connection Error"
@@ -655,11 +630,11 @@ async def generate_embeddings(
         )
         )
 
 
 
 
-def generate_ollama_embeddings(
-    form_data: GenerateEmbeddingsForm,
+def generate_ollama_batch_embeddings(
+    form_data: GenerateEmbedForm,
     url_idx: Optional[int] = None,
     url_idx: Optional[int] = None,
 ):
 ):
-    log.info(f"generate_ollama_embeddings {form_data}")
+    log.info(f"generate_ollama_batch_embeddings {form_data}")
 
 
     if url_idx is None:
     if url_idx is None:
         model = form_data.model
         model = form_data.model
@@ -680,7 +655,7 @@ def generate_ollama_embeddings(
 
 
     r = requests.request(
     r = requests.request(
         method="POST",
         method="POST",
-        url=f"{url}/api/embeddings",
+        url=f"{url}/api/embed",
         headers={"Content-Type": "application/json"},
         headers={"Content-Type": "application/json"},
         data=form_data.model_dump_json(exclude_none=True).encode(),
         data=form_data.model_dump_json(exclude_none=True).encode(),
     )
     )
@@ -689,10 +664,10 @@ def generate_ollama_embeddings(
 
 
         data = r.json()
         data = r.json()
 
 
-        log.info(f"generate_ollama_embeddings {data}")
+        log.info(f"generate_ollama_batch_embeddings {data}")
 
 
-        if "embedding" in data:
-            return data["embedding"]
+        if "embeddings" in data:
+            return data
         else:
         else:
             raise Exception("Something went wrong :/")
             raise Exception("Something went wrong :/")
     except Exception as e:
     except Exception as e:
@@ -788,8 +763,7 @@ async def generate_chat_completion(
     user=Depends(get_verified_user),
     user=Depends(get_verified_user),
 ):
 ):
     payload = {**form_data.model_dump(exclude_none=True)}
     payload = {**form_data.model_dump(exclude_none=True)}
-    log.debug(f"{payload = }")
-
+    log.debug(f"generate_chat_completion() - 1.payload = {payload}")
     if "metadata" in payload:
     if "metadata" in payload:
         del payload["metadata"]
         del payload["metadata"]
 
 
@@ -824,7 +798,7 @@ async def generate_chat_completion(
 
 
     url = get_ollama_url(url_idx, payload["model"])
     url = get_ollama_url(url_idx, payload["model"])
     log.info(f"url: {url}")
     log.info(f"url: {url}")
-    log.debug(payload)
+    log.debug(f"generate_chat_completion() - 2.payload = {payload}")
 
 
     return await post_streaming_url(
     return await post_streaming_url(
         f"{url}/api/chat",
         f"{url}/api/chat",

+ 10 - 14
backend/open_webui/apps/retrieval/main.py

@@ -63,7 +63,7 @@ from open_webui.config import (
     RAG_EMBEDDING_MODEL,
     RAG_EMBEDDING_MODEL,
     RAG_EMBEDDING_MODEL_AUTO_UPDATE,
     RAG_EMBEDDING_MODEL_AUTO_UPDATE,
     RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
     RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE,
-    RAG_EMBEDDING_OPENAI_BATCH_SIZE,
+    RAG_EMBEDDING_BATCH_SIZE,
     RAG_FILE_MAX_COUNT,
     RAG_FILE_MAX_COUNT,
     RAG_FILE_MAX_SIZE,
     RAG_FILE_MAX_SIZE,
     RAG_OPENAI_API_BASE_URL,
     RAG_OPENAI_API_BASE_URL,
@@ -134,7 +134,7 @@ app.state.config.CHUNK_OVERLAP = CHUNK_OVERLAP
 
 
 app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
 app.state.config.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
 app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
 app.state.config.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
-app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = RAG_EMBEDDING_OPENAI_BATCH_SIZE
+app.state.config.RAG_EMBEDDING_BATCH_SIZE = RAG_EMBEDDING_BATCH_SIZE
 app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
 app.state.config.RAG_RERANKING_MODEL = RAG_RERANKING_MODEL
 app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
 app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
 
 
@@ -233,7 +233,7 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
     app.state.sentence_transformer_ef,
     app.state.sentence_transformer_ef,
     app.state.config.OPENAI_API_KEY,
     app.state.config.OPENAI_API_KEY,
     app.state.config.OPENAI_API_BASE_URL,
     app.state.config.OPENAI_API_BASE_URL,
-    app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
+    app.state.config.RAG_EMBEDDING_BATCH_SIZE,
 )
 )
 
 
 app.add_middleware(
 app.add_middleware(
@@ -267,7 +267,7 @@ async def get_status():
         "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
         "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
         "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
         "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
         "reranking_model": app.state.config.RAG_RERANKING_MODEL,
         "reranking_model": app.state.config.RAG_RERANKING_MODEL,
-        "openai_batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
+        "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
     }
     }
 
 
 
 
@@ -277,10 +277,10 @@ async def get_embedding_config(user=Depends(get_admin_user)):
         "status": True,
         "status": True,
         "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
         "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
         "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
         "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
+        "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
         "openai_config": {
         "openai_config": {
             "url": app.state.config.OPENAI_API_BASE_URL,
             "url": app.state.config.OPENAI_API_BASE_URL,
             "key": app.state.config.OPENAI_API_KEY,
             "key": app.state.config.OPENAI_API_KEY,
-            "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
         },
         },
     }
     }
 
 
@@ -296,13 +296,13 @@ async def get_reraanking_config(user=Depends(get_admin_user)):
 class OpenAIConfigForm(BaseModel):
 class OpenAIConfigForm(BaseModel):
     url: str
     url: str
     key: str
     key: str
-    batch_size: Optional[int] = None
 
 
 
 
 class EmbeddingModelUpdateForm(BaseModel):
 class EmbeddingModelUpdateForm(BaseModel):
     openai_config: Optional[OpenAIConfigForm] = None
     openai_config: Optional[OpenAIConfigForm] = None
     embedding_engine: str
     embedding_engine: str
     embedding_model: str
     embedding_model: str
+    embedding_batch_size: Optional[int] = 1
 
 
 
 
 @app.post("/embedding/update")
 @app.post("/embedding/update")
@@ -320,11 +320,7 @@ async def update_embedding_config(
             if form_data.openai_config is not None:
             if form_data.openai_config is not None:
                 app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
                 app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
                 app.state.config.OPENAI_API_KEY = form_data.openai_config.key
                 app.state.config.OPENAI_API_KEY = form_data.openai_config.key
-                app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE = (
-                    form_data.openai_config.batch_size
-                    if form_data.openai_config.batch_size
-                    else 1
-                )
+            app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
 
 
         update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
         update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
 
 
@@ -334,17 +330,17 @@ async def update_embedding_config(
             app.state.sentence_transformer_ef,
             app.state.sentence_transformer_ef,
             app.state.config.OPENAI_API_KEY,
             app.state.config.OPENAI_API_KEY,
             app.state.config.OPENAI_API_BASE_URL,
             app.state.config.OPENAI_API_BASE_URL,
-            app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
+            app.state.config.RAG_EMBEDDING_BATCH_SIZE,
         )
         )
 
 
         return {
         return {
             "status": True,
             "status": True,
             "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
             "embedding_engine": app.state.config.RAG_EMBEDDING_ENGINE,
             "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
             "embedding_model": app.state.config.RAG_EMBEDDING_MODEL,
+            "embedding_batch_size": app.state.config.RAG_EMBEDDING_BATCH_SIZE,
             "openai_config": {
             "openai_config": {
                 "url": app.state.config.OPENAI_API_BASE_URL,
                 "url": app.state.config.OPENAI_API_BASE_URL,
                 "key": app.state.config.OPENAI_API_KEY,
                 "key": app.state.config.OPENAI_API_KEY,
-                "batch_size": app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
             },
             },
         }
         }
     except Exception as e:
     except Exception as e:
@@ -690,7 +686,7 @@ def save_docs_to_vector_db(
             app.state.sentence_transformer_ef,
             app.state.sentence_transformer_ef,
             app.state.config.OPENAI_API_KEY,
             app.state.config.OPENAI_API_KEY,
             app.state.config.OPENAI_API_BASE_URL,
             app.state.config.OPENAI_API_BASE_URL,
-            app.state.config.RAG_EMBEDDING_OPENAI_BATCH_SIZE,
+            app.state.config.RAG_EMBEDDING_BATCH_SIZE,
         )
         )
 
 
         embeddings = embedding_function(
         embeddings = embedding_function(

+ 25 - 17
backend/open_webui/apps/retrieval/utils.py

@@ -12,8 +12,8 @@ from langchain_core.documents import Document
 
 
 
 
 from open_webui.apps.ollama.main import (
 from open_webui.apps.ollama.main import (
-    GenerateEmbeddingsForm,
-    generate_ollama_embeddings,
+    GenerateEmbedForm,
+    generate_ollama_batch_embeddings,
 )
 )
 from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
 from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
 from open_webui.utils.misc import get_last_user_message
 from open_webui.utils.misc import get_last_user_message
@@ -71,7 +71,7 @@ def query_doc(
     try:
     try:
         result = VECTOR_DB_CLIENT.search(
         result = VECTOR_DB_CLIENT.search(
             collection_name=collection_name,
             collection_name=collection_name,
-            vectors=[query_embedding],
+            vectors=query_embedding,
             limit=k,
             limit=k,
         )
         )
 
 
@@ -265,19 +265,15 @@ def get_embedding_function(
     embedding_function,
     embedding_function,
     openai_key,
     openai_key,
     openai_url,
     openai_url,
-    batch_size,
+    embedding_batch_size,
 ):
 ):
     if embedding_engine == "":
     if embedding_engine == "":
         return lambda query: embedding_function.encode(query).tolist()
         return lambda query: embedding_function.encode(query).tolist()
     elif embedding_engine in ["ollama", "openai"]:
     elif embedding_engine in ["ollama", "openai"]:
         if embedding_engine == "ollama":
         if embedding_engine == "ollama":
             func = lambda query: generate_ollama_embeddings(
             func = lambda query: generate_ollama_embeddings(
-                GenerateEmbeddingsForm(
-                    **{
-                        "model": embedding_model,
-                        "prompt": query,
-                    }
-                )
+                model=embedding_model,
+                input=query,
             )
             )
         elif embedding_engine == "openai":
         elif embedding_engine == "openai":
             func = lambda query: generate_openai_embeddings(
             func = lambda query: generate_openai_embeddings(
@@ -289,13 +285,10 @@ def get_embedding_function(
 
 
         def generate_multiple(query, f):
         def generate_multiple(query, f):
             if isinstance(query, list):
             if isinstance(query, list):
-                if embedding_engine == "openai":
-                    embeddings = []
-                    for i in range(0, len(query), batch_size):
-                        embeddings.extend(f(query[i : i + batch_size]))
-                    return embeddings
-                else:
-                    return [f(q) for q in query]
+                embeddings = []
+                for i in range(0, len(query), embedding_batch_size):
+                    embeddings.extend(f(query[i : i + embedding_batch_size]))
+                return embeddings
             else:
             else:
                 return f(query)
                 return f(query)
 
 
@@ -481,6 +474,21 @@ def generate_openai_batch_embeddings(
         return None
         return None
 
 
 
 
+def generate_ollama_embeddings(
+    model: str, input: list[str]
+) -> Optional[list[list[float]]]:
+    if isinstance(input, list):
+        embeddings = generate_ollama_batch_embeddings(
+            GenerateEmbedForm(**{"model": model, "input": input})
+        )
+    else:
+        embeddings = generate_ollama_batch_embeddings(
+            GenerateEmbedForm(**{"model": model, "input": [input]})
+        )
+
+    return embeddings["embeddings"]
+
+
 import operator
 import operator
 from typing import Optional, Sequence
 from typing import Optional, Sequence
 
 

+ 4 - 4
backend/open_webui/config.py

@@ -986,10 +986,10 @@ RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE = (
     os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
     os.environ.get("RAG_EMBEDDING_MODEL_TRUST_REMOTE_CODE", "").lower() == "true"
 )
 )
 
 
-RAG_EMBEDDING_OPENAI_BATCH_SIZE = PersistentConfig(
-    "RAG_EMBEDDING_OPENAI_BATCH_SIZE",
-    "rag.embedding_openai_batch_size",
-    int(os.environ.get("RAG_EMBEDDING_OPENAI_BATCH_SIZE", "1")),
+RAG_EMBEDDING_BATCH_SIZE = PersistentConfig(
+    "RAG_EMBEDDING_BATCH_SIZE",
+    "rag.embedding_batch_size",
+    int(os.environ.get("RAG_EMBEDDING_BATCH_SIZE", "1")),
 )
 )
 
 
 RAG_RERANKING_MODEL = PersistentConfig(
 RAG_RERANKING_MODEL = PersistentConfig(

+ 1 - 1
src/lib/apis/retrieval/index.ts

@@ -200,13 +200,13 @@ export const getEmbeddingConfig = async (token: string) => {
 type OpenAIConfigForm = {
 type OpenAIConfigForm = {
 	key: string;
 	key: string;
 	url: string;
 	url: string;
-	batch_size: number;
 };
 };
 
 
 type EmbeddingModelUpdateForm = {
 type EmbeddingModelUpdateForm = {
 	openai_config?: OpenAIConfigForm;
 	openai_config?: OpenAIConfigForm;
 	embedding_engine: string;
 	embedding_engine: string;
 	embedding_model: string;
 	embedding_model: string;
+	embedding_batch_size?: number;
 };
 };
 
 
 export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => {
 export const updateEmbeddingConfig = async (token: string, payload: EmbeddingModelUpdateForm) => {

+ 12 - 6
src/lib/components/admin/Settings/Documents.svelte

@@ -38,6 +38,7 @@
 
 
 	let embeddingEngine = '';
 	let embeddingEngine = '';
 	let embeddingModel = '';
 	let embeddingModel = '';
+	let embeddingBatchSize = 1;
 	let rerankingModel = '';
 	let rerankingModel = '';
 
 
 	let fileMaxSize = null;
 	let fileMaxSize = null;
@@ -53,7 +54,6 @@
 
 
 	let OpenAIKey = '';
 	let OpenAIKey = '';
 	let OpenAIUrl = '';
 	let OpenAIUrl = '';
-	let OpenAIBatchSize = 1;
 
 
 	let querySettings = {
 	let querySettings = {
 		template: '',
 		template: '',
@@ -100,12 +100,16 @@
 		const res = await updateEmbeddingConfig(localStorage.token, {
 		const res = await updateEmbeddingConfig(localStorage.token, {
 			embedding_engine: embeddingEngine,
 			embedding_engine: embeddingEngine,
 			embedding_model: embeddingModel,
 			embedding_model: embeddingModel,
+			...(embeddingEngine === 'openai' || embeddingEngine === 'ollama'
+				? {
+						embedding_batch_size: embeddingBatchSize
+					}
+				: {}),
 			...(embeddingEngine === 'openai'
 			...(embeddingEngine === 'openai'
 				? {
 				? {
 						openai_config: {
 						openai_config: {
 							key: OpenAIKey,
 							key: OpenAIKey,
-							url: OpenAIUrl,
-							batch_size: OpenAIBatchSize
+							url: OpenAIUrl
 						}
 						}
 					}
 					}
 				: {})
 				: {})
@@ -193,10 +197,10 @@
 		if (embeddingConfig) {
 		if (embeddingConfig) {
 			embeddingEngine = embeddingConfig.embedding_engine;
 			embeddingEngine = embeddingConfig.embedding_engine;
 			embeddingModel = embeddingConfig.embedding_model;
 			embeddingModel = embeddingConfig.embedding_model;
+			embeddingBatchSize = embeddingConfig.embedding_batch_size ?? 1;
 
 
 			OpenAIKey = embeddingConfig.openai_config.key;
 			OpenAIKey = embeddingConfig.openai_config.key;
 			OpenAIUrl = embeddingConfig.openai_config.url;
 			OpenAIUrl = embeddingConfig.openai_config.url;
-			OpenAIBatchSize = embeddingConfig.openai_config.batch_size ?? 1;
 		}
 		}
 	};
 	};
 
 
@@ -309,6 +313,8 @@
 
 
 					<SensitiveInput placeholder={$i18n.t('API Key')} bind:value={OpenAIKey} />
 					<SensitiveInput placeholder={$i18n.t('API Key')} bind:value={OpenAIKey} />
 				</div>
 				</div>
+			{/if}
+			{#if embeddingEngine === 'ollama' || embeddingEngine === 'openai'}
 				<div class="flex mt-0.5 space-x-2">
 				<div class="flex mt-0.5 space-x-2">
 					<div class=" self-center text-xs font-medium">{$i18n.t('Embedding Batch Size')}</div>
 					<div class=" self-center text-xs font-medium">{$i18n.t('Embedding Batch Size')}</div>
 					<div class=" flex-1">
 					<div class=" flex-1">
@@ -318,13 +324,13 @@
 							min="1"
 							min="1"
 							max="2048"
 							max="2048"
 							step="1"
 							step="1"
-							bind:value={OpenAIBatchSize}
+							bind:value={embeddingBatchSize}
 							class="w-full h-2 rounded-lg appearance-none cursor-pointer dark:bg-gray-700"
 							class="w-full h-2 rounded-lg appearance-none cursor-pointer dark:bg-gray-700"
 						/>
 						/>
 					</div>
 					</div>
 					<div class="">
 					<div class="">
 						<input
 						<input
-							bind:value={OpenAIBatchSize}
+							bind:value={embeddingBatchSize}
 							type="number"
 							type="number"
 							class=" bg-transparent text-center w-14"
 							class=" bg-transparent text-center w-14"
 							min="-2"
 							min="-2"