Timothy Jaeryang Baek 5 meses atrás
pai
commit
932de8f1e2

+ 46 - 37
backend/open_webui/apps/ollama/main.py

@@ -68,25 +68,12 @@ app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
 app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
 app.state.config.OLLAMA_API_CONFIGS = OLLAMA_API_CONFIGS
 
-app.state.MODELS = {}
-
 
 # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
 # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
 # least connections, or least response time for better resource utilization and performance optimization.
 
 
-@app.middleware("http")
-async def check_url(request: Request, call_next):
-    if len(app.state.MODELS) == 0:
-        await get_all_models()
-    else:
-        pass
-
-    response = await call_next(request)
-    return response
-
-
 @app.head("/")
 @app.get("/")
 async def get_status():
@@ -321,8 +308,6 @@ async def get_all_models():
     else:
         models = {"models": []}
 
-    app.state.MODELS = {model["model"]: model for model in models["models"]}
-
     return models
 
 
@@ -470,8 +455,11 @@ async def push_model(
     user=Depends(get_admin_user),
 ):
     if url_idx is None:
-        if form_data.name in app.state.MODELS:
-            url_idx = app.state.MODELS[form_data.name]["urls"][0]
+        model_list = await get_all_models()
+        models = {model["model"]: model for model in model_list["models"]}
+
+        if form_data.name in models:
+            url_idx = models[form_data.name]["urls"][0]
         else:
             raise HTTPException(
                 status_code=400,
@@ -520,8 +508,11 @@ async def copy_model(
     user=Depends(get_admin_user),
 ):
     if url_idx is None:
-        if form_data.source in app.state.MODELS:
-            url_idx = app.state.MODELS[form_data.source]["urls"][0]
+        model_list = await get_all_models()
+        models = {model["model"]: model for model in model_list["models"]}
+
+        if form_data.source in models:
+            url_idx = models[form_data.source]["urls"][0]
         else:
             raise HTTPException(
                 status_code=400,
@@ -576,8 +567,11 @@ async def delete_model(
     user=Depends(get_admin_user),
 ):
     if url_idx is None:
-        if form_data.name in app.state.MODELS:
-            url_idx = app.state.MODELS[form_data.name]["urls"][0]
+        model_list = await get_all_models()
+        models = {model["model"]: model for model in model_list["models"]}
+
+        if form_data.name in models:
+            url_idx = models[form_data.name]["urls"][0]
         else:
             raise HTTPException(
                 status_code=400,
@@ -625,13 +619,16 @@ async def delete_model(
 
 @app.post("/api/show")
 async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_user)):
-    if form_data.name not in app.state.MODELS:
+    model_list = await get_all_models()
+    models = {model["model"]: model for model in model_list["models"]}
+
+    if form_data.name not in models:
         raise HTTPException(
             status_code=400,
             detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.name),
         )
 
-    url_idx = random.choice(app.state.MODELS[form_data.name]["urls"])
+    url_idx = random.choice(models[form_data.name]["urls"])
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     log.info(f"url: {url}")
 
@@ -701,23 +698,26 @@ async def generate_embeddings(
     url_idx: Optional[int] = None,
     user=Depends(get_verified_user),
 ):
-    return generate_ollama_embeddings(form_data=form_data, url_idx=url_idx)
+    return await generate_ollama_embeddings(form_data=form_data, url_idx=url_idx)
 
 
-def generate_ollama_embeddings(
+async def generate_ollama_embeddings(
     form_data: GenerateEmbeddingsForm,
     url_idx: Optional[int] = None,
 ):
     log.info(f"generate_ollama_embeddings {form_data}")
 
     if url_idx is None:
+        model_list = await get_all_models()
+        models = {model["model"]: model for model in model_list["models"]}
+
         model = form_data.model
 
         if ":" not in model:
             model = f"{model}:latest"
 
-        if model in app.state.MODELS:
-            url_idx = random.choice(app.state.MODELS[model]["urls"])
+        if model in models:
+            url_idx = random.choice(models[model]["urls"])
         else:
             raise HTTPException(
                 status_code=400,
@@ -768,20 +768,23 @@ def generate_ollama_embeddings(
         )
 
 
-def generate_ollama_batch_embeddings(
+async def generate_ollama_batch_embeddings(
     form_data: GenerateEmbedForm,
     url_idx: Optional[int] = None,
 ):
     log.info(f"generate_ollama_batch_embeddings {form_data}")
 
     if url_idx is None:
+        model_list = await get_all_models()
+        models = {model["model"]: model for model in model_list["models"]}
+
         model = form_data.model
 
         if ":" not in model:
             model = f"{model}:latest"
 
-        if model in app.state.MODELS:
-            url_idx = random.choice(app.state.MODELS[model]["urls"])
+        if model in models:
+            url_idx = random.choice(models[model]["urls"])
         else:
             raise HTTPException(
                 status_code=400,
@@ -851,13 +854,16 @@ async def generate_completion(
     user=Depends(get_verified_user),
 ):
     if url_idx is None:
+        model_list = await get_all_models()
+        models = {model["model"]: model for model in model_list["models"]}
+
         model = form_data.model
 
         if ":" not in model:
             model = f"{model}:latest"
 
-        if model in app.state.MODELS:
-            url_idx = random.choice(app.state.MODELS[model]["urls"])
+        if model in models:
+            url_idx = random.choice(models[model]["urls"])
         else:
             raise HTTPException(
                 status_code=400,
@@ -892,14 +898,17 @@ class GenerateChatCompletionForm(BaseModel):
     keep_alive: Optional[Union[int, str]] = None
 
 
-def get_ollama_url(url_idx: Optional[int], model: str):
+async def get_ollama_url(url_idx: Optional[int], model: str):
     if url_idx is None:
-        if model not in app.state.MODELS:
+        model_list = await get_all_models()
+        models = {model["model"]: model for model in model_list["models"]}
+
+        if model not in models:
             raise HTTPException(
                 status_code=400,
                 detail=ERROR_MESSAGES.MODEL_NOT_FOUND(model),
             )
-        url_idx = random.choice(app.state.MODELS[model]["urls"])
+        url_idx = random.choice(models[model]["urls"])
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     return url
 
@@ -948,7 +957,7 @@ async def generate_chat_completion(
     if ":" not in payload["model"]:
         payload["model"] = f"{payload['model']}:latest"
 
-    url = get_ollama_url(url_idx, payload["model"])
+    url = await get_ollama_url(url_idx, payload["model"])
     log.info(f"url: {url}")
     log.debug(f"generate_chat_completion() - 2.payload = {payload}")
 
@@ -1030,7 +1039,7 @@ async def generate_openai_chat_completion(
     if ":" not in payload["model"]:
         payload["model"] = f"{payload['model']}:latest"
 
-    url = get_ollama_url(url_idx, payload["model"])
+    url = await get_ollama_url(url_idx, payload["model"])
     log.info(f"url: {url}")
 
     api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})

+ 57 - 58
backend/open_webui/apps/openai/main.py

@@ -36,7 +36,7 @@ from open_webui.utils.payload import (
     apply_model_system_prompt_to_body,
 )
 
-from open_webui.utils.utils import get_admin_user, get_verified_user
+from open_webui.utils.utils import get_admin_user, get_verified_user, has_access
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["OPENAI"])
@@ -64,17 +64,6 @@ app.state.config.OPENAI_API_BASE_URLS = OPENAI_API_BASE_URLS
 app.state.config.OPENAI_API_KEYS = OPENAI_API_KEYS
 app.state.config.OPENAI_API_CONFIGS = OPENAI_API_CONFIGS
 
-app.state.MODELS = {}
-
-
-@app.middleware("http")
-async def check_url(request: Request, call_next):
-    if len(app.state.MODELS) == 0:
-        await get_all_models()
-
-    response = await call_next(request)
-    return response
-
 
 @app.get("/config")
 async def get_config(user=Depends(get_admin_user)):
@@ -259,7 +248,7 @@ def merge_models_lists(model_lists):
     return merged_list
 
 
-async def get_all_models_raw() -> list:
+async def get_all_models_responses() -> list:
     if not app.state.config.ENABLE_OPENAI_API:
         return []
 
@@ -330,22 +319,13 @@ async def get_all_models_raw() -> list:
     return responses
 
 
-@overload
-async def get_all_models(raw: Literal[True]) -> list: ...
-
-
-@overload
-async def get_all_models(raw: Literal[False] = False) -> dict[str, list]: ...
-
-
-async def get_all_models(raw=False) -> dict[str, list] | list:
+async def get_all_models() -> dict[str, list]:
     log.info("get_all_models()")
+
     if not app.state.config.ENABLE_OPENAI_API:
-        return [] if raw else {"data": []}
+        return {"data": []}
 
-    responses = await get_all_models_raw()
-    if raw:
-        return responses
+    responses = await get_all_models_responses()
 
     def extract_data(response):
         if response and "data" in response:
@@ -355,9 +335,7 @@ async def get_all_models(raw=False) -> dict[str, list] | list:
         return None
 
     models = {"data": merge_models_lists(map(extract_data, responses))}
-
     log.debug(f"models: {models}")
-    app.state.MODELS = {model["id"]: model for model in models["data"]}
 
     return models
 
@@ -365,21 +343,12 @@ async def get_all_models(raw=False) -> dict[str, list] | list:
 @app.get("/models")
 @app.get("/models/{url_idx}")
 async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
+    models = {
+        "data": [],
+    }
+
     if url_idx is None:
         models = await get_all_models()
-
-        # TODO: Check User Group and Filter Models
-        # if app.state.config.ENABLE_MODEL_FILTER:
-        #     if user.role == "user":
-        #         models["data"] = list(
-        #             filter(
-        #                 lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
-        #                 models["data"],
-        #             )
-        #         )
-        #         return models
-
-        return models
     else:
         url = app.state.config.OPENAI_API_BASE_URLS[url_idx]
         key = app.state.config.OPENAI_API_KEYS[url_idx]
@@ -387,6 +356,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
         headers = {}
         headers["Authorization"] = f"Bearer {key}"
         headers["Content-Type"] = "application/json"
+
         if ENABLE_FORWARD_USER_INFO_HEADERS:
             headers["X-OpenWebUI-User-Name"] = user.name
             headers["X-OpenWebUI-User-Id"] = user.id
@@ -428,8 +398,7 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
                             )
                         ]
 
-                    return response_data
-
+                    models = response_data
             except aiohttp.ClientError as e:
                 # ClientError covers all aiohttp requests issues
                 log.exception(f"Client error: {str(e)}")
@@ -443,6 +412,22 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
                 error_detail = f"Unexpected error: {str(e)}"
                 raise HTTPException(status_code=500, detail=error_detail)
 
+    if user.role == "user":
+        # Filter models based on user access control
+        filtered_models = []
+        for model in models.get("data", []):
+            model_info = Models.get_model_by_id(model["id"])
+            if model_info:
+                if has_access(
+                    user.id, type="read", access_control=model_info.access_control
+                ):
+                    filtered_models.append(model)
+            else:
+                filtered_models.append(model)
+        models["data"] = filtered_models
+
+    return models
+
 
 class ConnectionVerificationForm(BaseModel):
     url: str
@@ -502,18 +487,9 @@ async def generate_chat_completion(
         del payload["metadata"]
 
     model_id = form_data.get("model")
-
-    # TODO: Check User Group and Filter Models
-    # if not bypass_filter:
-    #     if app.state.config.ENABLE_MODEL_FILTER:
-    #         if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
-    #             raise HTTPException(
-    #                 status_code=403,
-    #                 detail="Model not found",
-    #             )
-
     model_info = Models.get_model_by_id(model_id)
 
+    # Check model info and override the payload
     if model_info:
         if model_info.base_model_id:
             payload["model"] = model_info.base_model_id
@@ -522,12 +498,33 @@ async def generate_chat_completion(
         payload = apply_model_params_to_body_openai(params, payload)
         payload = apply_model_system_prompt_to_body(params, payload, user)
 
-    try:
-        model = app.state.MODELS[payload.get("model")]
+        # Check if user has access to the model
+        if user.role == "user" and not has_access(
+            user.id, type="read", access_control=model_info.access_control
+        ):
+            raise HTTPException(
+                status_code=403,
+                detail="Model not found",
+            )
+
+    # Attemp to get urlIdx from the model
+    models = await get_all_models()
+
+    # Find the model from the list
+    model = next(
+        (model for model in models["data"] if model["id"] == payload.get("model")),
+        None,
+    )
+
+    if model:
         idx = model["urlIdx"]
-    except Exception as e:
-        raise HTTPException(status_code=404, detail="Model not found")
+    else:
+        raise HTTPException(
+            status_code=404,
+            detail="Model not found",
+        )
 
+    # Get the API config for the model
     api_config = app.state.config.OPENAI_API_CONFIGS.get(
         app.state.config.OPENAI_API_BASE_URLS[idx], {}
     )
@@ -536,6 +533,7 @@ async def generate_chat_completion(
     if prefix_id:
         payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
 
+    # Add user info to the payload if the model is a pipeline
     if "pipeline" in model and model.get("pipeline"):
         payload["user"] = {
             "name": user.name,
@@ -546,8 +544,9 @@ async def generate_chat_completion(
 
     url = app.state.config.OPENAI_API_BASE_URLS[idx]
     key = app.state.config.OPENAI_API_KEYS[idx]
-    is_o1 = payload["model"].lower().startswith("o1-")
 
+    # Fix: O1 does not support the "max_tokens" parameter, Modify "max_tokens" to "max_completion_tokens"
+    is_o1 = payload["model"].lower().startswith("o1-")
     # Change max_completion_tokens to max_tokens (Backward compatible)
     if "api.openai.com" not in url and not is_o1:
         if "max_completion_tokens" in payload:

+ 16 - 7
backend/open_webui/apps/retrieval/utils.py

@@ -3,6 +3,7 @@ import os
 import uuid
 from typing import Optional, Union
 
+import asyncio
 import requests
 
 from huggingface_hub import snapshot_download
@@ -291,7 +292,13 @@ def get_embedding_function(
     if embedding_engine == "":
         return lambda query: embedding_function.encode(query).tolist()
     elif embedding_engine in ["ollama", "openai"]:
-        func = lambda query: generate_embeddings(
+
+        # Wrapper to run the async generate_embeddings synchronously.
+        def sync_generate_embeddings(*args, **kwargs):
+            return asyncio.run(generate_embeddings(*args, **kwargs))
+
+        # Semantic expectation from the original version (using sync wrapper).
+        func = lambda query: sync_generate_embeddings(
             engine=embedding_engine,
             model=embedding_model,
             text=query,
@@ -469,7 +476,7 @@ def get_model_path(model: str, update_model: bool = False):
         return model
 
 
-def generate_openai_batch_embeddings(
+async def generate_openai_batch_embeddings(
     model: str, texts: list[str], key: str, url: str = "https://api.openai.com/v1"
 ) -> Optional[list[list[float]]]:
     try:
@@ -492,14 +499,16 @@ def generate_openai_batch_embeddings(
         return None
 
 
-def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **kwargs):
+async def generate_embeddings(
+    engine: str, model: str, text: Union[str, list[str]], **kwargs
+):
     if engine == "ollama":
         if isinstance(text, list):
-            embeddings = generate_ollama_batch_embeddings(
+            embeddings = await generate_ollama_batch_embeddings(
                 GenerateEmbedForm(**{"model": model, "input": text})
             )
         else:
-            embeddings = generate_ollama_batch_embeddings(
+            embeddings = await generate_ollama_batch_embeddings(
                 GenerateEmbedForm(**{"model": model, "input": [text]})
             )
         return (
@@ -512,9 +521,9 @@ def generate_embeddings(engine: str, model: str, text: Union[str, list[str]], **
         url = kwargs.get("url", "https://api.openai.com/v1")
 
         if isinstance(text, list):
-            embeddings = generate_openai_batch_embeddings(model, text, key, url)
+            embeddings = await generate_openai_batch_embeddings(model, text, key, url)
         else:
-            embeddings = generate_openai_batch_embeddings(model, [text], key, url)
+            embeddings = await generate_openai_batch_embeddings(model, [text], key, url)
 
         return embeddings[0] if isinstance(text, str) else embeddings
 

+ 2 - 3
backend/open_webui/apps/webui/main.py

@@ -142,7 +142,6 @@ app.state.config.LDAP_USE_TLS = LDAP_USE_TLS
 app.state.config.LDAP_CA_CERT_FILE = LDAP_CA_CERT_FILE
 app.state.config.LDAP_CIPHERS = LDAP_CIPHERS
 
-app.state.MODELS = {}
 app.state.TOOLS = {}
 app.state.FUNCTIONS = {}
 
@@ -369,7 +368,7 @@ def get_function_params(function_module, form_data, user, extra_params=None):
     return params
 
 
-async def generate_function_chat_completion(form_data, user):
+async def generate_function_chat_completion(form_data, user, models: dict = {}):
     model_id = form_data.get("model")
     model_info = Models.get_model_by_id(model_id)
 
@@ -412,7 +411,7 @@ async def generate_function_chat_completion(form_data, user):
         user,
         {
             **extra_params,
-            "__model__": app.state.MODELS[form_data["model"]],
+            "__model__": models.get(form_data["model"], None),
             "__messages__": form_data["messages"],
             "__files__": files,
         },

+ 179 - 103
backend/open_webui/main.py

@@ -11,6 +11,7 @@ import random
 from contextlib import asynccontextmanager
 from typing import Optional
 
+from aiocache import cached
 import aiohttp
 import requests
 from fastapi import (
@@ -45,6 +46,7 @@ from open_webui.apps.openai.main import (
     app as openai_app,
     generate_chat_completion as generate_openai_chat_completion,
     get_all_models as get_openai_models,
+    get_all_models_responses as get_openai_models_responses,
 )
 from open_webui.apps.retrieval.main import app as retrieval_app
 from open_webui.apps.retrieval.utils import get_rag_context, rag_template
@@ -132,6 +134,7 @@ from open_webui.utils.utils import (
     get_current_user,
     get_http_authorization_cred,
     get_verified_user,
+    has_access,
 )
 
 if SAFE_MODE:
@@ -196,20 +199,22 @@ app.state.config.WEBHOOK_URL = WEBHOOK_URL
 
 app.state.config.TASK_MODEL = TASK_MODEL
 app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
+
 app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
-app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE
+
 app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
+app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE
+
+
+app.state.config.ENABLE_SEARCH_QUERY = ENABLE_SEARCH_QUERY
 app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
     SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
 )
-app.state.config.ENABLE_SEARCH_QUERY = ENABLE_SEARCH_QUERY
+
 app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
     TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
 )
 
-app.state.MODELS = {}
-
-
 ##################################
 #
 # ChatCompletion Middleware
@@ -217,26 +222,6 @@ app.state.MODELS = {}
 ##################################
 
 
-def get_task_model_id(default_model_id):
-    # Set the task model
-    task_model_id = default_model_id
-    # Check if the user has a custom task model and use that model
-    if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
-        if (
-            app.state.config.TASK_MODEL
-            and app.state.config.TASK_MODEL in app.state.MODELS
-        ):
-            task_model_id = app.state.config.TASK_MODEL
-    else:
-        if (
-            app.state.config.TASK_MODEL_EXTERNAL
-            and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
-        ):
-            task_model_id = app.state.config.TASK_MODEL_EXTERNAL
-
-    return task_model_id
-
-
 def get_filter_function_ids(model):
     def get_priority(function_id):
         function = Functions.get_function_by_id(function_id)
@@ -366,8 +351,24 @@ async def get_content_from_response(response) -> Optional[str]:
     return content
 
 
+def get_task_model_id(
+    default_model_id: str, task_model: str, task_model_external: str, models
+) -> str:
+    # Set the task model
+    task_model_id = default_model_id
+    # Check if the user has a custom task model and use that model
+    if models[task_model_id]["owned_by"] == "ollama":
+        if task_model and task_model in models:
+            task_model_id = task_model
+    else:
+        if task_model_external and task_model_external in models:
+            task_model_id = task_model_external
+
+    return task_model_id
+
+
 async def chat_completion_tools_handler(
-    body: dict, user: UserModel, extra_params: dict
+    body: dict, user: UserModel, models, extra_params: dict
 ) -> tuple[dict, dict]:
     # If tool_ids field is present, call the functions
     metadata = body.get("metadata", {})
@@ -381,14 +382,19 @@ async def chat_completion_tools_handler(
     contexts = []
     citations = []
 
-    task_model_id = get_task_model_id(body["model"])
+    task_model_id = get_task_model_id(
+        body["model"],
+        app.state.config.TASK_MODEL,
+        app.state.config.TASK_MODEL_EXTERNAL,
+        models,
+    )
     tools = get_tools(
         webui_app,
         tool_ids,
         user,
         {
             **extra_params,
-            "__model__": app.state.MODELS[task_model_id],
+            "__model__": models[task_model_id],
             "__messages__": body["messages"],
             "__files__": metadata.get("files", []),
         },
@@ -412,7 +418,7 @@ async def chat_completion_tools_handler(
     )
 
     try:
-        payload = filter_pipeline(payload, user)
+        payload = filter_pipeline(payload, user, models)
     except Exception as e:
         raise e
 
@@ -513,16 +519,16 @@ def is_chat_completion_request(request):
     )
 
 
-async def get_body_and_model_and_user(request):
+async def get_body_and_model_and_user(request, models):
     # Read the original request body
     body = await request.body()
     body_str = body.decode("utf-8")
     body = json.loads(body_str) if body_str else {}
 
     model_id = body["model"]
-    if model_id not in app.state.MODELS:
+    if model_id not in models:
         raise Exception("Model not found")
-    model = app.state.MODELS[model_id]
+    model = models[model_id]
 
     user = get_current_user(
         request,
@@ -538,14 +544,27 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
             return await call_next(request)
         log.debug(f"request.url.path: {request.url.path}")
 
+        model_list = await get_all_models()
+        models = {model["id"]: model for model in model_list}
+
         try:
-            body, model, user = await get_body_and_model_and_user(request)
+            body, model, user = await get_body_and_model_and_user(request, models)
         except Exception as e:
             return JSONResponse(
                 status_code=status.HTTP_400_BAD_REQUEST,
                 content={"detail": str(e)},
             )
 
+        model_info = Models.get_model_by_id(model["id"])
+        if user.role == "user":
+            if model_info and not has_access(
+                user.id, type="read", access_control=model_info.access_control
+            ):
+                return JSONResponse(
+                    status_code=status.HTTP_403_FORBIDDEN,
+                    content={"detail": "User does not have access to the model"},
+                )
+
         metadata = {
             "chat_id": body.pop("chat_id", None),
             "message_id": body.pop("id", None),
@@ -582,15 +601,20 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                 content={"detail": str(e)},
             )
 
+        tool_ids = body.pop("tool_ids", None)
+        files = body.pop("files", None)
+
         metadata = {
             **metadata,
-            "tool_ids": body.pop("tool_ids", None),
-            "files": body.pop("files", None),
+            "tool_ids": tool_ids,
+            "files": files,
         }
         body["metadata"] = metadata
 
         try:
-            body, flags = await chat_completion_tools_handler(body, user, extra_params)
+            body, flags = await chat_completion_tools_handler(
+                body, user, models, extra_params
+            )
             contexts.extend(flags.get("contexts", []))
             citations.extend(flags.get("citations", []))
         except Exception as e:
@@ -687,10 +711,10 @@ app.add_middleware(ChatCompletionMiddleware)
 ##################################
 
 
-def get_sorted_filters(model_id):
+def get_sorted_filters(model_id, models):
     filters = [
         model
-        for model in app.state.MODELS.values()
+        for model in models.values()
         if "pipeline" in model
         and "type" in model["pipeline"]
         and model["pipeline"]["type"] == "filter"
@@ -706,12 +730,12 @@ def get_sorted_filters(model_id):
     return sorted_filters
 
 
-def filter_pipeline(payload, user):
+def filter_pipeline(payload, user, models):
     user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
     model_id = payload["model"]
-    sorted_filters = get_sorted_filters(model_id)
 
-    model = app.state.MODELS[model_id]
+    sorted_filters = get_sorted_filters(model_id, models)
+    model = models[model_id]
 
     if "pipeline" in model:
         sorted_filters.append(model)
@@ -782,8 +806,11 @@ class PipelineMiddleware(BaseHTTPMiddleware):
                     content={"detail": "Not authenticated"},
                 )
 
+        model_list = await get_all_models()
+        models = {model["id"]: model for model in model_list}
+
         try:
-            data = filter_pipeline(data, user)
+            data = filter_pipeline(data, user, models)
         except Exception as e:
             if len(e.args) > 1:
                 return JSONResponse(
@@ -862,16 +889,10 @@ async def commit_session_after_request(request: Request, call_next):
 
 @app.middleware("http")
 async def check_url(request: Request, call_next):
-    if len(app.state.MODELS) == 0:
-        await get_all_models()
-    else:
-        pass
-
     start_time = int(time.time())
     response = await call_next(request)
     process_time = int(time.time()) - start_time
     response.headers["X-Process-Time"] = str(process_time)
-
     return response
 
 
@@ -911,10 +932,10 @@ app.mount("/retrieval/api/v1", retrieval_app)
 
 app.mount("/api/v1", webui_app)
 
-
 webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION
 
 
+@cached(ttl=1)
 async def get_all_base_models():
     open_webui_models = []
     openai_models = []
@@ -944,6 +965,7 @@ async def get_all_base_models():
     return models
 
 
+@cached(ttl=1)
 async def get_all_models():
     models = await get_all_base_models()
 
@@ -1065,9 +1087,6 @@ async def get_all_models():
 
             function_module = get_function_module_by_id(action_id)
             model["actions"].extend(get_action_items_from_module(function_module))
-
-    app.state.MODELS = {model["id"]: model for model in models}
-    webui_app.state.MODELS = app.state.MODELS
     return models
 
 
@@ -1082,16 +1101,19 @@ async def get_models(user=Depends(get_verified_user)):
         if "pipeline" not in model or model["pipeline"].get("type", None) != "filter"
     ]
 
-    # TODO: Check User Group and Filter Models
-    # if app.state.config.ENABLE_MODEL_FILTER:
-    #     if user.role == "user":
-    #         models = list(
-    #             filter(
-    #                 lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
-    #                 models,
-    #             )
-    #         )
-    #         return {"data": models}
+    # Filter out models that the user does not have access to
+    if user.role == "user":
+        filtered_models = []
+        for model in models:
+            model_info = Models.get_model_by_id(model["id"])
+            if model_info:
+                if has_access(
+                    user.id, type="read", access_control=model_info.access_control
+                ):
+                    filtered_models.append(model)
+            else:
+                filtered_models.append(model)
+        models = filtered_models
 
     return {"data": models}
 
@@ -1106,24 +1128,27 @@ async def get_base_models(user=Depends(get_admin_user)):
 async def generate_chat_completions(
     form_data: dict, user=Depends(get_verified_user), bypass_filter: bool = False
 ):
-    model_id = form_data["model"]
+    model_list = await get_all_models()
+    models = {model["id"]: model for model in model_list}
 
-    if model_id not in app.state.MODELS:
+    model_id = form_data["model"]
+    if model_id not in models:
         raise HTTPException(
             status_code=status.HTTP_404_NOT_FOUND,
             detail="Model not found",
         )
 
-    # TODO: Check User Group and Filter Models
-    # if not bypass_filter:
-    #     if app.state.config.ENABLE_MODEL_FILTER:
-    #         if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
-    #             raise HTTPException(
-    #                 status_code=status.HTTP_403_FORBIDDEN,
-    #                 detail="Model not found",
-    #             )
-
-    model = app.state.MODELS[model_id]
+    model = models[model_id]
+    # Check if user has access to the model
+    if user.role == "user":
+        model_info = Models.get_model_by_id(model_id)
+        if not has_access(
+            user.id, type="read", access_control=model_info.access_control
+        ):
+            raise HTTPException(
+                status_code=403,
+                detail="Model not found",
+            )
 
     if model["owned_by"] == "arena":
         model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
@@ -1174,7 +1199,9 @@ async def generate_chat_completions(
 
     if model.get("pipe"):
         # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
-        return await generate_function_chat_completion(form_data, user=user)
+        return await generate_function_chat_completion(
+            form_data, user=user, models=models
+        )
     if model["owned_by"] == "ollama":
         # Using /ollama/api/chat endpoint
         form_data = convert_payload_openai_to_ollama(form_data)
@@ -1198,16 +1225,20 @@ async def generate_chat_completions(
 
 @app.post("/api/chat/completed")
 async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
+
+    model_list = await get_all_models()
+    models = {model["id"]: model for model in model_list}
+
     data = form_data
     model_id = data["model"]
-    if model_id not in app.state.MODELS:
+    if model_id not in models:
         raise HTTPException(
             status_code=status.HTTP_404_NOT_FOUND,
             detail="Model not found",
         )
-    model = app.state.MODELS[model_id]
 
-    sorted_filters = get_sorted_filters(model_id)
+    model = models[model_id]
+    sorted_filters = get_sorted_filters(model_id, models)
     if "pipeline" in model:
         sorted_filters = [model] + sorted_filters
 
@@ -1382,14 +1413,18 @@ async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified
             detail="Action not found",
         )
 
+    model_list = await get_all_models()
+    models = {model["id"]: model for model in model_list}
+
     data = form_data
     model_id = data["model"]
-    if model_id not in app.state.MODELS:
+
+    if model_id not in models:
         raise HTTPException(
             status_code=status.HTTP_404_NOT_FOUND,
             detail="Model not found",
         )
-    model = app.state.MODELS[model_id]
+    model = models[model_id]
 
     __event_emitter__ = get_event_emitter(
         {
@@ -1543,8 +1578,11 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
 async def generate_title(form_data: dict, user=Depends(get_verified_user)):
     print("generate_title")
 
+    model_list = await get_all_models()
+    models = {model["id"]: model for model in model_list}
+
     model_id = form_data["model"]
-    if model_id not in app.state.MODELS:
+    if model_id not in models:
         raise HTTPException(
             status_code=status.HTTP_404_NOT_FOUND,
             detail="Model not found",
@@ -1552,10 +1590,16 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
 
     # Check if the user has a custom task model
     # If the user has a custom task model, use that model
-    task_model_id = get_task_model_id(model_id)
+    task_model_id = get_task_model_id(
+        model_id,
+        app.state.config.TASK_MODEL,
+        app.state.config.TASK_MODEL_EXTERNAL,
+        models,
+    )
+
     print(task_model_id)
 
-    model = app.state.MODELS[task_model_id]
+    model = models[task_model_id]
 
     if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
         template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
@@ -1589,7 +1633,7 @@ Artificial Intelligence in Healthcare
         "stream": False,
         **(
             {"max_tokens": 50}
-            if app.state.MODELS[task_model_id]["owned_by"] == "ollama"
+            if models[task_model_id]["owned_by"] == "ollama"
             else {
                 "max_completion_tokens": 50,
             }
@@ -1601,7 +1645,7 @@ Artificial Intelligence in Healthcare
 
     # Handle pipeline filters
     try:
-        payload = filter_pipeline(payload, user)
+        payload = filter_pipeline(payload, user, models)
     except Exception as e:
         if len(e.args) > 1:
             return JSONResponse(
@@ -1628,8 +1672,11 @@ async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)):
             content={"detail": "Tags generation is disabled"},
         )
 
+    model_list = await get_all_models()
+    models = {model["id"]: model for model in model_list}
+
     model_id = form_data["model"]
-    if model_id not in app.state.MODELS:
+    if model_id not in models:
         raise HTTPException(
             status_code=status.HTTP_404_NOT_FOUND,
             detail="Model not found",
@@ -1637,7 +1684,12 @@ async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)):
 
     # Check if the user has a custom task model
     # If the user has a custom task model, use that model
-    task_model_id = get_task_model_id(model_id)
+    task_model_id = get_task_model_id(
+        model_id,
+        app.state.config.TASK_MODEL,
+        app.state.config.TASK_MODEL_EXTERNAL,
+        models,
+    )
     print(task_model_id)
 
     if app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "":
@@ -1675,7 +1727,7 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] }
 
     # Handle pipeline filters
     try:
-        payload = filter_pipeline(payload, user)
+        payload = filter_pipeline(payload, user, models)
     except Exception as e:
         if len(e.args) > 1:
             return JSONResponse(
@@ -1702,8 +1754,11 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
             detail=f"Search query generation is disabled",
         )
 
+    model_list = await get_all_models()
+    models = {model["id"]: model for model in model_list}
+
     model_id = form_data["model"]
-    if model_id not in app.state.MODELS:
+    if model_id not in models:
         raise HTTPException(
             status_code=status.HTTP_404_NOT_FOUND,
             detail="Model not found",
@@ -1711,10 +1766,15 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
 
     # Check if the user has a custom task model
     # If the user has a custom task model, use that model
-    task_model_id = get_task_model_id(model_id)
+    task_model_id = get_task_model_id(
+        model_id,
+        app.state.config.TASK_MODEL,
+        app.state.config.TASK_MODEL_EXTERNAL,
+        models,
+    )
     print(task_model_id)
 
-    model = app.state.MODELS[task_model_id]
+    model = models[task_model_id]
 
     if app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE != "":
         template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
@@ -1741,7 +1801,7 @@ Search Query:"""
         "stream": False,
         **(
             {"max_tokens": 30}
-            if app.state.MODELS[task_model_id]["owned_by"] == "ollama"
+            if models[task_model_id]["owned_by"] == "ollama"
             else {
                 "max_completion_tokens": 30,
             }
@@ -1752,7 +1812,7 @@ Search Query:"""
 
     # Handle pipeline filters
     try:
-        payload = filter_pipeline(payload, user)
+        payload = filter_pipeline(payload, user, models)
     except Exception as e:
         if len(e.args) > 1:
             return JSONResponse(
@@ -1774,8 +1834,11 @@ Search Query:"""
 async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
     print("generate_emoji")
 
+    model_list = await get_all_models()
+    models = {model["id"]: model for model in model_list}
+
     model_id = form_data["model"]
-    if model_id not in app.state.MODELS:
+    if model_id not in models:
         raise HTTPException(
             status_code=status.HTTP_404_NOT_FOUND,
             detail="Model not found",
@@ -1783,10 +1846,15 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
 
     # Check if the user has a custom task model
     # If the user has a custom task model, use that model
-    task_model_id = get_task_model_id(model_id)
+    task_model_id = get_task_model_id(
+        model_id,
+        app.state.config.TASK_MODEL,
+        app.state.config.TASK_MODEL_EXTERNAL,
+        models,
+    )
     print(task_model_id)
 
-    model = app.state.MODELS[task_model_id]
+    model = models[task_model_id]
 
     template = '''
 Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
@@ -1808,7 +1876,7 @@ Message: """{{prompt}}"""
         "stream": False,
         **(
             {"max_tokens": 4}
-            if app.state.MODELS[task_model_id]["owned_by"] == "ollama"
+            if models[task_model_id]["owned_by"] == "ollama"
             else {
                 "max_completion_tokens": 4,
             }
@@ -1820,7 +1888,7 @@ Message: """{{prompt}}"""
 
     # Handle pipeline filters
     try:
-        payload = filter_pipeline(payload, user)
+        payload = filter_pipeline(payload, user, models)
     except Exception as e:
         if len(e.args) > 1:
             return JSONResponse(
@@ -1842,8 +1910,11 @@ Message: """{{prompt}}"""
 async def generate_moa_response(form_data: dict, user=Depends(get_verified_user)):
     print("generate_moa_response")
 
+    model_list = await get_all_models()
+    models = {model["id"]: model for model in model_list}
+
     model_id = form_data["model"]
-    if model_id not in app.state.MODELS:
+    if model_id not in models:
         raise HTTPException(
             status_code=status.HTTP_404_NOT_FOUND,
             detail="Model not found",
@@ -1851,10 +1922,15 @@ async def generate_moa_response(form_data: dict, user=Depends(get_verified_user)
 
     # Check if the user has a custom task model
     # If the user has a custom task model, use that model
-    task_model_id = get_task_model_id(model_id)
+    task_model_id = get_task_model_id(
+        model_id,
+        app.state.config.TASK_MODEL,
+        app.state.config.TASK_MODEL_EXTERNAL,
+        models,
+    )
     print(task_model_id)
 
-    model = app.state.MODELS[task_model_id]
+    model = models[task_model_id]
 
     template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}"
 
@@ -1881,7 +1957,7 @@ Responses from models: {{responses}}"""
     log.debug(payload)
 
     try:
-        payload = filter_pipeline(payload, user)
+        payload = filter_pipeline(payload, user, models)
     except Exception as e:
         if len(e.args) > 1:
             return JSONResponse(
@@ -1911,7 +1987,7 @@ Responses from models: {{responses}}"""
 
 @app.get("/api/pipelines/list")
 async def get_pipelines_list(user=Depends(get_admin_user)):
-    responses = await get_openai_models(raw=True)
+    responses = await get_openai_models_responses()
 
     print(responses)
     urlIdxs = [

+ 4 - 3
backend/open_webui/utils/utils.py

@@ -192,15 +192,16 @@ def has_permission(
 
 def has_access(
     user_id: str,
-    action: str = "write",
+    type: str = "write",
     access_control: Optional[dict] = None,
 ) -> bool:
+    print("user_id", user_id, "type", type, "access_control", access_control)
     if access_control is None:
-        return action == "read"
+        return type == "read"
 
     user_groups = Groups.get_groups_by_member_id(user_id)
     user_group_ids = [group.id for group in user_groups]
-    permission_access = access_control.get(action, {})
+    permission_access = access_control.get(type, {})
     permitted_group_ids = permission_access.get("group_ids", [])
     permitted_user_ids = permission_access.get("user_ids", [])
 

+ 1 - 0
backend/requirements.txt

@@ -13,6 +13,7 @@ passlib[bcrypt]==1.7.4
 requests==2.32.3
 aiohttp==3.10.8
 async-timeout
+aiocache
 
 sqlalchemy==2.0.32
 alembic==1.13.2

+ 1 - 0
pyproject.toml

@@ -21,6 +21,7 @@ dependencies = [
     "requests==2.32.3",
     "aiohttp==3.10.8",
     "async-timeout",
+    "aiocache",
 
     "sqlalchemy==2.0.32",
     "alembic==1.13.2",

+ 1 - 1
src/lib/components/admin/Settings/Models.svelte

@@ -71,7 +71,7 @@
 	const upsertModelHandler = async (model) => {
 		model.base_model_id = null;
 
-		if (models.find((m) => m.id === model.id)) {
+		if (workspaceModels.find((m) => m.id === model.id)) {
 			await updateModelById(localStorage.token, model.id, model).catch((error) => {
 				return null;
 			});