فهرست منبع

Merge pull request #1554 from open-webui/external-embeddings

feat: external embeddings
Timothy Jaeryang Baek 1 سال پیش
والد
کامیت
54a4b7db14

+ 2 - 2
backend/apps/ollama/main.py

@@ -659,7 +659,7 @@ def generate_ollama_embeddings(
     url_idx: Optional[int] = None,
 ):
 
-    log.info("generate_ollama_embeddings", form_data)
+    log.info(f"generate_ollama_embeddings {form_data}")
 
     if url_idx == None:
         model = form_data.model
@@ -688,7 +688,7 @@ def generate_ollama_embeddings(
 
         data = r.json()
 
-        log.info("generate_ollama_embeddings", data)
+        log.info(f"generate_ollama_embeddings {data}")
 
         if "embedding" in data:
             return data["embedding"]

+ 102 - 49
backend/apps/rag/main.py

@@ -53,6 +53,7 @@ from apps.rag.utils import (
     query_collection,
     query_embeddings_collection,
     get_embedding_model_path,
+    generate_openai_embeddings,
 )
 
 from utils.misc import (
@@ -93,6 +94,8 @@ app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
 app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
 app.state.RAG_TEMPLATE = RAG_TEMPLATE
 
+app.state.RAG_OPENAI_API_BASE_URL = "https://api.openai.com"
+app.state.RAG_OPENAI_API_KEY = ""
 
 app.state.PDF_EXTRACT_IMAGES = False
 
@@ -144,10 +147,20 @@ async def get_embedding_config(user=Depends(get_admin_user)):
         "status": True,
         "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
         "embedding_model": app.state.RAG_EMBEDDING_MODEL,
+        "openai_config": {
+            "url": app.state.RAG_OPENAI_API_BASE_URL,
+            "key": app.state.RAG_OPENAI_API_KEY,
+        },
     }
 
 
+class OpenAIConfigForm(BaseModel):
+    url: str
+    key: str
+
+
 class EmbeddingModelUpdateForm(BaseModel):
+    openai_config: Optional[OpenAIConfigForm] = None
     embedding_engine: str
     embedding_model: str
 
@@ -156,17 +169,19 @@ class EmbeddingModelUpdateForm(BaseModel):
 async def update_embedding_config(
     form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
 ):
-
     log.info(
         f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
     )
-
     try:
         app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
 
-        if app.state.RAG_EMBEDDING_ENGINE == "ollama":
+        if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
             app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
             app.state.sentence_transformer_ef = None
+
+            if form_data.openai_config != None:
+                app.state.RAG_OPENAI_API_BASE_URL = form_data.openai_config.url
+                app.state.RAG_OPENAI_API_KEY = form_data.openai_config.key
         else:
             sentence_transformer_ef = (
                 embedding_functions.SentenceTransformerEmbeddingFunction(
@@ -183,6 +198,10 @@ async def update_embedding_config(
             "status": True,
             "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
             "embedding_model": app.state.RAG_EMBEDDING_MODEL,
+            "openai_config": {
+                "url": app.state.RAG_OPENAI_API_BASE_URL,
+                "key": app.state.RAG_OPENAI_API_KEY,
+            },
         }
 
     except Exception as e:
@@ -275,28 +294,37 @@ def query_doc_handler(
 ):
 
     try:
-        if app.state.RAG_EMBEDDING_ENGINE == "ollama":
-            query_embeddings = generate_ollama_embeddings(
-                GenerateEmbeddingsForm(
-                    **{
-                        "model": app.state.RAG_EMBEDDING_MODEL,
-                        "prompt": form_data.query,
-                    }
-                )
-            )
-
-            return query_embeddings_doc(
+        if app.state.RAG_EMBEDDING_ENGINE == "":
+            return query_doc(
                 collection_name=form_data.collection_name,
-                query_embeddings=query_embeddings,
+                query=form_data.query,
                 k=form_data.k if form_data.k else app.state.TOP_K,
+                embedding_function=app.state.sentence_transformer_ef,
             )
         else:
-            return query_doc(
+            if app.state.RAG_EMBEDDING_ENGINE == "ollama":
+                query_embeddings = generate_ollama_embeddings(
+                    GenerateEmbeddingsForm(
+                        **{
+                            "model": app.state.RAG_EMBEDDING_MODEL,
+                            "prompt": form_data.query,
+                        }
+                    )
+                )
+            elif app.state.RAG_EMBEDDING_ENGINE == "openai":
+                query_embeddings = generate_openai_embeddings(
+                    model=app.state.RAG_EMBEDDING_MODEL,
+                    text=form_data.query,
+                    key=app.state.RAG_OPENAI_API_KEY,
+                    url=app.state.RAG_OPENAI_API_BASE_URL,
+                )
+
+            return query_embeddings_doc(
                 collection_name=form_data.collection_name,
-                query=form_data.query,
+                query_embeddings=query_embeddings,
                 k=form_data.k if form_data.k else app.state.TOP_K,
-                embedding_function=app.state.sentence_transformer_ef,
             )
+
     except Exception as e:
         log.exception(e)
         raise HTTPException(
@@ -317,28 +345,38 @@ def query_collection_handler(
     user=Depends(get_current_user),
 ):
     try:
-        if app.state.RAG_EMBEDDING_ENGINE == "ollama":
-            query_embeddings = generate_ollama_embeddings(
-                GenerateEmbeddingsForm(
-                    **{
-                        "model": app.state.RAG_EMBEDDING_MODEL,
-                        "prompt": form_data.query,
-                    }
-                )
-            )
-
-            return query_embeddings_collection(
+        if app.state.RAG_EMBEDDING_ENGINE == "":
+            return query_collection(
                 collection_names=form_data.collection_names,
-                query_embeddings=query_embeddings,
+                query=form_data.query,
                 k=form_data.k if form_data.k else app.state.TOP_K,
+                embedding_function=app.state.sentence_transformer_ef,
             )
         else:
-            return query_collection(
+
+            if app.state.RAG_EMBEDDING_ENGINE == "ollama":
+                query_embeddings = generate_ollama_embeddings(
+                    GenerateEmbeddingsForm(
+                        **{
+                            "model": app.state.RAG_EMBEDDING_MODEL,
+                            "prompt": form_data.query,
+                        }
+                    )
+                )
+            elif app.state.RAG_EMBEDDING_ENGINE == "openai":
+                query_embeddings = generate_openai_embeddings(
+                    model=app.state.RAG_EMBEDDING_MODEL,
+                    text=form_data.query,
+                    key=app.state.RAG_OPENAI_API_KEY,
+                    url=app.state.RAG_OPENAI_API_BASE_URL,
+                )
+
+            return query_embeddings_collection(
                 collection_names=form_data.collection_names,
-                query=form_data.query,
+                query_embeddings=query_embeddings,
                 k=form_data.k if form_data.k else app.state.TOP_K,
-                embedding_function=app.state.sentence_transformer_ef,
             )
+
     except Exception as e:
         log.exception(e)
         raise HTTPException(
@@ -383,7 +421,7 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b
     docs = text_splitter.split_documents(data)
 
     if len(docs) > 0:
-        log.info("store_data_in_vector_db", "store_docs_in_vector_db")
+        log.info(f"store_data_in_vector_db {docs}")
         return store_docs_in_vector_db(docs, collection_name, overwrite), None
     else:
         raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
@@ -402,7 +440,7 @@ def store_text_in_vector_db(
 
 
 def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
-    log.info("store_docs_in_vector_db", docs, collection_name)
+    log.info(f"store_docs_in_vector_db {docs} {collection_name}")
 
     texts = [doc.page_content for doc in docs]
     metadatas = [doc.metadata for doc in docs]
@@ -414,39 +452,54 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
                     log.info(f"deleting existing collection {collection_name}")
                     CHROMA_CLIENT.delete_collection(name=collection_name)
 
-        if app.state.RAG_EMBEDDING_ENGINE == "ollama":
-            collection = CHROMA_CLIENT.create_collection(name=collection_name)
+        if app.state.RAG_EMBEDDING_ENGINE == "":
+
+            collection = CHROMA_CLIENT.create_collection(
+                name=collection_name,
+                embedding_function=app.state.sentence_transformer_ef,
+            )
 
             for batch in create_batches(
                 api=CHROMA_CLIENT,
                 ids=[str(uuid.uuid1()) for _ in texts],
                 metadatas=metadatas,
-                embeddings=[
+                documents=texts,
+            ):
+                collection.add(*batch)
+
+        else:
+            collection = CHROMA_CLIENT.create_collection(name=collection_name)
+
+            if app.state.RAG_EMBEDDING_ENGINE == "ollama":
+                embeddings = [
                     generate_ollama_embeddings(
                         GenerateEmbeddingsForm(
-                            **{"model": RAG_EMBEDDING_MODEL, "prompt": text}
+                            **{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text}
                         )
                     )
                     for text in texts
-                ],
-            ):
-                collection.add(*batch)
-        else:
-
-            collection = CHROMA_CLIENT.create_collection(
-                name=collection_name,
-                embedding_function=app.state.sentence_transformer_ef,
-            )
+                ]
+            elif app.state.RAG_EMBEDDING_ENGINE == "openai":
+                embeddings = [
+                    generate_openai_embeddings(
+                        model=app.state.RAG_EMBEDDING_MODEL,
+                        text=text,
+                        key=app.state.RAG_OPENAI_API_KEY,
+                        url=app.state.RAG_OPENAI_API_BASE_URL,
+                    )
+                    for text in texts
+                ]
 
             for batch in create_batches(
                 api=CHROMA_CLIENT,
                 ids=[str(uuid.uuid1()) for _ in texts],
                 metadatas=metadatas,
+                embeddings=embeddings,
                 documents=texts,
             ):
                 collection.add(*batch)
 
-            return True
+        return True
     except Exception as e:
         log.exception(e)
         if e.__class__.__name__ == "UniqueConstraintError":

+ 93 - 18
backend/apps/rag/utils.py

@@ -6,9 +6,12 @@ import requests
 
 
 from huggingface_hub import snapshot_download
+from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
+
 
 from config import SRC_LOG_LEVELS, CHROMA_CLIENT
 
+
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
@@ -32,7 +35,7 @@ def query_doc(collection_name: str, query: str, k: int, embedding_function):
 def query_embeddings_doc(collection_name: str, query_embeddings, k: int):
     try:
         # if you use docker use the model from the environment variable
-        log.info("query_embeddings_doc", query_embeddings)
+        log.info(f"query_embeddings_doc {query_embeddings}")
         collection = CHROMA_CLIENT.get_collection(
             name=collection_name,
         )
@@ -40,6 +43,8 @@ def query_embeddings_doc(collection_name: str, query_embeddings, k: int):
             query_embeddings=[query_embeddings],
             n_results=k,
         )
+
+        log.info(f"query_embeddings_doc:result {result}")
         return result
     except Exception as e:
         raise e
@@ -118,7 +123,7 @@ def query_collection(
 def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int):
 
     results = []
-    log.info("query_embeddings_collection", query_embeddings)
+    log.info(f"query_embeddings_collection {query_embeddings}")
 
     for collection_name in collection_names:
         try:
@@ -141,8 +146,20 @@ def rag_template(template: str, context: str, query: str):
     return template
 
 
-def rag_messages(docs, messages, template, k, embedding_function):
-    log.debug(f"docs: {docs}")
+def rag_messages(
+    docs,
+    messages,
+    template,
+    k,
+    embedding_engine,
+    embedding_model,
+    embedding_function,
+    openai_key,
+    openai_url,
+):
+    log.debug(
+        f"docs: {docs} {messages} {embedding_engine} {embedding_model} {embedding_function} {openai_key} {openai_url}"
+    )
 
     last_user_message_idx = None
     for i in range(len(messages) - 1, -1, -1):
@@ -175,22 +192,57 @@ def rag_messages(docs, messages, template, k, embedding_function):
         context = None
 
         try:
-            if doc["type"] == "collection":
-                context = query_collection(
-                    collection_names=doc["collection_names"],
-                    query=query,
-                    k=k,
-                    embedding_function=embedding_function,
-                )
-            elif doc["type"] == "text":
+
+            if doc["type"] == "text":
                 context = doc["content"]
             else:
-                context = query_doc(
-                    collection_name=doc["collection_name"],
-                    query=query,
-                    k=k,
-                    embedding_function=embedding_function,
-                )
+                if embedding_engine == "":
+                    if doc["type"] == "collection":
+                        context = query_collection(
+                            collection_names=doc["collection_names"],
+                            query=query,
+                            k=k,
+                            embedding_function=embedding_function,
+                        )
+                    else:
+                        context = query_doc(
+                            collection_name=doc["collection_name"],
+                            query=query,
+                            k=k,
+                            embedding_function=embedding_function,
+                        )
+
+                else:
+                    if embedding_engine == "ollama":
+                        query_embeddings = generate_ollama_embeddings(
+                            GenerateEmbeddingsForm(
+                                **{
+                                    "model": embedding_model,
+                                    "prompt": query,
+                                }
+                            )
+                        )
+                    elif embedding_engine == "openai":
+                        query_embeddings = generate_openai_embeddings(
+                            model=embedding_model,
+                            text=query,
+                            key=openai_key,
+                            url=openai_url,
+                        )
+
+                    if doc["type"] == "collection":
+                        context = query_embeddings_collection(
+                            collection_names=doc["collection_names"],
+                            query_embeddings=query_embeddings,
+                            k=k,
+                        )
+                    else:
+                        context = query_embeddings_doc(
+                            collection_name=doc["collection_name"],
+                            query_embeddings=query_embeddings,
+                            k=k,
+                        )
+
         except Exception as e:
             log.exception(e)
             context = None
@@ -269,3 +321,26 @@ def get_embedding_model_path(
     except Exception as e:
         log.exception(f"Cannot determine embedding model snapshot path: {e}")
         return embedding_model
+
+
+def generate_openai_embeddings(
+    model: str, text: str, key: str, url: str = "https://api.openai.com"
+):
+    try:
+        r = requests.post(
+            f"{url}/v1/embeddings",
+            headers={
+                "Content-Type": "application/json",
+                "Authorization": f"Bearer {key}",
+            },
+            json={"input": text, "model": model},
+        )
+        r.raise_for_status()
+        data = r.json()
+        if "data" in data:
+            return data["data"][0]["embedding"]
+        else:
+            raise "Something went wrong :/"
+    except Exception as e:
+        print(e)
+        return None

+ 4 - 0
backend/main.py

@@ -114,7 +114,11 @@ class RAGMiddleware(BaseHTTPMiddleware):
                     data["messages"],
                     rag_app.state.RAG_TEMPLATE,
                     rag_app.state.TOP_K,
+                    rag_app.state.RAG_EMBEDDING_ENGINE,
+                    rag_app.state.RAG_EMBEDDING_MODEL,
                     rag_app.state.sentence_transformer_ef,
+                    rag_app.state.RAG_OPENAI_API_KEY,
+                    rag_app.state.RAG_OPENAI_API_BASE_URL,
                 )
                 del data["docs"]
 

+ 6 - 0
src/lib/apis/rag/index.ts

@@ -373,7 +373,13 @@ export const getEmbeddingConfig = async (token: string) => {
 	return res;
 };
 
+type OpenAIConfigForm = {
+	key: string;
+	url: string;
+};
+
 type EmbeddingModelUpdateForm = {
+	openai_config?: OpenAIConfigForm;
 	embedding_engine: string;
 	embedding_model: string;
 };

+ 67 - 18
src/lib/components/documents/Settings/General.svelte

@@ -29,6 +29,9 @@
 	let embeddingEngine = '';
 	let embeddingModel = '';
 
+	let openAIKey = '';
+	let openAIUrl = '';
+
 	let chunkSize = 0;
 	let chunkOverlap = 0;
 	let pdfExtractImages = true;
@@ -50,7 +53,15 @@
 	};
 
 	const embeddingModelUpdateHandler = async () => {
-		if (embeddingModel === '') {
+		if (embeddingEngine === '' && embeddingModel.split('/').length - 1 > 1) {
+			toast.error(
+				$i18n.t(
+					'Model filesystem path detected. Model shortname is required for update, cannot continue.'
+				)
+			);
+			return;
+		}
+		if (embeddingEngine === 'ollama' && embeddingModel === '') {
 			toast.error(
 				$i18n.t(
 					'Model filesystem path detected. Model shortname is required for update, cannot continue.'
@@ -59,7 +70,7 @@
 			return;
 		}
 
-		if (embeddingEngine === '' && embeddingModel.split('/').length - 1 > 1) {
+		if (embeddingEngine === 'openai' && embeddingModel === '') {
 			toast.error(
 				$i18n.t(
 					'Model filesystem path detected. Model shortname is required for update, cannot continue.'
@@ -68,20 +79,28 @@
 			return;
 		}
 
+		if ((embeddingEngine === 'openai' && openAIKey === '') || openAIUrl === '') {
+			toast.error($i18n.t('OpenAI URL/Key required.'));
+			return;
+		}
+
 		console.log('Update embedding model attempt:', embeddingModel);
 
 		updateEmbeddingModelLoading = true;
 		const res = await updateEmbeddingConfig(localStorage.token, {
 			embedding_engine: embeddingEngine,
-			embedding_model: embeddingModel
+			embedding_model: embeddingModel,
+			...(embeddingEngine === 'openai'
+				? {
+						openai_config: {
+							key: openAIKey,
+							url: openAIUrl
+						}
+				  }
+				: {})
 		}).catch(async (error) => {
 			toast.error(error);
-
-			const embeddingConfig = await getEmbeddingConfig(localStorage.token);
-			if (embeddingConfig) {
-				embeddingEngine = embeddingConfig.embedding_engine;
-				embeddingModel = embeddingConfig.embedding_model;
-			}
+			await setEmbeddingConfig();
 			return null;
 		});
 		updateEmbeddingModelLoading = false;
@@ -89,7 +108,7 @@
 		if (res) {
 			console.log('embeddingModelUpdateHandler:', res);
 			if (res.status === true) {
-				toast.success($i18n.t('Model {{embedding_model}} update complete!', res), {
+				toast.success($i18n.t('Embedding model set to "{{embedding_model}}"', res), {
 					duration: 1000 * 10
 				});
 			}
@@ -107,6 +126,18 @@
 		querySettings = await updateQuerySettings(localStorage.token, querySettings);
 	};
 
+	const setEmbeddingConfig = async () => {
+		const embeddingConfig = await getEmbeddingConfig(localStorage.token);
+
+		if (embeddingConfig) {
+			embeddingEngine = embeddingConfig.embedding_engine;
+			embeddingModel = embeddingConfig.embedding_model;
+
+			openAIKey = embeddingConfig.openai_config.key;
+			openAIUrl = embeddingConfig.openai_config.url;
+		}
+	};
+
 	onMount(async () => {
 		const res = await getRAGConfig(localStorage.token);
 
@@ -117,12 +148,7 @@
 			chunkOverlap = res.chunk.chunk_overlap;
 		}
 
-		const embeddingConfig = await getEmbeddingConfig(localStorage.token);
-
-		if (embeddingConfig) {
-			embeddingEngine = embeddingConfig.embedding_engine;
-			embeddingModel = embeddingConfig.embedding_model;
-		}
+		await setEmbeddingConfig();
 
 		querySettings = await getQuerySettings(localStorage.token);
 	});
@@ -146,15 +172,38 @@
 						class="dark:bg-gray-900 w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right"
 						bind:value={embeddingEngine}
 						placeholder="Select an embedding engine"
-						on:change={() => {
-							embeddingModel = '';
+						on:change={(e) => {
+							if (e.target.value === 'ollama') {
+								embeddingModel = '';
+							} else if (e.target.value === 'openai') {
+								embeddingModel = 'text-embedding-3-small';
+							}
 						}}
 					>
 						<option value="">{$i18n.t('Default (SentenceTransformer)')}</option>
 						<option value="ollama">{$i18n.t('Ollama')}</option>
+						<option value="openai">{$i18n.t('OpenAI')}</option>
 					</select>
 				</div>
 			</div>
+
+			{#if embeddingEngine === 'openai'}
+				<div class="mt-1 flex gap-2">
+					<input
+						class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+						placeholder={$i18n.t('API Base URL')}
+						bind:value={openAIUrl}
+						required
+					/>
+
+					<input
+						class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+						placeholder={$i18n.t('API Key')}
+						bind:value={openAIKey}
+						required
+					/>
+				</div>
+			{/if}
 		</div>
 
 		<div class="space-y-2">