Browse Source

feat: openai embeddings support

Timothy J. Baek 1 year ago
parent
commit
b48e73fa43
2 changed files with 121 additions and 48 deletions
  1. 98 48
      backend/apps/rag/main.py
  2. 23 0
      backend/apps/rag/utils.py

+ 98 - 48
backend/apps/rag/main.py

@@ -53,6 +53,7 @@ from apps.rag.utils import (
     query_collection,
     query_collection,
     query_embeddings_collection,
     query_embeddings_collection,
     get_embedding_model_path,
     get_embedding_model_path,
+    generate_openai_embeddings,
 )
 )
 
 
 from utils.misc import (
 from utils.misc import (
@@ -93,6 +94,8 @@ app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
 app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
 app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
 app.state.RAG_TEMPLATE = RAG_TEMPLATE
 app.state.RAG_TEMPLATE = RAG_TEMPLATE
 
 
+app.state.RAG_OPENAI_API_BASE_URL = "https://api.openai.com"
+app.state.RAG_OPENAI_API_KEY = ""
 
 
 app.state.PDF_EXTRACT_IMAGES = False
 app.state.PDF_EXTRACT_IMAGES = False
 
 
@@ -144,10 +147,20 @@ async def get_embedding_config(user=Depends(get_admin_user)):
         "status": True,
         "status": True,
         "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
         "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
         "embedding_model": app.state.RAG_EMBEDDING_MODEL,
         "embedding_model": app.state.RAG_EMBEDDING_MODEL,
+        "openai_config": {
+            "url": app.state.RAG_OPENAI_API_BASE_URL,
+            "key": app.state.RAG_OPENAI_API_KEY,
+        },
     }
     }
 
 
 
 
+class OpenAIConfigForm(BaseModel):
+    url: str
+    key: str
+
+
 class EmbeddingModelUpdateForm(BaseModel):
 class EmbeddingModelUpdateForm(BaseModel):
+    openai_config: Optional[OpenAIConfigForm] = None
     embedding_engine: str
     embedding_engine: str
     embedding_model: str
     embedding_model: str
 
 
@@ -156,17 +169,19 @@ class EmbeddingModelUpdateForm(BaseModel):
 async def update_embedding_config(
 async def update_embedding_config(
     form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
     form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
 ):
 ):
-
     log.info(
     log.info(
         f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
         f"Updating embedding model: {app.state.RAG_EMBEDDING_MODEL} to {form_data.embedding_model}"
     )
     )
-
     try:
     try:
         app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
         app.state.RAG_EMBEDDING_ENGINE = form_data.embedding_engine
 
 
-        if app.state.RAG_EMBEDDING_ENGINE == "ollama":
+        if app.state.RAG_EMBEDDING_ENGINE in ["ollama", "openai"]:
             app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
             app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
             app.state.sentence_transformer_ef = None
             app.state.sentence_transformer_ef = None
+
+            if form_data.openai_config != None:
+                app.state.RAG_OPENAI_API_BASE_URL = form_data.openai_config.url
+                app.state.RAG_OPENAI_API_KEY = form_data.openai_config.key
         else:
         else:
             sentence_transformer_ef = (
             sentence_transformer_ef = (
                 embedding_functions.SentenceTransformerEmbeddingFunction(
                 embedding_functions.SentenceTransformerEmbeddingFunction(
@@ -183,6 +198,10 @@ async def update_embedding_config(
             "status": True,
             "status": True,
             "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
             "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
             "embedding_model": app.state.RAG_EMBEDDING_MODEL,
             "embedding_model": app.state.RAG_EMBEDDING_MODEL,
+            "openai_config": {
+                "url": app.state.RAG_OPENAI_API_BASE_URL,
+                "key": app.state.RAG_OPENAI_API_KEY,
+            },
         }
         }
 
 
     except Exception as e:
     except Exception as e:
@@ -275,28 +294,37 @@ def query_doc_handler(
 ):
 ):
 
 
     try:
     try:
-        if app.state.RAG_EMBEDDING_ENGINE == "ollama":
-            query_embeddings = generate_ollama_embeddings(
-                GenerateEmbeddingsForm(
-                    **{
-                        "model": app.state.RAG_EMBEDDING_MODEL,
-                        "prompt": form_data.query,
-                    }
-                )
-            )
-
-            return query_embeddings_doc(
+        if app.state.RAG_EMBEDDING_ENGINE == "":
+            return query_doc(
                 collection_name=form_data.collection_name,
                 collection_name=form_data.collection_name,
-                query_embeddings=query_embeddings,
+                query=form_data.query,
                 k=form_data.k if form_data.k else app.state.TOP_K,
                 k=form_data.k if form_data.k else app.state.TOP_K,
+                embedding_function=app.state.sentence_transformer_ef,
             )
             )
         else:
         else:
-            return query_doc(
+            if app.state.RAG_EMBEDDING_ENGINE == "ollama":
+                query_embeddings = generate_ollama_embeddings(
+                    GenerateEmbeddingsForm(
+                        **{
+                            "model": app.state.RAG_EMBEDDING_MODEL,
+                            "prompt": form_data.query,
+                        }
+                    )
+                )
+            elif app.state.RAG_EMBEDDING_ENGINE == "openai":
+                query_embeddings = generate_openai_embeddings(
+                    model=app.state.RAG_EMBEDDING_MODEL,
+                    text=form_data.query,
+                    key=app.state.RAG_OPENAI_API_KEY,
+                    url=app.state.RAG_OPENAI_API_BASE_URL,
+                )
+
+            return query_embeddings_doc(
                 collection_name=form_data.collection_name,
                 collection_name=form_data.collection_name,
-                query=form_data.query,
+                query_embeddings=query_embeddings,
                 k=form_data.k if form_data.k else app.state.TOP_K,
                 k=form_data.k if form_data.k else app.state.TOP_K,
-                embedding_function=app.state.sentence_transformer_ef,
             )
             )
+
     except Exception as e:
     except Exception as e:
         log.exception(e)
         log.exception(e)
         raise HTTPException(
         raise HTTPException(
@@ -317,28 +345,38 @@ def query_collection_handler(
     user=Depends(get_current_user),
     user=Depends(get_current_user),
 ):
 ):
     try:
     try:
-        if app.state.RAG_EMBEDDING_ENGINE == "ollama":
-            query_embeddings = generate_ollama_embeddings(
-                GenerateEmbeddingsForm(
-                    **{
-                        "model": app.state.RAG_EMBEDDING_MODEL,
-                        "prompt": form_data.query,
-                    }
-                )
-            )
-
-            return query_embeddings_collection(
+        if app.state.RAG_EMBEDDING_ENGINE == "":
+            return query_collection(
                 collection_names=form_data.collection_names,
                 collection_names=form_data.collection_names,
-                query_embeddings=query_embeddings,
+                query=form_data.query,
                 k=form_data.k if form_data.k else app.state.TOP_K,
                 k=form_data.k if form_data.k else app.state.TOP_K,
+                embedding_function=app.state.sentence_transformer_ef,
             )
             )
         else:
         else:
-            return query_collection(
+
+            if app.state.RAG_EMBEDDING_ENGINE == "ollama":
+                query_embeddings = generate_ollama_embeddings(
+                    GenerateEmbeddingsForm(
+                        **{
+                            "model": app.state.RAG_EMBEDDING_MODEL,
+                            "prompt": form_data.query,
+                        }
+                    )
+                )
+            elif app.state.RAG_EMBEDDING_ENGINE == "openai":
+                query_embeddings = generate_openai_embeddings(
+                    model=app.state.RAG_EMBEDDING_MODEL,
+                    text=form_data.query,
+                    key=app.state.RAG_OPENAI_API_KEY,
+                    url=app.state.RAG_OPENAI_API_BASE_URL,
+                )
+
+            return query_embeddings_collection(
                 collection_names=form_data.collection_names,
                 collection_names=form_data.collection_names,
-                query=form_data.query,
+                query_embeddings=query_embeddings,
                 k=form_data.k if form_data.k else app.state.TOP_K,
                 k=form_data.k if form_data.k else app.state.TOP_K,
-                embedding_function=app.state.sentence_transformer_ef,
             )
             )
+
     except Exception as e:
     except Exception as e:
         log.exception(e)
         log.exception(e)
         raise HTTPException(
         raise HTTPException(
@@ -414,39 +452,51 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
                     log.info(f"deleting existing collection {collection_name}")
                     log.info(f"deleting existing collection {collection_name}")
                     CHROMA_CLIENT.delete_collection(name=collection_name)
                     CHROMA_CLIENT.delete_collection(name=collection_name)
 
 
-        if app.state.RAG_EMBEDDING_ENGINE == "ollama":
-            collection = CHROMA_CLIENT.create_collection(name=collection_name)
+        if app.state.RAG_EMBEDDING_ENGINE == "":
+
+            collection = CHROMA_CLIENT.create_collection(
+                name=collection_name,
+                embedding_function=app.state.sentence_transformer_ef,
+            )
 
 
             for batch in create_batches(
             for batch in create_batches(
                 api=CHROMA_CLIENT,
                 api=CHROMA_CLIENT,
                 ids=[str(uuid.uuid1()) for _ in texts],
                 ids=[str(uuid.uuid1()) for _ in texts],
                 metadatas=metadatas,
                 metadatas=metadatas,
-                embeddings=[
+                documents=texts,
+            ):
+                collection.add(*batch)
+
+        else:
+            if app.state.RAG_EMBEDDING_ENGINE == "ollama":
+                embeddings = [
                     generate_ollama_embeddings(
                     generate_ollama_embeddings(
                         GenerateEmbeddingsForm(
                         GenerateEmbeddingsForm(
-                            **{"model": RAG_EMBEDDING_MODEL, "prompt": text}
+                            **{"model": app.state.RAG_EMBEDDING_MODEL, "prompt": text}
                         )
                         )
                     )
                     )
                     for text in texts
                     for text in texts
-                ],
-            ):
-                collection.add(*batch)
-        else:
-
-            collection = CHROMA_CLIENT.create_collection(
-                name=collection_name,
-                embedding_function=app.state.sentence_transformer_ef,
-            )
+                ]
+            elif app.state.RAG_EMBEDDING_ENGINE == "openai":
+                embeddings = [
+                    generate_openai_embeddings(
+                        model=app.state.RAG_EMBEDDING_MODEL,
+                        text=text,
+                        key=app.state.RAG_OPENAI_API_KEY,
+                        url=app.state.RAG_OPENAI_API_BASE_URL,
+                    )
+                    for text in texts
+                ]
 
 
             for batch in create_batches(
             for batch in create_batches(
                 api=CHROMA_CLIENT,
                 api=CHROMA_CLIENT,
                 ids=[str(uuid.uuid1()) for _ in texts],
                 ids=[str(uuid.uuid1()) for _ in texts],
                 metadatas=metadatas,
                 metadatas=metadatas,
-                documents=texts,
+                embeddings=embeddings,
             ):
             ):
                 collection.add(*batch)
                 collection.add(*batch)
 
 
-            return True
+        return True
     except Exception as e:
     except Exception as e:
         log.exception(e)
         log.exception(e)
         if e.__class__.__name__ == "UniqueConstraintError":
         if e.__class__.__name__ == "UniqueConstraintError":

+ 23 - 0
backend/apps/rag/utils.py

@@ -269,3 +269,26 @@ def get_embedding_model_path(
     except Exception as e:
     except Exception as e:
         log.exception(f"Cannot determine embedding model snapshot path: {e}")
         log.exception(f"Cannot determine embedding model snapshot path: {e}")
         return embedding_model
         return embedding_model
+
+
+def generate_openai_embeddings(
+    model: str, text: str, key: str, url: str = "https://api.openai.com"
+):
+    try:
+        r = requests.post(
+            f"{url}/v1/embeddings",
+            headers={
+                "Content-Type": "application/json",
+                "Authorization": f"Bearer {key}",
+            },
+            json={"input": text, "model": model},
+        )
+        r.raise_for_status()
+        data = r.json()
+        if "data" in data:
+            return data["data"][0]["embedding"]
+        else:
+            raise "Something went wrong :/"
+    except Exception as e:
+        print(e)
+        return None