Переглянути джерело

feat: dynamic embedding model load

Timothy J. Baek 1 рік тому
батько
коміт
7c127c35fc
1 змінених файлів з 56 додано та 36 видалено
  1. 56 36
      backend/apps/rag/main.py

+ 56 - 36
backend/apps/rag/main.py

@@ -35,6 +35,8 @@ from pydantic import BaseModel
 from typing import Optional
 import mimetypes
 import uuid
+import json
+
 
 from apps.web.models.documents import (
     Documents,
@@ -63,24 +65,26 @@ from config import (
 from constants import ERROR_MESSAGES
 
 #
-#if RAG_EMBEDDING_MODEL:
+# if RAG_EMBEDDING_MODEL:
 #    sentence_transformer_ef = SentenceTransformer(
 #        model_name_or_path=RAG_EMBEDDING_MODEL,
 #        cache_folder=RAG_EMBEDDING_MODEL_DIR,
 #        device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
 #    )
 
-if RAG_EMBEDDING_MODEL:
-    sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
-        model_name=RAG_EMBEDDING_MODEL,
-        device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
-    )
 
 app = FastAPI()
 
 app.state.CHUNK_SIZE = CHUNK_SIZE
 app.state.CHUNK_OVERLAP = CHUNK_OVERLAP
 app.state.RAG_TEMPLATE = RAG_TEMPLATE
+app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
+app.state.sentence_transformer_ef = (
+    embedding_functions.SentenceTransformerEmbeddingFunction(
+        model_name=app.state.RAG_EMBEDDING_MODEL,
+        device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
+    )
+)
 
 
 origins = ["*"]
@@ -112,14 +116,10 @@ def store_data_in_vector_db(data, collection_name) -> bool:
     metadatas = [doc.metadata for doc in docs]
 
     try:
-        if RAG_EMBEDDING_MODEL:
-            # if you use docker use the model from the environment variable
-            collection = CHROMA_CLIENT.create_collection(
-                name=collection_name, embedding_function=sentence_transformer_ef
-            )
-        else:
-            # for local development use the default model
-            collection = CHROMA_CLIENT.create_collection(name=collection_name)
+        collection = CHROMA_CLIENT.create_collection(
+            name=collection_name,
+            embedding_function=app.state.sentence_transformer_ef,
+        )
 
         collection.add(
             documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
@@ -139,6 +139,38 @@ async def get_status():
         "status": True,
         "chunk_size": app.state.CHUNK_SIZE,
         "chunk_overlap": app.state.CHUNK_OVERLAP,
+        "template": app.state.RAG_TEMPLATE,
+        "embedding_model": app.state.RAG_EMBEDDING_MODEL,
+    }
+
+
+@app.get("/embedding/model")
+async def get_embedding_model(user=Depends(get_admin_user)):
+    return {
+        "status": True,
+        "embedding_model": app.state.RAG_EMBEDDING_MODEL,
+    }
+
+
+class EmbeddingModelUpdateForm(BaseModel):
+    embedding_model: str
+
+
+@app.post("/embedding/model/update")
+async def update_embedding_model(
+    form_data: EmbeddingModelUpdateForm, user=Depends(get_admin_user)
+):
+    app.state.RAG_EMBEDDING_MODEL = form_data.embedding_model
+    app.state.sentence_transformer_ef = (
+        embedding_functions.SentenceTransformerEmbeddingFunction(
+            model_name=app.state.RAG_EMBEDDING_MODEL,
+            device=RAG_EMBEDDING_MODEL_DEVICE_TYPE,
+        )
+    )
+
+    return {
+        "status": True,
+        "embedding_model": app.state.RAG_EMBEDDING_MODEL,
     }
 
 
@@ -203,17 +235,11 @@ def query_doc(
     user=Depends(get_current_user),
 ):
     try:
-        if RAG_EMBEDDING_MODEL:
-            # if you use docker use the model from the environment variable
-            collection = CHROMA_CLIENT.get_collection(
-                name=form_data.collection_name,
-                embedding_function=sentence_transformer_ef,
-            )
-        else:
-            # for local development use the default model
-            collection = CHROMA_CLIENT.get_collection(
-                name=form_data.collection_name,
-            )
+        # if you use docker use the model from the environment variable
+        collection = CHROMA_CLIENT.get_collection(
+            name=form_data.collection_name,
+            embedding_function=app.state.sentence_transformer_ef,
+        )
         result = collection.query(query_texts=[form_data.query], n_results=form_data.k)
         return result
     except Exception as e:
@@ -284,17 +310,11 @@ def query_collection(
 
     for collection_name in form_data.collection_names:
         try:
-            if RAG_EMBEDDING_MODEL:
-                # if you use docker use the model from the environment variable
-                collection = CHROMA_CLIENT.get_collection(
-                    name=collection_name,
-                    embedding_function=sentence_transformer_ef,
-                )
-            else:
-                # for local development use the default model
-                collection = CHROMA_CLIENT.get_collection(
-                    name=collection_name,
-                )
+            # if you use docker use the model from the environment variable
+            collection = CHROMA_CLIENT.get_collection(
+                name=collection_name,
+                embedding_function=app.state.sentence_transformer_ef,
+            )
 
             result = collection.query(
                 query_texts=[form_data.query], n_results=form_data.k