Timothy J. Baek 9 months ago
parent
commit
f462744fc8
4 changed files with 33 additions and 12 deletions
  1. 6 4
      backend/apps/ollama/main.py
  2. 2 0
      backend/apps/openai/main.py
  3. 9 2
      backend/apps/webui/main.py
  4. 16 6
      backend/main.py

+ 6 - 4
backend/apps/ollama/main.py

@@ -728,8 +728,10 @@ async def generate_chat_completion(
     )
 
     payload = {
-        **form_data.model_dump(exclude_none=True),
+        **form_data.model_dump(exclude_none=True, exclude=["metadata"]),
     }
+    if "metadata" in payload:
+        del payload["metadata"]
 
     model_id = form_data.model
     model_info = Models.get_model_by_id(model_id)
@@ -894,9 +896,9 @@ async def generate_openai_chat_completion(
 ):
     form_data = OpenAIChatCompletionForm(**form_data)
 
-    payload = {
-        **form_data.model_dump(exclude_none=True),
-    }
+    payload = {**form_data}
+    if "metadata" in payload:
+        del payload["metadata"]
 
     model_id = form_data.model
     model_info = Models.get_model_by_id(model_id)

+ 2 - 0
backend/apps/openai/main.py

@@ -357,6 +357,8 @@ async def generate_chat_completion(
 ):
     idx = 0
     payload = {**form_data}
+    if "metadata" in payload:
+        del payload["metadata"]
 
     model_id = form_data.get("model")
     model_info = Models.get_model_by_id(model_id)

+ 9 - 2
backend/apps/webui/main.py

@@ -20,7 +20,6 @@ from apps.webui.routers import (
 )
 from apps.webui.models.functions import Functions
 from apps.webui.models.models import Models
-
 from apps.webui.utils import load_function_module_by_id
 
 from utils.misc import stream_message_template
@@ -53,7 +52,7 @@ import uuid
 import time
 import json
 
-from typing import Iterator, Generator
+from typing import Iterator, Generator, Optional
 from pydantic import BaseModel
 
 app = FastAPI()
@@ -193,6 +192,14 @@ async def generate_function_chat_completion(form_data, user):
     model_id = form_data.get("model")
     model_info = Models.get_model_by_id(model_id)
 
+    metadata = None
+    if "metadata" in form_data:
+        metadata = form_data["metadata"]
+        del form_data["metadata"]
+
+    if metadata:
+        print(metadata)
+
     if model_info:
         if model_info.base_model_id:
             form_data["model"] = model_info.base_model_id

+ 16 - 6
backend/main.py

@@ -618,6 +618,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                     content={"detail": str(e)},
                 )
 
+            # `task` field is used to determine the type of the request, e.g. `title_generation`, `query_generation`, etc.
+            task = None
+            if "task" in body:
+                task = body["task"]
+                del body["task"]
+
             # Extract session_id, chat_id and message_id from the request body
             session_id = None
             if "session_id" in body:
@@ -632,6 +638,8 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                 message_id = body["id"]
                 del body["id"]
 
+           
+
             __event_emitter__ = await get_event_emitter(
                 {"chat_id": chat_id, "message_id": message_id, "session_id": session_id}
             )
@@ -691,6 +699,13 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
             if len(citations) > 0:
                 data_items.append({"citations": citations})
 
+            body["metadata"] = {
+                "session_id": session_id,
+                "chat_id": chat_id,
+                "message_id": message_id,
+                "task": task,
+            }
+
             modified_body_bytes = json.dumps(body).encode("utf-8")
             # Replace the request body with the modified one
             request._body = modified_body_bytes
@@ -811,9 +826,6 @@ def filter_pipeline(payload, user):
                 if "detail" in res:
                     raise Exception(r.status_code, res["detail"])
 
-    if "pipeline" not in app.state.MODELS[model_id] and "task" in payload:
-        del payload["task"]
-
     return payload
 
 
@@ -1024,11 +1036,9 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
             status_code=status.HTTP_404_NOT_FOUND,
             detail="Model not found",
         )
-
     model = app.state.MODELS[model_id]
 
-    pipe = model.get("pipe")
-    if pipe:
+    if model.get("pipe"):
         return await generate_function_chat_completion(form_data, user=user)
     if model["owned_by"] == "ollama":
         return await generate_ollama_chat_completion(form_data, user=user)