Browse Source

logit_bias: handle comma seperated values

dannyl1u 2 months ago
parent
commit
90aa29528c
2 changed files with 16 additions and 3 deletions
  1. 4 3
      backend/open_webui/utils/middleware.py
  2. 12 0
      backend/open_webui/utils/misc.py

+ 4 - 3
backend/open_webui/utils/middleware.py

@@ -68,6 +68,7 @@ from open_webui.utils.misc import (
     get_last_user_message,
     get_last_user_message,
     get_last_assistant_message,
     get_last_assistant_message,
     prepend_to_first_user_message_content,
     prepend_to_first_user_message_content,
+    convert_logit_bias_input_to_json
 )
 )
 from open_webui.utils.tools import get_tools
 from open_webui.utils.tools import get_tools
 from open_webui.utils.plugin import load_function_module_by_id
 from open_webui.utils.plugin import load_function_module_by_id
@@ -593,9 +594,9 @@ def apply_params_to_form_data(form_data, model):
             form_data["reasoning_effort"] = params["reasoning_effort"]
             form_data["reasoning_effort"] = params["reasoning_effort"]
         if "logit_bias" in params:
         if "logit_bias" in params:
             try:
             try:
-                form_data["logit_bias"] = json.loads(params["logit_bias"])
-            except json.JSONDecodeError:
-                print("Invalid JSON format for logit_bias")
+                form_data["logit_bias"] = json.loads(convert_logit_bias_input_to_json(params["logit_bias"]))
+            except Exception as e:
+                print(f"Error parsing logit_bias: {e}")
 
 
     return form_data
     return form_data
 
 

+ 12 - 0
backend/open_webui/utils/misc.py

@@ -5,6 +5,7 @@ import uuid
 from datetime import timedelta
 from datetime import timedelta
 from pathlib import Path
 from pathlib import Path
 from typing import Callable, Optional
 from typing import Callable, Optional
+import json
 
 
 
 
 import collections.abc
 import collections.abc
@@ -445,3 +446,14 @@ def parse_ollama_modelfile(model_text):
         data["params"]["messages"] = messages
         data["params"]["messages"] = messages
 
 
     return data
     return data
+
+def convert_logit_bias_input_to_json(user_input):
+    logit_bias_pairs = user_input.split(',')
+    logit_bias_json = {}
+    for pair in logit_bias_pairs:
+        token, bias = pair.split(':')
+        token = str(token.strip())
+        bias = int(bias.strip())
+        bias = 100 if bias > 100 else -100 if bias < -100 else bias
+        logit_bias_json[token] = bias
+    return json.dumps(logit_bias_json)