Jelajahi Sumber

refac: undo raw split, remove gpt-4-vision-preview

Michael Poluektov 9 bulan lalu
induk
melakukan
3653126179
2 mengubah file dengan 14 tambahan dan 12 penghapusan
  1. 13 10
      backend/apps/openai/main.py
  2. 1 2
      backend/main.py

+ 13 - 10
backend/apps/openai/main.py

@@ -30,7 +30,7 @@ from config import (
     MODEL_FILTER_LIST,
     AppConfig,
 )
-from typing import List, Optional
+from typing import List, Optional, Literal, overload
 
 
 import hashlib
@@ -262,12 +262,22 @@ async def get_all_models_raw() -> list:
     return responses
 
 
-async def get_all_models() -> dict[str, list]:
+@overload
+async def get_all_models(raw: Literal[True]) -> list: ...
+
+
+@overload
+async def get_all_models(raw: Literal[False] = False) -> dict[str, list]: ...
+
+
+async def get_all_models(raw=False) -> dict[str, list] | list:
     log.info("get_all_models()")
     if is_openai_api_disabled():
-        return {"data": []}
+        return [] if raw else {"data": []}
 
     responses = await get_all_models_raw()
+    if raw:
+        return responses
 
     def extract_data(response):
         if response and "data" in response:
@@ -370,13 +380,6 @@ async def generate_chat_completion(
             "role": user.role,
         }
 
-    # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
-    # This is a workaround until OpenAI fixes the issue with this model
-    if payload.get("model") == "gpt-4-vision-preview":
-        if "max_tokens" not in payload:
-            payload["max_tokens"] = 4000
-        log.debug("Modified payload:", payload)
-
     # Convert the modified body back to JSON
     payload = json.dumps(payload)
 

+ 1 - 2
backend/main.py

@@ -36,7 +36,6 @@ from apps.ollama.main import (
 from apps.openai.main import (
     app as openai_app,
     get_all_models as get_openai_models,
-    get_all_models_raw as get_openai_models_raw,
     generate_chat_completion as generate_openai_chat_completion,
 )
 
@@ -1657,7 +1656,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
 
 @app.get("/api/pipelines/list")
 async def get_pipelines_list(user=Depends(get_admin_user)):
-    responses = await get_openai_models_raw()
+    responses = await get_openai_models(raw = True)
 
     print(responses)
     urlIdxs = [