Timothy Jaeryang Baek 4 miesięcy temu
rodzic
commit
a07ff56c50
2 zmienionych plików z 55 dodań i 50 usunięć
  1. 20 14
      backend/open_webui/main.py
  2. 35 36
      backend/open_webui/routers/openai.py

+ 20 - 14
backend/open_webui/main.py

@@ -1009,9 +1009,12 @@ async def get_body_and_model_and_user(request, models):
 
 class ChatCompletionMiddleware(BaseHTTPMiddleware):
     async def dispatch(self, request: Request, call_next):
-        if not request.method == "POST" and any(
-            endpoint in request.url.path
-            for endpoint in ["/ollama/api/chat", "/chat/completions"]
+        if not (
+            request.method == "POST"
+            and any(
+                endpoint in request.url.path
+                for endpoint in ["/ollama/api/chat", "/chat/completions"]
+            )
         ):
             return await call_next(request)
         log.debug(f"request.url.path: {request.url.path}")
@@ -1214,9 +1217,12 @@ app.add_middleware(ChatCompletionMiddleware)
 
 class PipelineMiddleware(BaseHTTPMiddleware):
     async def dispatch(self, request: Request, call_next):
-        if not request.method == "POST" and any(
-            endpoint in request.url.path
-            for endpoint in ["/ollama/api/chat", "/chat/completions"]
+        if not (
+            request.method == "POST"
+            and any(
+                endpoint in request.url.path
+                for endpoint in ["/ollama/api/chat", "/chat/completions"]
+            )
         ):
             return await call_next(request)
 
@@ -1664,17 +1670,17 @@ async def generate_function_chat_completion(form_data, user, models: dict = {}):
         return openai_chat_completion_message_template(form_data["model"], message)
 
 
-async def get_all_base_models():
+async def get_all_base_models(request):
     function_models = []
     openai_models = []
     ollama_models = []
 
     if app.state.config.ENABLE_OPENAI_API:
-        openai_models = await openai.get_all_models()
+        openai_models = await openai.get_all_models(request)
         openai_models = openai_models["data"]
 
     if app.state.config.ENABLE_OLLAMA_API:
-        ollama_models = await ollama.get_all_models()
+        ollama_models = await ollama.get_all_models(request)
         ollama_models = [
             {
                 "id": model["model"],
@@ -1729,8 +1735,8 @@ async def get_all_base_models():
 
 
 @cached(ttl=3)
-async def get_all_models():
-    models = await get_all_base_models()
+async def get_all_models(request):
+    models = await get_all_base_models(request)
 
     # If there are no models, return an empty list
     if len([model for model in models if not model.get("arena", False)]) == 0:
@@ -1859,8 +1865,8 @@ async def get_all_models():
 
 
 @app.get("/api/models")
-async def get_models(user=Depends(get_verified_user)):
-    models = await get_all_models()
+async def get_models(request: Request, user=Depends(get_verified_user)):
+    models = await get_all_models(request)
 
     # Filter out filter pipelines
     models = [
@@ -2042,7 +2048,7 @@ async def generate_chat_completions(
 async def chat_completed(
     request: Request, form_data: dict, user=Depends(get_verified_user)
 ):
-    model_list = await get_all_models()
+    model_list = await get_all_models(request)
     models = {model["id"]: model for model in model_list}
 
     data = form_data

+ 35 - 36
backend/open_webui/routers/openai.py

@@ -245,41 +245,6 @@ async def speech(request: Request, user=Depends(get_verified_user)):
         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.OPENAI_NOT_FOUND)
 
 
-def merge_models_lists(model_lists):
-    log.debug(f"merge_models_lists {model_lists}")
-    merged_list = []
-
-    for idx, models in enumerate(model_lists):
-        if models is not None and "error" not in models:
-            merged_list.extend(
-                [
-                    {
-                        **model,
-                        "name": model.get("name", model["id"]),
-                        "owned_by": "openai",
-                        "openai": model,
-                        "urlIdx": idx,
-                    }
-                    for model in models
-                    if "api.openai.com"
-                    not in request.app.state.config.OPENAI_API_BASE_URLS[idx]
-                    or not any(
-                        name in model["id"]
-                        for name in [
-                            "babbage",
-                            "dall-e",
-                            "davinci",
-                            "embedding",
-                            "tts",
-                            "whisper",
-                        ]
-                    )
-                ]
-            )
-
-    return merged_list
-
-
 async def get_all_models_responses(request: Request) -> list:
     if not request.app.state.config.ENABLE_OPENAI_API:
         return []
@@ -379,7 +344,7 @@ async def get_all_models(request: Request) -> dict[str, list]:
     if not request.app.state.config.ENABLE_OPENAI_API:
         return {"data": []}
 
-    responses = await get_all_models_responses()
+    responses = await get_all_models_responses(request)
 
     def extract_data(response):
         if response and "data" in response:
@@ -388,6 +353,40 @@ async def get_all_models(request: Request) -> dict[str, list]:
             return response
         return None
 
+    def merge_models_lists(model_lists):
+        log.debug(f"merge_models_lists {model_lists}")
+        merged_list = []
+
+        for idx, models in enumerate(model_lists):
+            if models is not None and "error" not in models:
+                merged_list.extend(
+                    [
+                        {
+                            **model,
+                            "name": model.get("name", model["id"]),
+                            "owned_by": "openai",
+                            "openai": model,
+                            "urlIdx": idx,
+                        }
+                        for model in models
+                        if "api.openai.com"
+                        not in request.app.state.config.OPENAI_API_BASE_URLS[idx]
+                        or not any(
+                            name in model["id"]
+                            for name in [
+                                "babbage",
+                                "dall-e",
+                                "davinci",
+                                "embedding",
+                                "tts",
+                                "whisper",
+                            ]
+                        )
+                    ]
+                )
+
+        return merged_list
+
     models = {"data": merge_models_lists(map(extract_data, responses))}
     log.debug(f"models: {models}")