Browse Source

feat: model filter list env var

Timothy J. Baek 1 year ago
parent
commit
a4ca1fc5c4
4 changed files with 41 additions and 16 deletions
  1. 4 4
      backend/apps/ollama/main.py
  2. 10 4
      backend/apps/openai/main.py
  3. 5 0
      backend/config.py
  4. 22 8
      backend/main.py

+ 4 - 4
backend/apps/ollama/main.py

@@ -15,7 +15,7 @@ import asyncio
 from apps.web.models.users import Users
 from apps.web.models.users import Users
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
 from utils.utils import decode_token, get_current_user, get_admin_user
 from utils.utils import decode_token, get_current_user, get_admin_user
-from config import OLLAMA_BASE_URLS
+from config import OLLAMA_BASE_URLS, MODEL_FILTER_ENABLED, MODEL_FILTER_LIST
 
 
 from typing import Optional, List, Union
 from typing import Optional, List, Union
 
 
@@ -30,8 +30,8 @@ app.add_middleware(
 )
 )
 
 
 
 
-app.state.MODEL_FILTER_ENABLED = False
-app.state.MODEL_LIST = []
+app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
+app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
 
 
 app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
 app.state.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
 app.state.MODELS = {}
 app.state.MODELS = {}
@@ -140,7 +140,7 @@ async def get_ollama_tags(
             if user.role == "user":
             if user.role == "user":
                 models["models"] = list(
                 models["models"] = list(
                     filter(
                     filter(
-                        lambda model: model["name"] in app.state.MODEL_LIST,
+                        lambda model: model["name"] in app.state.MODEL_FILTER_LIST,
                         models["models"],
                         models["models"],
                     )
                     )
                 )
                 )

+ 10 - 4
backend/apps/openai/main.py

@@ -18,7 +18,13 @@ from utils.utils import (
     get_verified_user,
     get_verified_user,
     get_admin_user,
     get_admin_user,
 )
 )
-from config import OPENAI_API_BASE_URLS, OPENAI_API_KEYS, CACHE_DIR
+from config import (
+    OPENAI_API_BASE_URLS,
+    OPENAI_API_KEYS,
+    CACHE_DIR,
+    MODEL_FILTER_ENABLED,
+    MODEL_FILTER_LIST,
+)
 from typing import List, Optional
 from typing import List, Optional
 
 
 
 
@@ -34,8 +40,8 @@ app.add_middleware(
     allow_headers=["*"],
     allow_headers=["*"],
 )
 )
 
 
-app.state.MODEL_FILTER_ENABLED = False
-app.state.MODEL_LIST = []
+app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
+app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
 
 
 app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
 app.state.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
 app.state.OPENAI_API_KEYS = OPENAI_API_KEYS
 app.state.OPENAI_API_KEYS = OPENAI_API_KEYS
@@ -198,7 +204,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_use
             if user.role == "user":
             if user.role == "user":
                 models["data"] = list(
                 models["data"] = list(
                     filter(
                     filter(
-                        lambda model: model["id"] in app.state.MODEL_LIST,
+                        lambda model: model["id"] in app.state.MODEL_FILTER_LIST,
                         models["data"],
                         models["data"],
                     )
                     )
                 )
                 )

+ 5 - 0
backend/config.py

@@ -292,6 +292,11 @@ DEFAULT_USER_ROLE = os.getenv("DEFAULT_USER_ROLE", "pending")
 USER_PERMISSIONS = {"chat": {"deletion": True}}
 USER_PERMISSIONS = {"chat": {"deletion": True}}
 
 
 
 
+MODEL_FILTER_ENABLED = os.environ.get("MODEL_FILTER_ENABLED", False)
+MODEL_FILTER_LIST = os.environ.get("MODEL_FILTER_LIST", "")
+MODEL_FILTER_LIST = [model.strip() for model in MODEL_FILTER_LIST.split(";")]
+
+
 ####################################
 ####################################
 # WEBUI_VERSION
 # WEBUI_VERSION
 ####################################
 ####################################

+ 22 - 8
backend/main.py

@@ -30,7 +30,15 @@ from typing import List
 from utils.utils import get_admin_user
 from utils.utils import get_admin_user
 from apps.rag.utils import query_doc, query_collection, rag_template
 from apps.rag.utils import query_doc, query_collection, rag_template
 
 
-from config import WEBUI_NAME, ENV, VERSION, CHANGELOG, FRONTEND_BUILD_DIR
+from config import (
+    WEBUI_NAME,
+    ENV,
+    VERSION,
+    CHANGELOG,
+    FRONTEND_BUILD_DIR,
+    MODEL_FILTER_ENABLED,
+    MODEL_FILTER_LIST,
+)
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
 
 
 
 
@@ -47,8 +55,8 @@ class SPAStaticFiles(StaticFiles):
 
 
 app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None)
 app = FastAPI(docs_url="/docs" if ENV == "dev" else None, redoc_url=None)
 
 
-app.state.MODEL_FILTER_ENABLED = False
-app.state.MODEL_LIST = []
+app.state.MODEL_FILTER_ENABLED = MODEL_FILTER_ENABLED
+app.state.MODEL_FILTER_LIST = MODEL_FILTER_LIST
 
 
 origins = ["*"]
 origins = ["*"]
 
 
@@ -222,7 +230,10 @@ async def get_app_config():
 
 
 @app.get("/api/config/model/filter")
 @app.get("/api/config/model/filter")
 async def get_model_filter_config(user=Depends(get_admin_user)):
 async def get_model_filter_config(user=Depends(get_admin_user)):
-    return {"enabled": app.state.MODEL_FILTER_ENABLED, "models": app.state.MODEL_LIST}
+    return {
+        "enabled": app.state.MODEL_FILTER_ENABLED,
+        "models": app.state.MODEL_FILTER_LIST,
+    }
 
 
 
 
 class ModelFilterConfigForm(BaseModel):
 class ModelFilterConfigForm(BaseModel):
@@ -236,15 +247,18 @@ async def get_model_filter_config(
 ):
 ):
 
 
     app.state.MODEL_FILTER_ENABLED = form_data.enabled
     app.state.MODEL_FILTER_ENABLED = form_data.enabled
-    app.state.MODEL_LIST = form_data.models
+    app.state.MODEL_FILTER_LIST = form_data.models
 
 
     ollama_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
     ollama_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
-    ollama_app.state.MODEL_LIST = app.state.MODEL_LIST
+    ollama_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
 
 
     openai_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
     openai_app.state.MODEL_FILTER_ENABLED = app.state.MODEL_FILTER_ENABLED
-    openai_app.state.MODEL_LIST = app.state.MODEL_LIST
+    openai_app.state.MODEL_FILTER_LIST = app.state.MODEL_FILTER_LIST
 
 
-    return {"enabled": app.state.MODEL_FILTER_ENABLED, "models": app.state.MODEL_LIST}
+    return {
+        "enabled": app.state.MODEL_FILTER_ENABLED,
+        "models": app.state.MODEL_FILTER_LIST,
+    }
 
 
 
 
 @app.get("/api/version")
 @app.get("/api/version")