Browse Source

refac: litellm model name validation

Timothy J. Baek 1 year ago
parent
commit
4651db8c09
1 changed files with 18 additions and 8 deletions
  1. 18 8
      backend/apps/litellm/main.py

+ 18 - 8
backend/apps/litellm/main.py

@@ -12,7 +12,7 @@ import json
 import time
 import requests
 
-from pydantic import BaseModel
+from pydantic import BaseModel, ConfigDict
 from typing import Optional, List
 
 from utils.utils import get_verified_user, get_current_user, get_admin_user
@@ -25,6 +25,7 @@ log.setLevel(SRC_LOG_LEVELS["LITELLM"])
 
 from config import MODEL_FILTER_ENABLED, MODEL_FILTER_LIST, DATA_DIR
 
+from litellm.utils import get_llm_provider
 
 import asyncio
 import subprocess
@@ -165,6 +166,8 @@ class LiteLLMConfigForm(BaseModel):
     model_list: Optional[List[dict]] = None
     router_settings: Optional[dict] = None
 
+    model_config = ConfigDict(protected_namespaces=())
+
 
 @app.post("/config/update")
 async def update_config(form_data: LiteLLMConfigForm, user=Depends(get_admin_user)):
@@ -236,21 +239,28 @@ class AddLiteLLMModelForm(BaseModel):
     model_name: str
     litellm_params: dict
 
+    model_config = ConfigDict(protected_namespaces=())
+
 
 @app.post("/model/new")
 async def add_model_to_config(
     form_data: AddLiteLLMModelForm, user=Depends(get_admin_user)
 ):
-    # TODO: Validate model form
+    try:
+        get_llm_provider(model=form_data.model_name)
+        app.state.CONFIG["model_list"].append(form_data.model_dump())
 
-    app.state.CONFIG["model_list"].append(form_data.model_dump())
+        with open(LITELLM_CONFIG_DIR, "w") as file:
+            yaml.dump(app.state.CONFIG, file)
 
-    with open(LITELLM_CONFIG_DIR, "w") as file:
-        yaml.dump(app.state.CONFIG, file)
+        await restart_litellm()
 
-    await restart_litellm()
-
-    return {"message": MESSAGES.MODEL_ADDED(form_data.model_name)}
+        return {"message": MESSAGES.MODEL_ADDED(form_data.model_name)}
+    except Exception as e:
+        print(e)
+        raise HTTPException(
+            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
+        )
 
 
 class DeleteLiteLLMModelForm(BaseModel):