|
@@ -15,6 +15,7 @@ import uuid
|
|
|
import inspect
|
|
|
import asyncio
|
|
|
|
|
|
+from fastapi.concurrency import run_in_threadpool
|
|
|
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
from fastapi.responses import JSONResponse
|
|
@@ -46,7 +47,7 @@ from apps.webui.main import app as webui_app, get_pipe_models
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel
|
|
|
-from typing import List, Optional
|
|
|
+from typing import List, Optional, Iterator, Generator, Union
|
|
|
|
|
|
from apps.webui.models.models import Models, ModelModel
|
|
|
from apps.webui.models.tools import Tools
|
|
@@ -66,7 +67,11 @@ from utils.task import (
|
|
|
search_query_generation_template,
|
|
|
tools_function_calling_generation_template,
|
|
|
)
|
|
|
-from utils.misc import get_last_user_message, add_or_update_system_message
|
|
|
+from utils.misc import (
|
|
|
+ get_last_user_message,
|
|
|
+ add_or_update_system_message,
|
|
|
+ stream_message_template,
|
|
|
+)
|
|
|
|
|
|
from apps.rag.utils import get_rag_context, rag_template
|
|
|
|
|
@@ -347,38 +352,39 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
model = app.state.MODELS[model_id]
|
|
|
|
|
|
# Check if the model has any filters
|
|
|
- for filter_id in model["info"]["meta"].get("filterIds", []):
|
|
|
- filter = Functions.get_function_by_id(filter_id)
|
|
|
- if filter:
|
|
|
- if filter_id in webui_app.state.FUNCTIONS:
|
|
|
- function_module = webui_app.state.FUNCTIONS[filter_id]
|
|
|
- else:
|
|
|
- function_module, function_type = load_function_module_by_id(
|
|
|
- filter_id
|
|
|
- )
|
|
|
- webui_app.state.FUNCTIONS[filter_id] = function_module
|
|
|
+ if "info" in model and "meta" in model["info"]:
|
|
|
+ for filter_id in model["info"]["meta"].get("filterIds", []):
|
|
|
+ filter = Functions.get_function_by_id(filter_id)
|
|
|
+ if filter:
|
|
|
+ if filter_id in webui_app.state.FUNCTIONS:
|
|
|
+ function_module = webui_app.state.FUNCTIONS[filter_id]
|
|
|
+ else:
|
|
|
+ function_module, function_type = load_function_module_by_id(
|
|
|
+ filter_id
|
|
|
+ )
|
|
|
+ webui_app.state.FUNCTIONS[filter_id] = function_module
|
|
|
|
|
|
- # Check if the function has a file_handler variable
|
|
|
- if getattr(function_module, "file_handler"):
|
|
|
- skip_files = True
|
|
|
+ # Check if the function has a file_handler variable
|
|
|
+ if getattr(function_module, "file_handler"):
|
|
|
+ skip_files = True
|
|
|
|
|
|
- try:
|
|
|
- if hasattr(function_module, "inlet"):
|
|
|
- data = function_module.inlet(
|
|
|
- data,
|
|
|
- {
|
|
|
- "id": user.id,
|
|
|
- "email": user.email,
|
|
|
- "name": user.name,
|
|
|
- "role": user.role,
|
|
|
- },
|
|
|
+ try:
|
|
|
+ if hasattr(function_module, "inlet"):
|
|
|
+ data = function_module.inlet(
|
|
|
+ data,
|
|
|
+ {
|
|
|
+ "id": user.id,
|
|
|
+ "email": user.email,
|
|
|
+ "name": user.name,
|
|
|
+ "role": user.role,
|
|
|
+ },
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error: {e}")
|
|
|
+ return JSONResponse(
|
|
|
+ status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
+ content={"detail": str(e)},
|
|
|
)
|
|
|
- except Exception as e:
|
|
|
- print(f"Error: {e}")
|
|
|
- return JSONResponse(
|
|
|
- status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
- content={"detail": str(e)},
|
|
|
- )
|
|
|
|
|
|
# Set the task model
|
|
|
task_model_id = data["model"]
|
|
@@ -794,13 +800,97 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
|
|
|
model = app.state.MODELS[model_id]
|
|
|
print(model)
|
|
|
|
|
|
-
|
|
|
+ pipe = model.get("pipe")
|
|
|
+ if pipe:
|
|
|
+
|
|
|
+ def job():
|
|
|
+ pipe_id = form_data["model"]
|
|
|
+ if "." in pipe_id:
|
|
|
+ pipe_id, sub_pipe_id = pipe_id.split(".", 1)
|
|
|
+ print(pipe_id)
|
|
|
+
|
|
|
+ pipe = webui_app.state.FUNCTIONS[pipe_id].pipe
|
|
|
+ if form_data["stream"]:
|
|
|
+
|
|
|
+ def stream_content():
|
|
|
+ res = pipe(body=form_data)
|
|
|
+
|
|
|
+ if isinstance(res, str):
|
|
|
+ message = stream_message_template(form_data["model"], res)
|
|
|
+ yield f"data: {json.dumps(message)}\n\n"
|
|
|
+
|
|
|
+ if isinstance(res, Iterator):
|
|
|
+ for line in res:
|
|
|
+ if isinstance(line, BaseModel):
|
|
|
+ line = line.model_dump_json()
|
|
|
+ line = f"data: {line}"
|
|
|
+ try:
|
|
|
+ line = line.decode("utf-8")
|
|
|
+ except:
|
|
|
+ pass
|
|
|
+
|
|
|
+ if line.startswith("data:"):
|
|
|
+ yield f"{line}\n\n"
|
|
|
+ else:
|
|
|
+ line = stream_message_template(form_data["model"], line)
|
|
|
+ yield f"data: {json.dumps(line)}\n\n"
|
|
|
+
|
|
|
+ if isinstance(res, str) or isinstance(res, Generator):
|
|
|
+ finish_message = {
|
|
|
+ "id": f"{form_data['model']}-{str(uuid.uuid4())}",
|
|
|
+ "object": "chat.completion.chunk",
|
|
|
+ "created": int(time.time()),
|
|
|
+ "model": form_data["model"],
|
|
|
+ "choices": [
|
|
|
+ {
|
|
|
+ "index": 0,
|
|
|
+ "delta": {},
|
|
|
+ "logprobs": None,
|
|
|
+ "finish_reason": "stop",
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ }
|
|
|
+
|
|
|
+ yield f"data: {json.dumps(finish_message)}\n\n"
|
|
|
+ yield f"data: [DONE]"
|
|
|
|
|
|
- if model.get('pipe') == True:
|
|
|
- print('hi')
|
|
|
-
|
|
|
-
|
|
|
-
|
|
|
+ return StreamingResponse(
|
|
|
+ stream_content(), media_type="text/event-stream"
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ res = pipe(body=form_data)
|
|
|
+
|
|
|
+ if isinstance(res, dict):
|
|
|
+ return res
|
|
|
+ elif isinstance(res, BaseModel):
|
|
|
+ return res.model_dump()
|
|
|
+ else:
|
|
|
+ message = ""
|
|
|
+ if isinstance(res, str):
|
|
|
+ message = res
|
|
|
+ if isinstance(res, Generator):
|
|
|
+ for stream in res:
|
|
|
+ message = f"{message}{stream}"
|
|
|
+
|
|
|
+ return {
|
|
|
+ "id": f"{form_data['model']}-{str(uuid.uuid4())}",
|
|
|
+ "object": "chat.completion",
|
|
|
+ "created": int(time.time()),
|
|
|
+ "model": form_data["model"],
|
|
|
+ "choices": [
|
|
|
+ {
|
|
|
+ "index": 0,
|
|
|
+ "message": {
|
|
|
+ "role": "assistant",
|
|
|
+ "content": message,
|
|
|
+ },
|
|
|
+ "logprobs": None,
|
|
|
+ "finish_reason": "stop",
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ }
|
|
|
+
|
|
|
+ return await run_in_threadpool(job)
|
|
|
if model["owned_by"] == "ollama":
|
|
|
return await generate_ollama_chat_completion(form_data, user=user)
|
|
|
else:
|
|
@@ -877,32 +967,35 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
|
|
pass
|
|
|
|
|
|
# Check if the model has any filters
|
|
|
- for filter_id in model["info"]["meta"].get("filterIds", []):
|
|
|
- filter = Functions.get_function_by_id(filter_id)
|
|
|
- if filter:
|
|
|
- if filter_id in webui_app.state.FUNCTIONS:
|
|
|
- function_module = webui_app.state.FUNCTIONS[filter_id]
|
|
|
- else:
|
|
|
- function_module, function_type = load_function_module_by_id(filter_id)
|
|
|
- webui_app.state.FUNCTIONS[filter_id] = function_module
|
|
|
+ if "info" in model and "meta" in model["info"]:
|
|
|
+ for filter_id in model["info"]["meta"].get("filterIds", []):
|
|
|
+ filter = Functions.get_function_by_id(filter_id)
|
|
|
+ if filter:
|
|
|
+ if filter_id in webui_app.state.FUNCTIONS:
|
|
|
+ function_module = webui_app.state.FUNCTIONS[filter_id]
|
|
|
+ else:
|
|
|
+ function_module, function_type = load_function_module_by_id(
|
|
|
+ filter_id
|
|
|
+ )
|
|
|
+ webui_app.state.FUNCTIONS[filter_id] = function_module
|
|
|
|
|
|
- try:
|
|
|
- if hasattr(function_module, "outlet"):
|
|
|
- data = function_module.outlet(
|
|
|
- data,
|
|
|
- {
|
|
|
- "id": user.id,
|
|
|
- "email": user.email,
|
|
|
- "name": user.name,
|
|
|
- "role": user.role,
|
|
|
- },
|
|
|
+ try:
|
|
|
+ if hasattr(function_module, "outlet"):
|
|
|
+ data = function_module.outlet(
|
|
|
+ data,
|
|
|
+ {
|
|
|
+ "id": user.id,
|
|
|
+ "email": user.email,
|
|
|
+ "name": user.name,
|
|
|
+ "role": user.role,
|
|
|
+ },
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error: {e}")
|
|
|
+ return JSONResponse(
|
|
|
+ status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
+ content={"detail": str(e)},
|
|
|
)
|
|
|
- except Exception as e:
|
|
|
- print(f"Error: {e}")
|
|
|
- return JSONResponse(
|
|
|
- status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
- content={"detail": str(e)},
|
|
|
- )
|
|
|
|
|
|
return data
|
|
|
|