Browse Source

refac: task flag

Co-Authored-By: Michael Poluektov <78477503+michaelpoluektov@users.noreply.github.com>
Timothy J. Baek 10 months ago
parent
commit
c83704d6ca
2 changed files with 16 additions and 8 deletions
  1. 11 0
      backend/constants.py
  2. 5 8
      backend/main.py

+ 11 - 0
backend/constants.py

@@ -89,3 +89,14 @@ class ERROR_MESSAGES(str, Enum):
     OLLAMA_API_DISABLED = (
     OLLAMA_API_DISABLED = (
         "The Ollama API is disabled. Please enable it to use this feature."
         "The Ollama API is disabled. Please enable it to use this feature."
     )
     )
+
+
+class TASKS(str, Enum):
+    def __str__(self) -> str:
+        return super().__str__()
+
+    DEFAULT = lambda task="": f"{task if task else 'default'}"
+    TITLE_GENERATION = "Title Generation"
+    EMOJI_GENERATION = "Emoji Generation"
+    QUERY_GENERATION = "Query Generation"
+    FUNCTION_CALLING = "Function Calling"

+ 5 - 8
backend/main.py

@@ -126,7 +126,7 @@ from config import (
     WEBUI_SESSION_COOKIE_SECURE,
     WEBUI_SESSION_COOKIE_SECURE,
     AppConfig,
     AppConfig,
 )
 )
-from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
+from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS
 from utils.webhook import post_webhook
 from utils.webhook import post_webhook
 
 
 if SAFE_MODE:
 if SAFE_MODE:
@@ -311,6 +311,7 @@ async def get_function_call_response(
             {"role": "user", "content": f"Query: {prompt}"},
             {"role": "user", "content": f"Query: {prompt}"},
         ],
         ],
         "stream": False,
         "stream": False,
+        "task": TASKS.FUNCTION_CALLING,
     }
     }
 
 
     try:
     try:
@@ -323,7 +324,6 @@ async def get_function_call_response(
     response = None
     response = None
     try:
     try:
         response = await generate_chat_completions(form_data=payload, user=user)
         response = await generate_chat_completions(form_data=payload, user=user)
-
         content = None
         content = None
 
 
         if hasattr(response, "body_iterator"):
         if hasattr(response, "body_iterator"):
@@ -833,9 +833,6 @@ def filter_pipeline(payload, user):
                 pass
                 pass
 
 
     if "pipeline" not in app.state.MODELS[model_id]:
     if "pipeline" not in app.state.MODELS[model_id]:
-        if "title" in payload:
-            del payload["title"]
-
         if "task" in payload:
         if "task" in payload:
             del payload["task"]
             del payload["task"]
 
 
@@ -1338,7 +1335,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
         "stream": False,
         "stream": False,
         "max_tokens": 50,
         "max_tokens": 50,
         "chat_id": form_data.get("chat_id", None),
         "chat_id": form_data.get("chat_id", None),
-        "title": True,
+        "task": TASKS.TITLE_GENERATION,
     }
     }
 
 
     log.debug(payload)
     log.debug(payload)
@@ -1401,7 +1398,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
         "messages": [{"role": "user", "content": content}],
         "messages": [{"role": "user", "content": content}],
         "stream": False,
         "stream": False,
         "max_tokens": 30,
         "max_tokens": 30,
-        "task": True,
+        "task": TASKS.QUERY_GENERATION,
     }
     }
 
 
     print(payload)
     print(payload)
@@ -1468,7 +1465,7 @@ Message: """{{prompt}}"""
         "stream": False,
         "stream": False,
         "max_tokens": 4,
         "max_tokens": 4,
         "chat_id": form_data.get("chat_id", None),
         "chat_id": form_data.get("chat_id", None),
-        "task": True,
+        "task": TASKS.EMOJI_GENERATION,
     }
     }
 
 
     log.debug(payload)
     log.debug(payload)