Browse Source

feat: openai embeddings integration

Timothy J. Baek 1 năm trước cách đây
mục cha
commit
b1b72441bb

+ 2 - 2
backend/apps/ollama/main.py

@@ -659,7 +659,7 @@ def generate_ollama_embeddings(
     url_idx: Optional[int] = None,
 ):
 
-    log.info("generate_ollama_embeddings", form_data)
+    log.info(f"generate_ollama_embeddings {form_data}")
 
     if url_idx == None:
         model = form_data.model
@@ -688,7 +688,7 @@ def generate_ollama_embeddings(
 
         data = r.json()
 
-        log.info("generate_ollama_embeddings", data)
+        log.info(f"generate_ollama_embeddings {data}")
 
         if "embedding" in data:
             return data["embedding"]

+ 4 - 2
backend/apps/rag/main.py

@@ -421,7 +421,7 @@ def store_data_in_vector_db(data, collection_name, overwrite: bool = False) -> b
     docs = text_splitter.split_documents(data)
 
     if len(docs) > 0:
-        log.info("store_data_in_vector_db", "store_docs_in_vector_db")
+        log.info(f"store_data_in_vector_db {docs}")
         return store_docs_in_vector_db(docs, collection_name, overwrite), None
     else:
         raise ValueError(ERROR_MESSAGES.EMPTY_CONTENT)
@@ -440,7 +440,7 @@ def store_text_in_vector_db(
 
 
 def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> bool:
-    log.info("store_docs_in_vector_db", docs, collection_name)
+    log.info(f"store_docs_in_vector_db {docs} {collection_name}")
 
     texts = [doc.page_content for doc in docs]
     metadatas = [doc.metadata for doc in docs]
@@ -468,6 +468,8 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
                 collection.add(*batch)
 
         else:
+            collection = CHROMA_CLIENT.create_collection(name=collection_name)
+
             if app.state.RAG_EMBEDDING_ENGINE == "ollama":
                 embeddings = [
                     generate_ollama_embeddings(

+ 65 - 17
backend/apps/rag/utils.py

@@ -6,9 +6,12 @@ import requests
 
 
 from huggingface_hub import snapshot_download
+from apps.ollama.main import generate_ollama_embeddings, GenerateEmbeddingsForm
+
 
 from config import SRC_LOG_LEVELS, CHROMA_CLIENT
 
+
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["RAG"])
 
@@ -32,7 +35,7 @@ def query_doc(collection_name: str, query: str, k: int, embedding_function):
 def query_embeddings_doc(collection_name: str, query_embeddings, k: int):
     try:
         # if you use docker use the model from the environment variable
-        log.info("query_embeddings_doc", query_embeddings)
+        log.info(f"query_embeddings_doc {query_embeddings}")
         collection = CHROMA_CLIENT.get_collection(
             name=collection_name,
         )
@@ -118,7 +121,7 @@ def query_collection(
 def query_embeddings_collection(collection_names: List[str], query_embeddings, k: int):
 
     results = []
-    log.info("query_embeddings_collection", query_embeddings)
+    log.info(f"query_embeddings_collection {query_embeddings}")
 
     for collection_name in collection_names:
         try:
@@ -141,7 +144,17 @@ def rag_template(template: str, context: str, query: str):
     return template
 
 
-def rag_messages(docs, messages, template, k, embedding_function):
+def rag_messages(
+    docs,
+    messages,
+    template,
+    k,
+    embedding_engine,
+    embedding_model,
+    embedding_function,
+    openai_key,
+    openai_url,
+):
     log.debug(f"docs: {docs}")
 
     last_user_message_idx = None
@@ -175,22 +188,57 @@ def rag_messages(docs, messages, template, k, embedding_function):
         context = None
 
         try:
-            if doc["type"] == "collection":
-                context = query_collection(
-                    collection_names=doc["collection_names"],
-                    query=query,
-                    k=k,
-                    embedding_function=embedding_function,
-                )
-            elif doc["type"] == "text":
+
+            if doc["type"] == "text":
                 context = doc["content"]
             else:
-                context = query_doc(
-                    collection_name=doc["collection_name"],
-                    query=query,
-                    k=k,
-                    embedding_function=embedding_function,
-                )
+                if embedding_engine == "":
+                    if doc["type"] == "collection":
+                        context = query_collection(
+                            collection_names=doc["collection_names"],
+                            query=query,
+                            k=k,
+                            embedding_function=embedding_function,
+                        )
+                    else:
+                        context = query_doc(
+                            collection_name=doc["collection_name"],
+                            query=query,
+                            k=k,
+                            embedding_function=embedding_function,
+                        )
+
+                else:
+                    if embedding_engine == "ollama":
+                        query_embeddings = generate_ollama_embeddings(
+                            GenerateEmbeddingsForm(
+                                **{
+                                    "model": embedding_model,
+                                    "prompt": query,
+                                }
+                            )
+                        )
+                    elif embedding_engine == "openai":
+                        query_embeddings = generate_openai_embeddings(
+                            model=embedding_model,
+                            text=query,
+                            key=openai_key,
+                            url=openai_url,
+                        )
+
+                    if doc["type"] == "collection":
+                        context = query_embeddings_collection(
+                            collection_names=doc["collection_names"],
+                            query_embeddings=query_embeddings,
+                            k=k,
+                        )
+                    else:
+                        context = query_embeddings_doc(
+                            collection_name=doc["collection_name"],
+                            query_embeddings=query_embeddings,
+                            k=k,
+                        )
+
         except Exception as e:
             log.exception(e)
             context = None

+ 4 - 0
backend/main.py

@@ -114,7 +114,11 @@ class RAGMiddleware(BaseHTTPMiddleware):
                     data["messages"],
                     rag_app.state.RAG_TEMPLATE,
                     rag_app.state.TOP_K,
+                    rag_app.state.RAG_EMBEDDING_ENGINE,
+                    rag_app.state.RAG_EMBEDDING_MODEL,
                     rag_app.state.sentence_transformer_ef,
+                    rag_app.state.RAG_OPENAI_API_KEY,
+                    rag_app.state.RAG_OPENAI_API_BASE_URL,
                 )
                 del data["docs"]
 

+ 6 - 0
src/lib/apis/rag/index.ts

@@ -373,7 +373,13 @@ export const getEmbeddingConfig = async (token: string) => {
 	return res;
 };
 
+type OpenAIConfigForm = {
+	key: string;
+	url: string;
+};
+
 type EmbeddingModelUpdateForm = {
+	openai_config?: OpenAIConfigForm;
 	embedding_engine: string;
 	embedding_model: string;
 };

+ 67 - 18
src/lib/components/documents/Settings/General.svelte

@@ -29,6 +29,9 @@
 	let embeddingEngine = '';
 	let embeddingModel = '';
 
+	let openAIKey = '';
+	let openAIUrl = '';
+
 	let chunkSize = 0;
 	let chunkOverlap = 0;
 	let pdfExtractImages = true;
@@ -50,7 +53,15 @@
 	};
 
 	const embeddingModelUpdateHandler = async () => {
-		if (embeddingModel === '') {
+		if (embeddingEngine === '' && embeddingModel.split('/').length - 1 > 1) {
+			toast.error(
+				$i18n.t(
+					'Model filesystem path detected. Model shortname is required for update, cannot continue.'
+				)
+			);
+			return;
+		}
+		if (embeddingEngine === 'ollama' && embeddingModel === '') {
 			toast.error(
 				$i18n.t(
 					'Model filesystem path detected. Model shortname is required for update, cannot continue.'
@@ -59,7 +70,7 @@
 			return;
 		}
 
-		if (embeddingEngine === '' && embeddingModel.split('/').length - 1 > 1) {
+		if (embeddingEngine === 'openai' && embeddingModel === '') {
 			toast.error(
 				$i18n.t(
 					'Model filesystem path detected. Model shortname is required for update, cannot continue.'
@@ -68,20 +79,28 @@
 			return;
 		}
 
+		if ((embeddingEngine === 'openai' && openAIKey === '') || openAIUrl === '') {
+			toast.error($i18n.t('OpenAI URL/Key required.'));
+			return;
+		}
+
 		console.log('Update embedding model attempt:', embeddingModel);
 
 		updateEmbeddingModelLoading = true;
 		const res = await updateEmbeddingConfig(localStorage.token, {
 			embedding_engine: embeddingEngine,
-			embedding_model: embeddingModel
+			embedding_model: embeddingModel,
+			...(embeddingEngine === 'openai'
+				? {
+						openai_config: {
+							key: openAIKey,
+							url: openAIUrl
+						}
+				  }
+				: {})
 		}).catch(async (error) => {
 			toast.error(error);
-
-			const embeddingConfig = await getEmbeddingConfig(localStorage.token);
-			if (embeddingConfig) {
-				embeddingEngine = embeddingConfig.embedding_engine;
-				embeddingModel = embeddingConfig.embedding_model;
-			}
+			await setEmbeddingConfig();
 			return null;
 		});
 		updateEmbeddingModelLoading = false;
@@ -89,7 +108,7 @@
 		if (res) {
 			console.log('embeddingModelUpdateHandler:', res);
 			if (res.status === true) {
-				toast.success($i18n.t('Model {{embedding_model}} update complete!', res), {
+				toast.success($i18n.t('Embedding model set to "{{embedding_model}}"', res), {
 					duration: 1000 * 10
 				});
 			}
@@ -107,6 +126,18 @@
 		querySettings = await updateQuerySettings(localStorage.token, querySettings);
 	};
 
+	const setEmbeddingConfig = async () => {
+		const embeddingConfig = await getEmbeddingConfig(localStorage.token);
+
+		if (embeddingConfig) {
+			embeddingEngine = embeddingConfig.embedding_engine;
+			embeddingModel = embeddingConfig.embedding_model;
+
+			openAIKey = embeddingConfig.openai_config.key;
+			openAIUrl = embeddingConfig.openai_config.url;
+		}
+	};
+
 	onMount(async () => {
 		const res = await getRAGConfig(localStorage.token);
 
@@ -117,12 +148,7 @@
 			chunkOverlap = res.chunk.chunk_overlap;
 		}
 
-		const embeddingConfig = await getEmbeddingConfig(localStorage.token);
-
-		if (embeddingConfig) {
-			embeddingEngine = embeddingConfig.embedding_engine;
-			embeddingModel = embeddingConfig.embedding_model;
-		}
+		await setEmbeddingConfig();
 
 		querySettings = await getQuerySettings(localStorage.token);
 	});
@@ -146,15 +172,38 @@
 						class="dark:bg-gray-900 w-fit pr-8 rounded px-2 p-1 text-xs bg-transparent outline-none text-right"
 						bind:value={embeddingEngine}
 						placeholder="Select an embedding engine"
-						on:change={() => {
-							embeddingModel = '';
+						on:change={(e) => {
+							if (e.target.value === 'ollama') {
+								embeddingModel = '';
+							} else if (e.target.value === 'openai') {
+								embeddingModel = 'text-embedding-3-small';
+							}
 						}}
 					>
 						<option value="">{$i18n.t('Default (SentenceTransformer)')}</option>
 						<option value="ollama">{$i18n.t('Ollama')}</option>
+						<option value="openai">{$i18n.t('OpenAI')}</option>
 					</select>
 				</div>
 			</div>
+
+			{#if embeddingEngine === 'openai'}
+				<div class="mt-1 flex gap-2">
+					<input
+						class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+						placeholder={$i18n.t('API Base URL')}
+						bind:value={openAIUrl}
+						required
+					/>
+
+					<input
+						class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+						placeholder={$i18n.t('API Key')}
+						bind:value={openAIKey}
+						required
+					/>
+				</div>
+			{/if}
 		</div>
 
 		<div class="space-y-2">