浏览代码

refac: use get_task_model_id()

Michael Poluektov 10 月之前
父节点
当前提交
1d20c27553
共有 1 个文件被更改,包括 4 次插入40 次删除
  1. 4 40
      backend/main.py

+ 4 - 40
backend/main.py

@@ -1293,16 +1293,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
 
     # Check if the user has a custom task model
     # If the user has a custom task model, use that model
-    if app.state.MODELS[model_id]["owned_by"] == "ollama":
-        if app.state.config.TASK_MODEL:
-            task_model_id = app.state.config.TASK_MODEL
-            if task_model_id in app.state.MODELS:
-                model_id = task_model_id
-    else:
-        if app.state.config.TASK_MODEL_EXTERNAL:
-            task_model_id = app.state.config.TASK_MODEL_EXTERNAL
-            if task_model_id in app.state.MODELS:
-                model_id = task_model_id
+    model_id = get_task_model_id(model_id)
 
     print(model_id)
 
@@ -1361,16 +1352,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
 
     # Check if the user has a custom task model
     # If the user has a custom task model, use that model
-    if app.state.MODELS[model_id]["owned_by"] == "ollama":
-        if app.state.config.TASK_MODEL:
-            task_model_id = app.state.config.TASK_MODEL
-            if task_model_id in app.state.MODELS:
-                model_id = task_model_id
-    else:
-        if app.state.config.TASK_MODEL_EXTERNAL:
-            task_model_id = app.state.config.TASK_MODEL_EXTERNAL
-            if task_model_id in app.state.MODELS:
-                model_id = task_model_id
+    model_id = get_task_model_id(model_id)
 
     print(model_id)
 
@@ -1417,16 +1399,7 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
 
     # Check if the user has a custom task model
     # If the user has a custom task model, use that model
-    if app.state.MODELS[model_id]["owned_by"] == "ollama":
-        if app.state.config.TASK_MODEL:
-            task_model_id = app.state.config.TASK_MODEL
-            if task_model_id in app.state.MODELS:
-                model_id = task_model_id
-    else:
-        if app.state.config.TASK_MODEL_EXTERNAL:
-            task_model_id = app.state.config.TASK_MODEL_EXTERNAL
-            if task_model_id in app.state.MODELS:
-                model_id = task_model_id
+    model_id = get_task_model_id(model_id)
 
     print(model_id)
 
@@ -1483,16 +1456,7 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
 
     # Check if the user has a custom task model
     # If the user has a custom task model, use that model
-    if app.state.MODELS[model_id]["owned_by"] == "ollama":
-        if app.state.config.TASK_MODEL:
-            task_model_id = app.state.config.TASK_MODEL
-            if task_model_id in app.state.MODELS:
-                model_id = task_model_id
-    else:
-        if app.state.config.TASK_MODEL_EXTERNAL:
-            task_model_id = app.state.config.TASK_MODEL_EXTERNAL
-            if task_model_id in app.state.MODELS:
-                model_id = task_model_id
+    model_id = get_task_model_id(model_id)
 
     print(model_id)
     template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE