Przeglądaj źródła

fix: openai proxy

Timothy J. Baek 11 miesięcy temu
rodzic
commit
e427ef767b
2 zmienionych plików z 31 dodań i 17 usunięć
  1. 31 16
      backend/apps/openai/main.py
  2. 0 1
      backend/apps/webui/routers/models.py

+ 31 - 16
backend/apps/openai/main.py

@@ -338,21 +338,36 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
                 model_info.params = model_info.params.model_dump()
 
                 if model_info.params:
-                    payload["temperature"] = model_info.params.get("temperature", None)
-                    payload["top_p"] = model_info.params.get("top_p", None)
-                    payload["max_tokens"] = model_info.params.get("max_tokens", None)
-                    payload["frequency_penalty"] = model_info.params.get(
-                        "frequency_penalty", None
-                    )
-                    payload["seed"] = model_info.params.get("seed", None)
-                    payload["stop"] = (
-                        [
-                            bytes(stop, "utf-8").decode("unicode_escape")
-                            for stop in model_info.params["stop"]
-                        ]
-                        if model_info.params.get("stop", None)
-                        else None
-                    )
+                    if model_info.params.get("temperature", None):
+                        payload["temperature"] = int(
+                            model_info.params.get("temperature")
+                        )
+
+                    if model_info.params.get("top_p", None):
+                        payload["top_p"] = int(model_info.params.get("top_p", None))
+
+                    if model_info.params.get("max_tokens", None):
+                        payload["max_tokens"] = int(
+                            model_info.params.get("max_tokens", None)
+                        )
+
+                    if model_info.params.get("frequency_penalty", None):
+                        payload["frequency_penalty"] = int(
+                            model_info.params.get("frequency_penalty", None)
+                        )
+
+                    if model_info.params.get("seed", None):
+                        payload["seed"] = model_info.params.get("seed", None)
+
+                    if model_info.params.get("stop", None):
+                        payload["stop"] = (
+                            [
+                                bytes(stop, "utf-8").decode("unicode_escape")
+                                for stop in model_info.params["stop"]
+                            ]
+                            if model_info.params.get("stop", None)
+                            else None
+                        )
 
                 if model_info.params.get("system", None):
                     # Check if the payload already has a system message
@@ -376,7 +391,6 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
             else:
                 pass
 
-            print(app.state.MODELS)
             model = app.state.MODELS[payload.get("model")]
 
             idx = model["urlIdx"]
@@ -442,6 +456,7 @@ async def proxy(path: str, request: Request, user=Depends(get_verified_user)):
         if r is not None:
             try:
                 res = r.json()
+                print(res)
                 if "error" in res:
                     error_detail = f"External: {res['error']['message'] if 'message' in res['error'] else res['error']}"
             except:

+ 0 - 1
backend/apps/webui/routers/models.py

@@ -82,7 +82,6 @@ async def update_model_by_id(
     else:
         if form_data.id in request.app.state.MODELS:
             model = Models.insert_new_model(form_data, user.id)
-            print(model)
             if model:
                 return model
             else: