Timothy J. Baek 1 рік тому
батько
коміт
713934edb6
2 змінених файлів з 46 додано та 14 видалено
  1. 32 0
      backend/apps/audio/main.py
  2. 14 14
      backend/apps/rag/main.py

+ 32 - 0
backend/apps/audio/main.py

@@ -15,6 +15,8 @@ from fastapi.responses import StreamingResponse, JSONResponse, FileResponse
 
 from fastapi.middleware.cors import CORSMiddleware
 from faster_whisper import WhisperModel
+from pydantic import BaseModel
+
 
 import requests
 import hashlib
@@ -67,6 +69,36 @@ SPEECH_CACHE_DIR = Path(CACHE_DIR).joinpath("./audio/speech/")
 SPEECH_CACHE_DIR.mkdir(parents=True, exist_ok=True)
 
 
+class OpenAIConfigUpdateForm(BaseModel):
+    url: str
+    key: str
+
+
+@app.get("/config")
+async def get_openai_config(user=Depends(get_admin_user)):
+    return {
+        "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
+        "OPENAI_API_KEY": app.state.OPENAI_API_KEY,
+    }
+
+
+@app.post("/config/update")
+async def update_openai_config(
+    form_data: OpenAIConfigUpdateForm, user=Depends(get_admin_user)
+):
+    if form_data.key == "":
+        raise HTTPException(status_code=400, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
+
+    app.state.OPENAI_API_BASE_URL = form_data.url
+    app.state.OPENAI_API_KEY = form_data.key
+
+    return {
+        "status": True,
+        "OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL,
+        "OPENAI_API_KEY": app.state.OPENAI_API_KEY,
+    }
+
+
 @app.post("/speech")
 async def speech(request: Request, user=Depends(get_verified_user)):
     idx = None

+ 14 - 14
backend/apps/rag/main.py

@@ -96,8 +96,8 @@ app.state.RAG_EMBEDDING_ENGINE = RAG_EMBEDDING_ENGINE
 app.state.RAG_EMBEDDING_MODEL = RAG_EMBEDDING_MODEL
 app.state.RAG_TEMPLATE = RAG_TEMPLATE
 
-app.state.RAG_OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
-app.state.RAG_OPENAI_API_KEY = RAG_OPENAI_API_KEY
+app.state.OPENAI_API_BASE_URL = RAG_OPENAI_API_BASE_URL
+app.state.OPENAI_API_KEY = RAG_OPENAI_API_KEY
 
 app.state.PDF_EXTRACT_IMAGES = False
 
@@ -150,8 +150,8 @@ async def get_embedding_config(user=Depends(get_admin_user)):
         "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
         "embedding_model": app.state.RAG_EMBEDDING_MODEL,
         "openai_config": {
-            "url": app.state.RAG_OPENAI_API_BASE_URL,
-            "key": app.state.RAG_OPENAI_API_KEY,
+            "url": app.state.OPENAI_API_BASE_URL,
+            "key": app.state.OPENAI_API_KEY,
         },
     }
 
@@ -182,8 +182,8 @@ async def update_embedding_config(
             app.state.sentence_transformer_ef = None
 
             if form_data.openai_config != None:
-                app.state.RAG_OPENAI_API_BASE_URL = form_data.openai_config.url
-                app.state.RAG_OPENAI_API_KEY = form_data.openai_config.key
+                app.state.OPENAI_API_BASE_URL = form_data.openai_config.url
+                app.state.OPENAI_API_KEY = form_data.openai_config.key
         else:
             sentence_transformer_ef = (
                 embedding_functions.SentenceTransformerEmbeddingFunction(
@@ -201,8 +201,8 @@ async def update_embedding_config(
             "embedding_engine": app.state.RAG_EMBEDDING_ENGINE,
             "embedding_model": app.state.RAG_EMBEDDING_MODEL,
             "openai_config": {
-                "url": app.state.RAG_OPENAI_API_BASE_URL,
-                "key": app.state.RAG_OPENAI_API_KEY,
+                "url": app.state.OPENAI_API_BASE_URL,
+                "key": app.state.OPENAI_API_KEY,
             },
         }
 
@@ -317,8 +317,8 @@ def query_doc_handler(
                 query_embeddings = generate_openai_embeddings(
                     model=app.state.RAG_EMBEDDING_MODEL,
                     text=form_data.query,
-                    key=app.state.RAG_OPENAI_API_KEY,
-                    url=app.state.RAG_OPENAI_API_BASE_URL,
+                    key=app.state.OPENAI_API_KEY,
+                    url=app.state.OPENAI_API_BASE_URL,
                 )
 
             return query_embeddings_doc(
@@ -369,8 +369,8 @@ def query_collection_handler(
                 query_embeddings = generate_openai_embeddings(
                     model=app.state.RAG_EMBEDDING_MODEL,
                     text=form_data.query,
-                    key=app.state.RAG_OPENAI_API_KEY,
-                    url=app.state.RAG_OPENAI_API_BASE_URL,
+                    key=app.state.OPENAI_API_KEY,
+                    url=app.state.OPENAI_API_BASE_URL,
                 )
 
             return query_embeddings_collection(
@@ -486,8 +486,8 @@ def store_docs_in_vector_db(docs, collection_name, overwrite: bool = False) -> b
                     generate_openai_embeddings(
                         model=app.state.RAG_EMBEDDING_MODEL,
                         text=text,
-                        key=app.state.RAG_OPENAI_API_KEY,
-                        url=app.state.RAG_OPENAI_API_BASE_URL,
+                        key=app.state.OPENAI_API_KEY,
+                        url=app.state.OPENAI_API_BASE_URL,
                     )
                     for text in texts
                 ]