Browse Source

refac: max_tokens -> max_completion_tokens

Timothy J. Baek 7 months ago
parent
commit
f8fffdd288
1 changed files with 30 additions and 14 deletions
  1. 30 14
      backend/open_webui/main.py

+ 30 - 14
backend/open_webui/main.py

@@ -1398,9 +1398,9 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
 
     # Check if the user has a custom task model
     # If the user has a custom task model, use that model
-    model_id = get_task_model_id(model_id)
+    task_model_id = get_task_model_id(model_id)
 
-    print(model_id)
+    print(task_model_id)
 
     if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
         template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
@@ -1427,10 +1427,16 @@ Prompt: {{prompt:middletruncate:8000}}"""
     )
 
     payload = {
-        "model": model_id,
+        "model": task_model_id,
         "messages": [{"role": "user", "content": content}],
         "stream": False,
-        "max_tokens": 50,
+        **(
+            {"max_tokens": 50}
+            if app.state.MODELS[task_model_id]["owned_by"] == "ollama"
+            else {
+                "max_completion_tokens": 50,
+            }
+        ),
         "chat_id": form_data.get("chat_id", None),
         "metadata": {"task": str(TASKS.TITLE_GENERATION)},
     }
@@ -1475,9 +1481,8 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
 
     # Check if the user has a custom task model
     # If the user has a custom task model, use that model
-    model_id = get_task_model_id(model_id)
-
-    print(model_id)
+    task_model_id = get_task_model_id(model_id)
+    print(task_model_id)
 
     if app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE != "":
         template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
@@ -1499,10 +1504,16 @@ Search Query:"""
     print("content", content)
 
     payload = {
-        "model": model_id,
+        "model": task_model_id,
         "messages": [{"role": "user", "content": content}],
         "stream": False,
-        "max_tokens": 30,
+        **(
+            {"max_tokens": 30}
+            if app.state.MODELS[task_model_id]["owned_by"] == "ollama"
+            else {
+                "max_completion_tokens": 30,
+            }
+        ),
         "metadata": {"task": str(TASKS.QUERY_GENERATION)},
     }
 
@@ -1541,9 +1552,8 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
 
     # Check if the user has a custom task model
     # If the user has a custom task model, use that model
-    model_id = get_task_model_id(model_id)
-
-    print(model_id)
+    task_model_id = get_task_model_id(model_id)
+    print(task_model_id)
 
     template = '''
 Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
@@ -1561,10 +1571,16 @@ Message: """{{prompt}}"""
     )
 
     payload = {
-        "model": model_id,
+        "model": task_model_id,
         "messages": [{"role": "user", "content": content}],
         "stream": False,
-        "max_tokens": 4,
+        **(
+            {"max_tokens": 4}
+            if app.state.MODELS[task_model_id]["owned_by"] == "ollama"
+            else {
+                "max_completion_tokens": 4,
+            }
+        ),
         "chat_id": form_data.get("chat_id", None),
         "metadata": {"task": str(TASKS.EMOJI_GENERATION)},
     }