Selaa lähdekoodia

feat: chromadb vector store api

Timothy J. Baek 1 vuosi sitten
vanhempi
commit
784b369cc9
4 muutettua tiedostoa jossa 119 lisäystä ja 11 poistoa
  1. 2 1
      backend/.gitignore
  2. 97 3
      backend/apps/rag/main.py
  3. 16 5
      backend/config.py
  4. 4 2
      backend/constants.py

+ 2 - 1
backend/.gitignore

@@ -5,4 +5,5 @@ uploads
 .ipynb_checkpoints
 *.db
 _test
-Pipfile
+Pipfile
+data/*

+ 97 - 3
backend/apps/rag/main.py

@@ -1,9 +1,25 @@
-from fastapi import FastAPI, Request, Depends, HTTPException
+from fastapi import FastAPI, Request, Depends, HTTPException, status, UploadFile, File
 from fastapi.middleware.cors import CORSMiddleware
 
-from apps.web.routers import auths, users, chats, modelfiles, utils
-from config import WEBUI_VERSION, WEBUI_AUTH
+from chromadb.utils import embedding_functions
 
+from langchain.document_loaders import WebBaseLoader, TextLoader
+from langchain.text_splitter import RecursiveCharacterTextSplitter
+from langchain_community.vectorstores import Chroma
+from langchain.chains import RetrievalQA
+
+
+from pydantic import BaseModel
+from typing import Optional
+
+import uuid
+
+from config import EMBED_MODEL, CHROMA_CLIENT, CHUNK_SIZE, CHUNK_OVERLAP
+from constants import ERROR_MESSAGES
+
+EMBEDDING_FUNC = embedding_functions.SentenceTransformerEmbeddingFunction(
+    model_name=EMBED_MODEL
+)
 
 app = FastAPI()
 
@@ -18,6 +34,84 @@ app.add_middleware(
 )
 
 
+class StoreWebForm(BaseModel):
+    url: str
+    collection_name: Optional[str] = "test"
+
+
+def store_data_in_vector_db(data, collection_name):
+    text_splitter = RecursiveCharacterTextSplitter(
+        chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP
+    )
+    docs = text_splitter.split_documents(data)
+
+    texts = [doc.page_content for doc in docs]
+    metadatas = [doc.metadata for doc in docs]
+
+    collection = CHROMA_CLIENT.create_collection(
+        name=collection_name, embedding_function=EMBEDDING_FUNC
+    )
+
+    collection.add(
+        documents=texts, metadatas=metadatas, ids=[str(uuid.uuid1()) for _ in texts]
+    )
+
+
 @app.get("/")
 async def get_status():
     return {"status": True}
+
+
+@app.get("/query/{collection_name}")
+def query_collection(collection_name: str, query: str, k: Optional[int] = 4):
+    collection = CHROMA_CLIENT.get_collection(
+        name=collection_name,
+    )
+    result = collection.query(query_texts=[query], n_results=k)
+
+    return result
+
+
+@app.post("/web")
+def store_web(form_data: StoreWebForm):
+    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
+    try:
+        loader = WebBaseLoader(form_data.url)
+        data = loader.load()
+        store_data_in_vector_db(data, form_data.collection_name)
+        return {"status": True}
+    except Exception as e:
+        print(e)
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.DEFAULT(e),
+        )
+
+
+@app.post("/doc")
+def store_doc(file: UploadFile = File(...)):
+    # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
+
+    try:
+        print(file)
+        file.filename = f"{uuid.uuid4()}-{file.filename}"
+        contents = file.file.read()
+        with open(f"./data/{file.filename}", "wb") as f:
+            f.write(contents)
+            f.close()
+
+        # loader = WebBaseLoader(form_data.url)
+        # data = loader.load()
+        # store_data_in_vector_db(data, form_data.collection_name)
+        return {"status": True}
+    except Exception as e:
+        print(e)
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.DEFAULT(e),
+        )
+
+
+def reset_vector_db():
+    CHROMA_CLIENT.reset()
+    return {"status": True}

+ 16 - 5
backend/config.py

@@ -1,11 +1,11 @@
 from dotenv import load_dotenv, find_dotenv
-
-from constants import ERROR_MESSAGES
+import os
+import chromadb
 
 from secrets import token_bytes
 from base64 import b64encode
 
-import os
+from constants import ERROR_MESSAGES
 
 load_dotenv(find_dotenv("../.env"))
 
@@ -19,8 +19,9 @@ ENV = os.environ.get("ENV", "dev")
 # OLLAMA_API_BASE_URL
 ####################################
 
-OLLAMA_API_BASE_URL = os.environ.get("OLLAMA_API_BASE_URL",
-                                     "http://localhost:11434/api")
+OLLAMA_API_BASE_URL = os.environ.get(
+    "OLLAMA_API_BASE_URL", "http://localhost:11434/api"
+)
 
 if ENV == "prod":
     if OLLAMA_API_BASE_URL == "/ollama/api":
@@ -56,3 +57,13 @@ WEBUI_JWT_SECRET_KEY = os.environ.get("WEBUI_JWT_SECRET_KEY", "t0p-s3cr3t")
 
 if WEBUI_AUTH and WEBUI_JWT_SECRET_KEY == "":
     raise ValueError(ERROR_MESSAGES.ENV_VAR_NOT_FOUND)
+
+####################################
+# RAG
+####################################
+
+CHROMA_DATA_PATH = "./data/vector_db"
+EMBED_MODEL = "all-MiniLM-L6-v2"
+CHROMA_CLIENT = chromadb.PersistentClient(path=CHROMA_DATA_PATH)
+CHUNK_SIZE = 1500
+CHUNK_OVERLAP = 100

+ 4 - 2
backend/constants.py

@@ -6,7 +6,6 @@ class MESSAGES(str, Enum):
 
 
 class ERROR_MESSAGES(str, Enum):
-
     def __str__(self) -> str:
         return super().__str__()
 
@@ -30,7 +29,10 @@ class ERROR_MESSAGES(str, Enum):
     UNAUTHORIZED = "401 Unauthorized"
     ACCESS_PROHIBITED = "You do not have permission to access this resource. Please contact your administrator for assistance."
     ACTION_PROHIBITED = (
-        "The requested action has been restricted as a security measure.")
+        "The requested action has been restricted as a security measure."
+    )
+
+    FILE_NOT_SENT = "FILE_NOT_SENT"
     NOT_FOUND = "We could not find what you're looking for :/"
     USER_NOT_FOUND = "We could not find what you're looking for :/"
     API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."