Sfoglia il codice sorgente

refac: ollama setting for rag

Timothy Jaeryang Baek 5 mesi fa
parent
commit
20321e5271

+ 54 - 6
backend/open_webui/apps/retrieval/main.py

@@ -75,6 +75,8 @@ from open_webui.config import (
     RAG_FILE_MAX_SIZE,
     RAG_OPENAI_API_BASE_URL,
     RAG_OPENAI_API_KEY,
+    RAG_OLLAMA_BASE_URL,
+    RAG_OLLAMA_API_KEY,
     RAG_RELEVANCE_THRESHOLD,
     RAG_RERANKING_MODEL,
     RAG_RERANKING_MODEL_AUTO_UPDATE,
@@ -163,6 +165,9 @@ app.state.config.RAG_TEMPLATE = RAG_TEMPLATE
 app.state.config.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
 app.state.config.OPENAI_API_KEY = RAG_OPENAI_API_KEY
 
+app.state.config.OLLAMA_BASE_URL = RAG_OLLAMA_BASE_URL
+app.state.config.OLLAMA_API_KEY = RAG_OLLAMA_API_KEY
+
 app.state.config.PDF_EXTRACT_IMAGES = PDF_EXTRACT_IMAGES
 
 app.state.config.YOUTUBE_LOADER_LANGUAGE = YOUTUBE_LOADER_LANGUAGE
@@ -261,8 +266,16 @@ app.state.EMBEDDING_FUNCTION = get_embedding_function(
     app.state.config.RAG_EMBEDDING_ENGINE,
     app.state.config.RAG_EMBEDDING_MODEL,
     app.state.sentence_transformer_ef,
-    app.state.config.OPENAI_API_KEY,
-    app.state.config.OPENAI_API_BASE_URL,
+    (
+        app.state.config.OPENAI_API_BASE_URL
+        if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
+        else app.state.config.OLLAMA_BASE_URL
+    ),
+    (
+        app.state.config.OPENAI_API_KEY
+        if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
+        else app.state.config.OLLAMA_API_KEY
+    ),
     app.state.config.RAG_EMBEDDING_BATCH_SIZE,
 )
 
@@ -312,6 +325,10 @@ async def get_embedding_config(user=Depends(get_admin_user)):
             "url": app.state.config.OPENAI_API_BASE_URL,
             "key": app.state.config.OPENAI_API_KEY,
         },
+        "ollama_config": {
+            "url": app.state.config.OLLAMA_BASE_URL,
+            "key": app.state.config.OLLAMA_API_KEY,
+        },
     }
 
 
@@ -328,8 +345,14 @@ class OpenAIConfigForm(BaseModel):
     key: str
 
 
+class OllamaConfigForm(BaseModel):
+    url: str
+    key: str
+
+
 class EmbeddingModelUpdateForm(BaseModel):
     openai_config: Optional[OpenAIConfigForm] = None
+    ollama_config: Optional[OllamaConfigForm] = None
     embedding_engine: str
     embedding_model: str
     embedding_batch_size: Optional[int] = 1
@@ -350,6 +373,11 @@ async def update_embedding_config(
             if form_data.openai_config is not None:
                 app.state.config.OPENAI_API_BASE_URL = form_data.openai_config.url
                 app.state.config.OPENAI_API_KEY = form_data.openai_config.key
+
+            if form_data.ollama_config is not None:
+                app.state.config.OLLAMA_BASE_URL = form_data.ollama_config.url
+                app.state.config.OLLAMA_API_KEY = form_data.ollama_config.key
+
             app.state.config.RAG_EMBEDDING_BATCH_SIZE = form_data.embedding_batch_size
 
         update_embedding_model(app.state.config.RAG_EMBEDDING_MODEL)
@@ -358,8 +386,16 @@ async def update_embedding_config(
             app.state.config.RAG_EMBEDDING_ENGINE,
             app.state.config.RAG_EMBEDDING_MODEL,
             app.state.sentence_transformer_ef,
-            app.state.config.OPENAI_API_KEY,
-            app.state.config.OPENAI_API_BASE_URL,
+            (
+                app.state.config.OPENAI_API_BASE_URL
+                if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
+                else app.state.config.OLLAMA_BASE_URL
+            ),
+            (
+                app.state.config.OPENAI_API_KEY
+                if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
+                else app.state.config.OLLAMA_API_KEY
+            ),
             app.state.config.RAG_EMBEDDING_BATCH_SIZE,
         )
 
@@ -372,6 +408,10 @@ async def update_embedding_config(
                 "url": app.state.config.OPENAI_API_BASE_URL,
                 "key": app.state.config.OPENAI_API_KEY,
             },
+            "ollama_config": {
+                "url": app.state.config.OLLAMA_BASE_URL,
+                "key": app.state.config.OLLAMA_API_KEY,
+            },
         }
     except Exception as e:
         log.exception(f"Problem updating embedding model: {e}")
@@ -785,8 +825,16 @@ def save_docs_to_vector_db(
             app.state.config.RAG_EMBEDDING_ENGINE,
             app.state.config.RAG_EMBEDDING_MODEL,
             app.state.sentence_transformer_ef,
-            app.state.config.OPENAI_API_KEY,
-            app.state.config.OPENAI_API_BASE_URL,
+            (
+                app.state.config.OPENAI_API_BASE_URL
+                if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
+                else app.state.config.OLLAMA_BASE_URL
+            ),
+            (
+                app.state.config.OPENAI_API_KEY
+                if app.state.config.RAG_EMBEDDING_ENGINE == "openai"
+                else app.state.config.OLLAMA_API_KEY
+            ),
             app.state.config.RAG_EMBEDDING_BATCH_SIZE,
         )
 

+ 43 - 35
backend/open_webui/apps/retrieval/utils.py

@@ -11,11 +11,6 @@ from langchain.retrievers import ContextualCompressionRetriever, EnsembleRetriev
 from langchain_community.retrievers import BM25Retriever
 from langchain_core.documents import Document
 
-
-from open_webui.apps.ollama.main import (
-    GenerateEmbedForm,
-    generate_ollama_batch_embeddings,
-)
 from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
 from open_webui.utils.misc import get_last_user_message
 
@@ -285,25 +280,19 @@ def get_embedding_function(
     embedding_engine,
     embedding_model,
     embedding_function,
-    openai_key,
-    openai_url,
+    url,
+    key,
     embedding_batch_size,
 ):
     if embedding_engine == "":
         return lambda query: embedding_function.encode(query).tolist()
     elif embedding_engine in ["ollama", "openai"]:
-
-        # Wrapper to run the async generate_embeddings synchronously.
-        def sync_generate_embeddings(*args, **kwargs):
-            return asyncio.run(generate_embeddings(*args, **kwargs))
-
-        # Semantic expectation from the original version (using sync wrapper).
-        func = lambda query: sync_generate_embeddings(
+        func = lambda query: generate_embeddings(
             engine=embedding_engine,
             model=embedding_model,
             text=query,
-            key=openai_key if embedding_engine == "openai" else "",
-            url=openai_url if embedding_engine == "openai" else "",
+            url=url,
+            key=key,
         )
 
         def generate_multiple(query, func):
@@ -476,8 +465,8 @@ def get_model_path(model: str, update_model: bool = False):
         return model
 
 
-async def generate_openai_batch_embeddings(
-    model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
+def generate_openai_batch_embeddings(
+    model: str, texts: list[str], url: str = "https://api.openai.com/v1", key: str = ""
 ) -> Optional[list[list[float]]]:
     try:
         r = requests.post(
@@ -499,31 +488,50 @@ async def generate_openai_batch_embeddings(
         return None
 
 
-async def generate_embeddings(
-    engine: str, model: str, text: Union[str, list[str]], **kwargs
-):
+def generate_ollama_batch_embeddings(
+    model: str, texts: list[str], url: str, key: str
+) -> Optional[list[list[float]]]:
+    try:
+        r = requests.post(
+            f"{url}/api/embed",
+            headers={
+                "Content-Type": "application/json",
+                "Authorization": f"Bearer {key}",
+            },
+            json={"input": texts, "model": model},
+        )
+        r.raise_for_status()
+        data = r.json()
+
+        print(data)
+        if "embeddings" in data:
+            return data["embeddings"]
+        else:
+            raise "Something went wrong :/"
+    except Exception as e:
+        print(e)
+        return None
+
+
+def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
+    url = kwargs.get("url", "")
+    key = kwargs.get("key", "")
+
     if engine == "ollama":
         if isinstance(text, list):
-            embeddings = await generate_ollama_batch_embeddings(
-                GenerateEmbedForm(**{"model": model, "input": text})
+            embeddings = generate_ollama_batch_embeddings(
+                **{"model": model, "texts": text, "url": url, "key": key}
             )
         else:
-            embeddings = await generate_ollama_batch_embeddings(
-                GenerateEmbedForm(**{"model": model, "input": [text]})
+            embeddings = generate_ollama_batch_embeddings(
+                **{"model": model, "texts": [text], "url": url, "key": key}
             )
-        return (
-            embeddings["embeddings"][0]
-            if isinstance(text, str)
-            else embeddings["embeddings"]
-        )
+        return embeddings[0] if isinstance(text, str) else embeddings
     elif engine == "openai":
-        key = kwargs.get("key", "")
-        url = kwargs.get("url", "https://api.openai.com/v1")
-
         if isinstance(text, list):
-            embeddings = await generate_openai_batch_embeddings(model, text, key, url)
+            embeddings = generate_openai_batch_embeddings(model, text, url, key)
         else:
-            embeddings = await generate_openai_batch_embeddings(model, [text], key, url)
+            embeddings = generate_openai_batch_embeddings(model, [text], url, key)
 
         return embeddings[0] if isinstance(text, str) else embeddings
 

+ 33 - 15
src/lib/components/admin/Settings/Documents.svelte

@@ -56,8 +56,11 @@
 	let chunkOverlap = 0;
 	let pdfExtractImages = true;
 
-	let OpenAIKey = '';
 	let OpenAIUrl = '';
+	let OpenAIKey = '';
+
+	let OllamaUrl = '';
+	let OllamaKey = '';
 
 	let querySettings = {
 		template: '',
@@ -104,19 +107,15 @@
 		const res = await updateEmbeddingConfig(localStorage.token, {
 			embedding_engine: embeddingEngine,
 			embedding_model: embeddingModel,
-			...(embeddingEngine === 'openai' || embeddingEngine === 'ollama'
-				? {
-						embedding_batch_size: embeddingBatchSize
-					}
-				: {}),
-			...(embeddingEngine === 'openai'
-				? {
-						openai_config: {
-							key: OpenAIKey,
-							url: OpenAIUrl
-						}
-					}
-				: {})
+			embedding_batch_size: embeddingBatchSize,
+			ollama_config: {
+				key: OllamaKey,
+				url: OllamaUrl
+			},
+			openai_config: {
+				key: OpenAIKey,
+				url: OpenAIUrl
+			}
 		}).catch(async (error) => {
 			toast.error(error);
 			await setEmbeddingConfig();
@@ -206,6 +205,9 @@
 
 			OpenAIKey = embeddingConfig.openai_config.key;
 			OpenAIUrl = embeddingConfig.openai_config.url;
+
+			OllamaKey = embeddingConfig.ollama_config.key;
+			OllamaUrl = embeddingConfig.ollama_config.url;
 		}
 	};
 
@@ -310,7 +312,7 @@
 			</div>
 
 			{#if embeddingEngine === 'openai'}
-				<div class="my-0.5 flex gap-2">
+				<div class="my-0.5 flex gap-2 pr-2">
 					<input
 						class="flex-1 w-full rounded-lg text-sm bg-transparent outline-none"
 						placeholder={$i18n.t('API Base URL')}
@@ -320,7 +322,23 @@
 
 					<SensitiveInput placeholder={$i18n.t('API Key')} bind:value={OpenAIKey} />
 				</div>
+			{:else if embeddingEngine === 'ollama'}
+				<div class="my-0.5 flex gap-2 pr-2">
+					<input
+						class="flex-1 w-full rounded-lg text-sm bg-transparent outline-none"
+						placeholder={$i18n.t('API Base URL')}
+						bind:value={OllamaUrl}
+						required
+					/>
+
+					<SensitiveInput
+						placeholder={$i18n.t('API Key')}
+						bind:value={OllamaKey}
+						required={false}
+					/>
+				</div>
 			{/if}
+
 			{#if embeddingEngine === 'ollama' || embeddingEngine === 'openai'}
 				<div class="flex mt-0.5 space-x-2">
 					<div class=" self-center text-xs font-medium">{$i18n.t('Embedding Batch Size')}</div>