浏览代码

refac: openai /models

Timothy Jaeryang Baek 5 月之前
父节点
当前提交
b82e25cac8
共有 1 个文件被更改,包括 43 次插入36 次删除
  1. 43 36
      backend/open_webui/apps/openai/main.py

+ 43 - 36
backend/open_webui/apps/openai/main.py

@@ -342,46 +342,53 @@ async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_us
 
         r = None
 
-        try:
-            r = requests.request(method="GET", url=f"{url}/models", headers=headers)
-            r.raise_for_status()
 
-            response_data = r.json()
-
-            if "api.openai.com" in url:
-                # Filter the response data
-                response_data["data"] = [
-                    model
-                    for model in response_data["data"]
-                    if not any(
-                        name in model["id"]
-                        for name in [
-                            "babbage",
-                            "dall-e",
-                            "davinci",
-                            "embedding",
-                            "tts",
-                            "whisper",
+        timeout = aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT_OPENAI_MODEL_LIST)
+        async with aiohttp.ClientSession(timeout=timeout) as session:
+            try:
+                async with session.get(f"{url}/models", headers=headers) as r:
+                    if r.status != 200:
+                        # Extract response error details if available
+                        error_detail = f"HTTP Error: {r.status}"
+                        res = await r.json()
+                        if "error" in res:
+                            error_detail = f"External Error: {res['error']}"
+                        raise Exception(error_detail)
+
+                    response_data = await r.json()
+
+                    # Check if we're calling OpenAI API based on the URL
+                    if "api.openai.com" in url:
+                        # Filter models according to the specified conditions
+                        response_data["data"] = [
+                            model for model in response_data.get("data", [])
+                            if not any(
+                                name in model["id"]
+                                for name in [
+                                    "babbage",
+                                    "dall-e",
+                                    "davinci",
+                                    "embedding",
+                                    "tts",
+                                    "whisper",
+                                ]
+                            )
                         ]
-                    )
-                ]
 
-            return response_data
-        except Exception as e:
-            log.exception(e)
-            error_detail = "Open WebUI: Server Connection Error"
-            if r is not None:
-                try:
-                    res = r.json()
-                    if "error" in res:
-                        error_detail = f"External: {res['error']}"
-                except Exception:
-                    error_detail = f"External: {e}"
+                    return response_data
+
+            except aiohttp.ClientError as e:
+                # ClientError covers all aiohttp requests issues
+                log.exception(f"Client error: {str(e)}")
+                # Handle aiohttp-specific connection issues, timeout etc.
+                raise HTTPException(status_code=500, detail="Open WebUI: Server Connection Error")
+            except Exception as e:
+                log.exception(f"Unexpected error: {e}")
+                # Generic error handler in case parsing JSON or other steps fail
+                error_detail = f"Unexpected error: {str(e)}"
+                raise HTTPException(status_code=500, detail=error_detail)
+
 
-            raise HTTPException(
-                status_code=r.status_code if r else 500,
-                detail=error_detail,
-            )
 
 
 @app.post("/chat/completions")