|
@@ -981,11 +981,20 @@ async def get_models(user=Depends(get_verified_user)):
|
|
|
@app.post("/api/chat/completions")
|
|
|
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
|
|
|
model_id = form_data["model"]
|
|
|
+
|
|
|
if model_id not in app.state.MODELS:
|
|
|
raise HTTPException(
|
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
|
detail="Model not found",
|
|
|
)
|
|
|
+
|
|
|
+ if app.state.config.ENABLE_MODEL_FILTER:
|
|
|
+ if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_403_FORBIDDEN,
|
|
|
+ detail="Model not found",
|
|
|
+ )
|
|
|
+
|
|
|
model = app.state.MODELS[model_id]
|
|
|
if model.get("pipe"):
|
|
|
return await generate_function_chat_completion(form_data, user=user)
|