Browse Source

fix: multiple openai issue

Timothy J. Baek 1 year ago
parent
commit
1bfcd801b7
3 changed files with 31 additions and 23 deletions
  1. 9 9
      backend/apps/ollama/main.py
  2. 18 12
      backend/apps/openai/main.py
  3. 4 2
      backend/config.py

+ 9 - 9
backend/apps/ollama/main.py

@@ -98,13 +98,14 @@ def merge_models_lists(model_lists):
     merged_models = {}
     merged_models = {}
 
 
     for idx, model_list in enumerate(model_lists):
     for idx, model_list in enumerate(model_lists):
-        for model in model_list:
-            digest = model["digest"]
-            if digest not in merged_models:
-                model["urls"] = [idx]
-                merged_models[digest] = model
-            else:
-                merged_models[digest]["urls"].append(idx)
+        if model_list is not None:
+            for model in model_list:
+                digest = model["digest"]
+                if digest not in merged_models:
+                    model["urls"] = [idx]
+                    merged_models[digest] = model
+                else:
+                    merged_models[digest]["urls"].append(idx)
 
 
     return list(merged_models.values())
     return list(merged_models.values())
 
 
@@ -116,11 +117,10 @@ async def get_all_models():
     print("get_all_models")
     print("get_all_models")
     tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS]
     tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS]
     responses = await asyncio.gather(*tasks)
     responses = await asyncio.gather(*tasks)
-    responses = list(filter(lambda x: x is not None, responses))
 
 
     models = {
     models = {
         "models": merge_models_lists(
         "models": merge_models_lists(
-            map(lambda response: response["models"], responses)
+            map(lambda response: response["models"] if response else None, responses)
         )
         )
     }
     }
 
 

+ 18 - 12
backend/apps/openai/main.py

@@ -168,14 +168,15 @@ def merge_models_lists(model_lists):
     merged_list = []
     merged_list = []
 
 
     for idx, models in enumerate(model_lists):
     for idx, models in enumerate(model_lists):
-        merged_list.extend(
-            [
-                {**model, "urlIdx": idx}
-                for model in models
-                if "api.openai.com" not in app.state.OPENAI_API_BASE_URLS[idx]
-                or "gpt" in model["id"]
-            ]
-        )
+        if models is not None and "error" not in models:
+            merged_list.extend(
+                [
+                    {**model, "urlIdx": idx}
+                    for model in models
+                    if "api.openai.com" not in app.state.OPENAI_API_BASE_URLS[idx]
+                    or "gpt" in model["id"]
+                ]
+            )
 
 
     return merged_list
     return merged_list
 
 
@@ -190,15 +191,20 @@ async def get_all_models():
             fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx])
             fetch_url(f"{url}/models", app.state.OPENAI_API_KEYS[idx])
             for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS)
             for idx, url in enumerate(app.state.OPENAI_API_BASE_URLS)
         ]
         ]
+
         responses = await asyncio.gather(*tasks)
         responses = await asyncio.gather(*tasks)
-        responses = list(
-            filter(lambda x: x is not None and "error" not in x, responses)
-        )
         models = {
         models = {
             "data": merge_models_lists(
             "data": merge_models_lists(
-                list(map(lambda response: response["data"], responses))
+                list(
+                    map(
+                        lambda response: response["data"] if response else None,
+                        responses,
+                    )
+                )
             )
             )
         }
         }
+
+        print(models)
         app.state.MODELS = {model["id"]: model for model in models["data"]}
         app.state.MODELS = {model["id"]: model for model in models["data"]}
 
 
         return models
         return models

+ 4 - 2
backend/config.py

@@ -250,8 +250,10 @@ OPENAI_API_BASE_URLS = (
     OPENAI_API_BASE_URLS if OPENAI_API_BASE_URLS != "" else OPENAI_API_BASE_URL
     OPENAI_API_BASE_URLS if OPENAI_API_BASE_URLS != "" else OPENAI_API_BASE_URL
 )
 )
 
 
-OPENAI_API_BASE_URLS = [url.strip() for url in OPENAI_API_BASE_URLS.split(";")]
-
+OPENAI_API_BASE_URLS = [
+    url.strip() if url != "" else "https://api.openai.com/v1"
+    for url in OPENAI_API_BASE_URLS.split(";")
+]
 
 
 ####################################
 ####################################
 # WEBUI
 # WEBUI