Timothy Jaeryang Baek 4 months ago
parent
commit
866c3dff11

+ 26 - 12
backend/open_webui/main.py

@@ -70,6 +70,15 @@ from open_webui.routers import (
     users,
     users,
     utils,
     utils,
 )
 )
+
+from open_webui.routers.openai import (
+    generate_chat_completion as generate_openai_chat_completion,
+)
+
+from open_webui.routers.ollama import (
+    generate_chat_completion as generate_ollama_chat_completion,
+)
+
 from open_webui.routers.retrieval import (
 from open_webui.routers.retrieval import (
     get_embedding_function,
     get_embedding_function,
     get_ef,
     get_ef,
@@ -1019,8 +1028,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
             return await call_next(request)
             return await call_next(request)
         log.debug(f"request.url.path: {request.url.path}")
         log.debug(f"request.url.path: {request.url.path}")
 
 
-        model_list = await get_all_models()
-        models = {model["id"]: model for model in model_list}
+        await get_all_models(request)
+        models = app.state.MODELS
 
 
         try:
         try:
             body, model, user = await get_body_and_model_and_user(request, models)
             body, model, user = await get_body_and_model_and_user(request, models)
@@ -1257,7 +1266,7 @@ class PipelineMiddleware(BaseHTTPMiddleware):
                 content={"detail": e.detail},
                 content={"detail": e.detail},
             )
             )
 
 
-        await get_all_models()
+        await get_all_models(request)
         models = app.state.MODELS
         models = app.state.MODELS
 
 
         try:
         try:
@@ -1924,6 +1933,7 @@ async def get_base_models(request: Request, user=Depends(get_admin_user)):
 
 
 @app.post("/api/chat/completions")
 @app.post("/api/chat/completions")
 async def generate_chat_completions(
 async def generate_chat_completions(
+    request: Request,
     form_data: dict,
     form_data: dict,
     user=Depends(get_verified_user),
     user=Depends(get_verified_user),
     bypass_filter: bool = False,
     bypass_filter: bool = False,
@@ -1931,8 +1941,7 @@ async def generate_chat_completions(
     if BYPASS_MODEL_ACCESS_CONTROL:
     if BYPASS_MODEL_ACCESS_CONTROL:
         bypass_filter = True
         bypass_filter = True
 
 
-    model_list = app.state.MODELS
-    models = {model["id"]: model for model in model_list}
+    models = app.state.MODELS
 
 
     model_id = form_data["model"]
     model_id = form_data["model"]
     if model_id not in models:
     if model_id not in models:
@@ -1981,7 +1990,7 @@ async def generate_chat_completions(
         if model_ids and filter_mode == "exclude":
         if model_ids and filter_mode == "exclude":
             model_ids = [
             model_ids = [
                 model["id"]
                 model["id"]
-                for model in await get_all_models()
+                for model in await get_all_models(request)
                 if model.get("owned_by") != "arena" and model["id"] not in model_ids
                 if model.get("owned_by") != "arena" and model["id"] not in model_ids
             ]
             ]
 
 
@@ -1991,7 +2000,7 @@ async def generate_chat_completions(
         else:
         else:
             model_ids = [
             model_ids = [
                 model["id"]
                 model["id"]
-                for model in await get_all_models()
+                for model in await get_all_models(request)
                 if model.get("owned_by") != "arena"
                 if model.get("owned_by") != "arena"
             ]
             ]
             selected_model_id = random.choice(model_ids)
             selected_model_id = random.choice(model_ids)
@@ -2028,6 +2037,7 @@ async def generate_chat_completions(
         # Using /ollama/api/chat endpoint
         # Using /ollama/api/chat endpoint
         form_data = convert_payload_openai_to_ollama(form_data)
         form_data = convert_payload_openai_to_ollama(form_data)
         response = await generate_ollama_chat_completion(
         response = await generate_ollama_chat_completion(
+            request=request,
             form_data=form_data, user=user, bypass_filter=bypass_filter
             form_data=form_data, user=user, bypass_filter=bypass_filter
         )
         )
         if form_data.stream:
         if form_data.stream:
@@ -2040,6 +2050,8 @@ async def generate_chat_completions(
             return convert_response_ollama_to_openai(response)
             return convert_response_ollama_to_openai(response)
     else:
     else:
         return await generate_openai_chat_completion(
         return await generate_openai_chat_completion(
+            request=request,
+
             form_data, user=user, bypass_filter=bypass_filter
             form_data, user=user, bypass_filter=bypass_filter
         )
         )
 
 
@@ -2048,8 +2060,8 @@ async def generate_chat_completions(
 async def chat_completed(
 async def chat_completed(
     request: Request, form_data: dict, user=Depends(get_verified_user)
     request: Request, form_data: dict, user=Depends(get_verified_user)
 ):
 ):
-    model_list = await get_all_models(request)
-    models = {model["id"]: model for model in model_list}
+    await get_all_models(request)
+    models = app.state.MODELS
 
 
     data = form_data
     data = form_data
     model_id = data["model"]
     model_id = data["model"]
@@ -2183,7 +2195,9 @@ async def chat_completed(
 
 
 
 
 @app.post("/api/chat/actions/{action_id}")
 @app.post("/api/chat/actions/{action_id}")
-async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified_user)):
+async def chat_action(
+    request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user)
+):
     if "." in action_id:
     if "." in action_id:
         action_id, sub_action_id = action_id.split(".")
         action_id, sub_action_id = action_id.split(".")
     else:
     else:
@@ -2196,8 +2210,8 @@ async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified
             detail="Action not found",
             detail="Action not found",
         )
         )
 
 
-    model_list = await get_all_models()
-    models = {model["id"]: model for model in model_list}
+    await get_all_models(request)
+    models = app.state.MODELS
 
 
     data = form_data
     data = form_data
     model_id = data["model"]
     model_id = data["model"]

+ 9 - 9
backend/open_webui/routers/ollama.py

@@ -344,7 +344,7 @@ async def get_ollama_tags(
     models = []
     models = []
 
 
     if url_idx is None:
     if url_idx is None:
-        models = await get_all_models()
+        models = await get_all_models(request)
     else:
     else:
         url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
         url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
         key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS)
         key = get_api_key(url, request.app.state.config.OLLAMA_API_CONFIGS)
@@ -565,7 +565,7 @@ async def copy_model(
     user=Depends(get_admin_user),
     user=Depends(get_admin_user),
 ):
 ):
     if url_idx is None:
     if url_idx is None:
-        await get_all_models()
+        await get_all_models(request)
         models = request.app.state.OLLAMA_MODELS
         models = request.app.state.OLLAMA_MODELS
 
 
         if form_data.source in models:
         if form_data.source in models:
@@ -620,7 +620,7 @@ async def delete_model(
     user=Depends(get_admin_user),
     user=Depends(get_admin_user),
 ):
 ):
     if url_idx is None:
     if url_idx is None:
-        await get_all_models()
+        await get_all_models(request)
         models = request.app.state.OLLAMA_MODELS
         models = request.app.state.OLLAMA_MODELS
 
 
         if form_data.name in models:
         if form_data.name in models:
@@ -670,7 +670,7 @@ async def delete_model(
 async def show_model_info(
 async def show_model_info(
     request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
     request: Request, form_data: ModelNameForm, user=Depends(get_verified_user)
 ):
 ):
-    await get_all_models()
+    await get_all_models(request)
     models = request.app.state.OLLAMA_MODELS
     models = request.app.state.OLLAMA_MODELS
 
 
     if form_data.name not in models:
     if form_data.name not in models:
@@ -734,7 +734,7 @@ async def embed(
     log.info(f"generate_ollama_batch_embeddings {form_data}")
     log.info(f"generate_ollama_batch_embeddings {form_data}")
 
 
     if url_idx is None:
     if url_idx is None:
-        await get_all_models()
+        await get_all_models(request)
         models = request.app.state.OLLAMA_MODELS
         models = request.app.state.OLLAMA_MODELS
 
 
         model = form_data.model
         model = form_data.model
@@ -803,7 +803,7 @@ async def embeddings(
     log.info(f"generate_ollama_embeddings {form_data}")
     log.info(f"generate_ollama_embeddings {form_data}")
 
 
     if url_idx is None:
     if url_idx is None:
-        await get_all_models()
+        await get_all_models(request)
         models = request.app.state.OLLAMA_MODELS
         models = request.app.state.OLLAMA_MODELS
 
 
         model = form_data.model
         model = form_data.model
@@ -878,8 +878,8 @@ async def generate_completion(
     user=Depends(get_verified_user),
     user=Depends(get_verified_user),
 ):
 ):
     if url_idx is None:
     if url_idx is None:
-        model_list = await get_all_models()
-        models = {model["model"]: model for model in model_list["models"]}
+        await get_all_models(request)
+        models = request.app.state.OLLAMA_MODELS
 
 
         model = form_data.model
         model = form_data.model
 
 
@@ -1200,7 +1200,7 @@ async def get_openai_models(
 
 
     models = []
     models = []
     if url_idx is None:
     if url_idx is None:
-        model_list = await get_all_models()
+        model_list = await get_all_models(request)
         models = [
         models = [
             {
             {
                 "id": model["model"],
                 "id": model["model"],

+ 1 - 1
backend/open_webui/routers/openai.py

@@ -404,7 +404,7 @@ async def get_models(
     }
     }
 
 
     if url_idx is None:
     if url_idx is None:
-        models = await get_all_models()
+        models = await get_all_models(request)
     else:
     else:
         url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
         url = request.app.state.config.OPENAI_API_BASE_URLS[url_idx]
         key = request.app.state.config.OPENAI_API_KEYS[url_idx]
         key = request.app.state.config.OPENAI_API_KEYS[url_idx]