|
@@ -15,9 +15,11 @@ import requests
|
|
|
import mimetypes
|
|
|
import shutil
|
|
|
import os
|
|
|
+import uuid
|
|
|
import inspect
|
|
|
import asyncio
|
|
|
|
|
|
+from fastapi.concurrency import run_in_threadpool
|
|
|
from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
from fastapi.responses import JSONResponse
|
|
@@ -46,16 +48,19 @@ from apps.openai.main import (
|
|
|
from apps.audio.main import app as audio_app
|
|
|
from apps.images.main import app as images_app
|
|
|
from apps.rag.main import app as rag_app
|
|
|
-from apps.webui.main import app as webui_app
|
|
|
+from apps.webui.main import app as webui_app, get_pipe_models
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel
|
|
|
-from typing import List, Optional
|
|
|
+from typing import List, Optional, Iterator, Generator, Union
|
|
|
|
|
|
from apps.webui.models.auths import Auths
|
|
|
from apps.webui.models.models import Models, ModelModel
|
|
|
from apps.webui.models.tools import Tools
|
|
|
+from apps.webui.models.functions import Functions
|
|
|
from apps.webui.models.users import Users
|
|
|
+
|
|
|
+from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
|
|
|
from apps.webui.utils import load_toolkit_module_by_id
|
|
|
|
|
|
from utils.misc import parse_duration
|
|
@@ -72,7 +77,11 @@ from utils.task import (
|
|
|
search_query_generation_template,
|
|
|
tools_function_calling_generation_template,
|
|
|
)
|
|
|
-from utils.misc import get_last_user_message, add_or_update_system_message
|
|
|
+from utils.misc import (
|
|
|
+ get_last_user_message,
|
|
|
+ add_or_update_system_message,
|
|
|
+ stream_message_template,
|
|
|
+)
|
|
|
|
|
|
from apps.rag.utils import get_rag_context, rag_template
|
|
|
|
|
@@ -85,6 +94,7 @@ from config import (
|
|
|
VERSION,
|
|
|
CHANGELOG,
|
|
|
FRONTEND_BUILD_DIR,
|
|
|
+ UPLOAD_DIR,
|
|
|
CACHE_DIR,
|
|
|
STATIC_DIR,
|
|
|
ENABLE_OPENAI_API,
|
|
@@ -184,7 +194,16 @@ app.state.MODELS = {}
|
|
|
origins = ["*"]
|
|
|
|
|
|
|
|
|
-async def get_function_call_response(messages, tool_id, template, task_model_id, user):
|
|
|
+##################################
|
|
|
+#
|
|
|
+# ChatCompletion Middleware
|
|
|
+#
|
|
|
+##################################
|
|
|
+
|
|
|
+
|
|
|
+async def get_function_call_response(
|
|
|
+ messages, files, tool_id, template, task_model_id, user
|
|
|
+):
|
|
|
tool = Tools.get_tool_by_id(tool_id)
|
|
|
tools_specs = json.dumps(tool.specs, indent=2)
|
|
|
content = tools_function_calling_generation_template(template, tools_specs)
|
|
@@ -222,9 +241,7 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
|
|
|
response = None
|
|
|
try:
|
|
|
if model["owned_by"] == "ollama":
|
|
|
- response = await generate_ollama_chat_completion(
|
|
|
- OpenAIChatCompletionForm(**payload), user=user
|
|
|
- )
|
|
|
+ response = await generate_ollama_chat_completion(payload, user=user)
|
|
|
else:
|
|
|
response = await generate_openai_chat_completion(payload, user=user)
|
|
|
|
|
@@ -247,6 +264,7 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
|
|
|
result = json.loads(content)
|
|
|
print(result)
|
|
|
|
|
|
+ citation = None
|
|
|
# Call the function
|
|
|
if "name" in result:
|
|
|
if tool_id in webui_app.state.TOOLS:
|
|
@@ -255,76 +273,170 @@ async def get_function_call_response(messages, tool_id, template, task_model_id,
|
|
|
toolkit_module = load_toolkit_module_by_id(tool_id)
|
|
|
webui_app.state.TOOLS[tool_id] = toolkit_module
|
|
|
|
|
|
+ file_handler = False
|
|
|
+ # check if toolkit_module has file_handler self variable
|
|
|
+ if hasattr(toolkit_module, "file_handler"):
|
|
|
+ file_handler = True
|
|
|
+ print("file_handler: ", file_handler)
|
|
|
+
|
|
|
function = getattr(toolkit_module, result["name"])
|
|
|
function_result = None
|
|
|
try:
|
|
|
# Get the signature of the function
|
|
|
sig = inspect.signature(function)
|
|
|
- # Check if '__user__' is a parameter of the function
|
|
|
+ params = result["parameters"]
|
|
|
+
|
|
|
if "__user__" in sig.parameters:
|
|
|
# Call the function with the '__user__' parameter included
|
|
|
- function_result = function(
|
|
|
- **{
|
|
|
- **result["parameters"],
|
|
|
- "__user__": {
|
|
|
- "id": user.id,
|
|
|
- "email": user.email,
|
|
|
- "name": user.name,
|
|
|
- "role": user.role,
|
|
|
- },
|
|
|
- }
|
|
|
- )
|
|
|
+ params = {
|
|
|
+ **params,
|
|
|
+ "__user__": {
|
|
|
+ "id": user.id,
|
|
|
+ "email": user.email,
|
|
|
+ "name": user.name,
|
|
|
+ "role": user.role,
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ if "__messages__" in sig.parameters:
|
|
|
+ # Call the function with the '__messages__' parameter included
|
|
|
+ params = {
|
|
|
+ **params,
|
|
|
+ "__messages__": messages,
|
|
|
+ }
|
|
|
+
|
|
|
+ if "__files__" in sig.parameters:
|
|
|
+ # Call the function with the '__files__' parameter included
|
|
|
+ params = {
|
|
|
+ **params,
|
|
|
+ "__files__": files,
|
|
|
+ }
|
|
|
+
|
|
|
+ if "__model__" in sig.parameters:
|
|
|
+ # Call the function with the '__model__' parameter included
|
|
|
+ params = {
|
|
|
+ **params,
|
|
|
+ "__model__": model,
|
|
|
+ }
|
|
|
+
|
|
|
+ if "__id__" in sig.parameters:
|
|
|
+ # Call the function with the '__id__' parameter included
|
|
|
+ params = {
|
|
|
+ **params,
|
|
|
+ "__id__": tool_id,
|
|
|
+ }
|
|
|
+
|
|
|
+ if inspect.iscoroutinefunction(function):
|
|
|
+ function_result = await function(**params)
|
|
|
else:
|
|
|
- # Call the function without modifying the parameters
|
|
|
- function_result = function(**result["parameters"])
|
|
|
+ function_result = function(**params)
|
|
|
+
|
|
|
+ if hasattr(toolkit_module, "citation") and toolkit_module.citation:
|
|
|
+ citation = {
|
|
|
+ "source": {"name": f"TOOL:{tool.name}/{result['name']}"},
|
|
|
+ "document": [function_result],
|
|
|
+ "metadata": [{"source": result["name"]}],
|
|
|
+ }
|
|
|
except Exception as e:
|
|
|
print(e)
|
|
|
|
|
|
# Add the function result to the system prompt
|
|
|
- if function_result:
|
|
|
- return function_result
|
|
|
+ if function_result is not None:
|
|
|
+ return function_result, citation, file_handler
|
|
|
except Exception as e:
|
|
|
print(f"Error: {e}")
|
|
|
|
|
|
- return None
|
|
|
+ return None, None, False
|
|
|
|
|
|
|
|
|
class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
|
- return_citations = False
|
|
|
+ data_items = []
|
|
|
|
|
|
- if request.method == "POST" and (
|
|
|
- "/ollama/api/chat" in request.url.path
|
|
|
- or "/chat/completions" in request.url.path
|
|
|
+ show_citations = False
|
|
|
+ citations = []
|
|
|
+
|
|
|
+ if request.method == "POST" and any(
|
|
|
+ endpoint in request.url.path
|
|
|
+ for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
|
|
):
|
|
|
log.debug(f"request.url.path: {request.url.path}")
|
|
|
|
|
|
# Read the original request body
|
|
|
body = await request.body()
|
|
|
- # Decode body to string
|
|
|
body_str = body.decode("utf-8")
|
|
|
- # Parse string to JSON
|
|
|
data = json.loads(body_str) if body_str else {}
|
|
|
|
|
|
user = get_current_user(
|
|
|
- get_http_authorization_cred(request.headers.get("Authorization"))
|
|
|
+ request,
|
|
|
+ get_http_authorization_cred(request.headers.get("Authorization")),
|
|
|
)
|
|
|
-
|
|
|
- # Remove the citations from the body
|
|
|
- return_citations = data.get("citations", False)
|
|
|
- if "citations" in data:
|
|
|
+ # Flag to skip RAG completions if file_handler is present in tools/functions
|
|
|
+ skip_files = False
|
|
|
+ if data.get("citations"):
|
|
|
+ show_citations = True
|
|
|
del data["citations"]
|
|
|
|
|
|
- # Set the task model
|
|
|
- task_model_id = data["model"]
|
|
|
- if task_model_id not in app.state.MODELS:
|
|
|
+ model_id = data["model"]
|
|
|
+ if model_id not in app.state.MODELS:
|
|
|
raise HTTPException(
|
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
|
detail="Model not found",
|
|
|
)
|
|
|
+ model = app.state.MODELS[model_id]
|
|
|
+
|
|
|
+ # Check if the model has any filters
|
|
|
+ if "info" in model and "meta" in model["info"]:
|
|
|
+ for filter_id in model["info"]["meta"].get("filterIds", []):
|
|
|
+ filter = Functions.get_function_by_id(filter_id)
|
|
|
+ if filter:
|
|
|
+ if filter_id in webui_app.state.FUNCTIONS:
|
|
|
+ function_module = webui_app.state.FUNCTIONS[filter_id]
|
|
|
+ else:
|
|
|
+ function_module, function_type = load_function_module_by_id(
|
|
|
+ filter_id
|
|
|
+ )
|
|
|
+ webui_app.state.FUNCTIONS[filter_id] = function_module
|
|
|
+
|
|
|
+ # Check if the function has a file_handler variable
|
|
|
+ if hasattr(function_module, "file_handler"):
|
|
|
+ skip_files = function_module.file_handler
|
|
|
+
|
|
|
+ try:
|
|
|
+ if hasattr(function_module, "inlet"):
|
|
|
+ inlet = function_module.inlet
|
|
|
+
|
|
|
+ if inspect.iscoroutinefunction(inlet):
|
|
|
+ data = await inlet(
|
|
|
+ data,
|
|
|
+ {
|
|
|
+ "id": user.id,
|
|
|
+ "email": user.email,
|
|
|
+ "name": user.name,
|
|
|
+ "role": user.role,
|
|
|
+ },
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ data = inlet(
|
|
|
+ data,
|
|
|
+ {
|
|
|
+ "id": user.id,
|
|
|
+ "email": user.email,
|
|
|
+ "name": user.name,
|
|
|
+ "role": user.role,
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error: {e}")
|
|
|
+ return JSONResponse(
|
|
|
+ status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
+ content={"detail": str(e)},
|
|
|
+ )
|
|
|
|
|
|
- # Check if the user has a custom task model
|
|
|
- # If the user has a custom task model, use that model
|
|
|
+ # Set the task model
|
|
|
+ task_model_id = data["model"]
|
|
|
+ # Check if the user has a custom task model and use that model
|
|
|
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
|
|
|
if (
|
|
|
app.state.config.TASK_MODEL
|
|
@@ -347,55 +459,71 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
for tool_id in data["tool_ids"]:
|
|
|
print(tool_id)
|
|
|
try:
|
|
|
- response = await get_function_call_response(
|
|
|
- messages=data["messages"],
|
|
|
- tool_id=tool_id,
|
|
|
- template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
|
- task_model_id=task_model_id,
|
|
|
- user=user,
|
|
|
+ response, citation, file_handler = (
|
|
|
+ await get_function_call_response(
|
|
|
+ messages=data["messages"],
|
|
|
+ files=data.get("files", []),
|
|
|
+ tool_id=tool_id,
|
|
|
+ template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
|
+ task_model_id=task_model_id,
|
|
|
+ user=user,
|
|
|
+ )
|
|
|
)
|
|
|
|
|
|
- if response:
|
|
|
+ print(file_handler)
|
|
|
+ if isinstance(response, str):
|
|
|
context += ("\n" if context != "" else "") + response
|
|
|
+
|
|
|
+ if citation:
|
|
|
+ citations.append(citation)
|
|
|
+ show_citations = True
|
|
|
+
|
|
|
+ if file_handler:
|
|
|
+ skip_files = True
|
|
|
+
|
|
|
except Exception as e:
|
|
|
print(f"Error: {e}")
|
|
|
del data["tool_ids"]
|
|
|
|
|
|
print(f"tool_context: {context}")
|
|
|
|
|
|
- # If docs field is present, generate RAG completions
|
|
|
- if "docs" in data:
|
|
|
- data = {**data}
|
|
|
- rag_context, citations = get_rag_context(
|
|
|
- docs=data["docs"],
|
|
|
- messages=data["messages"],
|
|
|
- embedding_function=rag_app.state.EMBEDDING_FUNCTION,
|
|
|
- k=rag_app.state.config.TOP_K,
|
|
|
- reranking_function=rag_app.state.sentence_transformer_rf,
|
|
|
- r=rag_app.state.config.RELEVANCE_THRESHOLD,
|
|
|
- hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
|
|
- )
|
|
|
+ # If files field is present, generate RAG completions
|
|
|
+ # If skip_files is True, skip the RAG completions
|
|
|
+ if "files" in data:
|
|
|
+ if not skip_files:
|
|
|
+ data = {**data}
|
|
|
+ rag_context, rag_citations = get_rag_context(
|
|
|
+ files=data["files"],
|
|
|
+ messages=data["messages"],
|
|
|
+ embedding_function=rag_app.state.EMBEDDING_FUNCTION,
|
|
|
+ k=rag_app.state.config.TOP_K,
|
|
|
+ reranking_function=rag_app.state.sentence_transformer_rf,
|
|
|
+ r=rag_app.state.config.RELEVANCE_THRESHOLD,
|
|
|
+ hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
|
|
+ )
|
|
|
+ if rag_context:
|
|
|
+ context += ("\n" if context != "" else "") + rag_context
|
|
|
+
|
|
|
+ log.debug(f"rag_context: {rag_context}, citations: {citations}")
|
|
|
|
|
|
- if rag_context:
|
|
|
- context += ("\n" if context != "" else "") + rag_context
|
|
|
+ if rag_citations:
|
|
|
+ citations.extend(rag_citations)
|
|
|
|
|
|
- del data["docs"]
|
|
|
+ del data["files"]
|
|
|
|
|
|
- log.debug(f"rag_context: {rag_context}, citations: {citations}")
|
|
|
+ if show_citations and len(citations) > 0:
|
|
|
+ data_items.append({"citations": citations})
|
|
|
|
|
|
if context != "":
|
|
|
system_prompt = rag_template(
|
|
|
rag_app.state.config.RAG_TEMPLATE, context, prompt
|
|
|
)
|
|
|
-
|
|
|
print(system_prompt)
|
|
|
-
|
|
|
data["messages"] = add_or_update_system_message(
|
|
|
- f"\n{system_prompt}", data["messages"]
|
|
|
+ system_prompt, data["messages"]
|
|
|
)
|
|
|
|
|
|
modified_body_bytes = json.dumps(data).encode("utf-8")
|
|
|
-
|
|
|
# Replace the request body with the modified one
|
|
|
request._body = modified_body_bytes
|
|
|
# Set custom header to ensure content-length matches new body length
|
|
@@ -408,43 +536,54 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
],
|
|
|
]
|
|
|
|
|
|
- response = await call_next(request)
|
|
|
-
|
|
|
- if return_citations:
|
|
|
- # Inject the citations into the response
|
|
|
+ response = await call_next(request)
|
|
|
if isinstance(response, StreamingResponse):
|
|
|
# If it's a streaming response, inject it as SSE event or NDJSON line
|
|
|
content_type = response.headers.get("Content-Type")
|
|
|
if "text/event-stream" in content_type:
|
|
|
return StreamingResponse(
|
|
|
- self.openai_stream_wrapper(response.body_iterator, citations),
|
|
|
+ self.openai_stream_wrapper(response.body_iterator, data_items),
|
|
|
)
|
|
|
if "application/x-ndjson" in content_type:
|
|
|
return StreamingResponse(
|
|
|
- self.ollama_stream_wrapper(response.body_iterator, citations),
|
|
|
+ self.ollama_stream_wrapper(response.body_iterator, data_items),
|
|
|
)
|
|
|
+ else:
|
|
|
+ return response
|
|
|
|
|
|
+ # If it's not a chat completion request, just pass it through
|
|
|
+ response = await call_next(request)
|
|
|
return response
|
|
|
|
|
|
async def _receive(self, body: bytes):
|
|
|
return {"type": "http.request", "body": body, "more_body": False}
|
|
|
|
|
|
- async def openai_stream_wrapper(self, original_generator, citations):
|
|
|
- yield f"data: {json.dumps({'citations': citations})}\n\n"
|
|
|
+ async def openai_stream_wrapper(self, original_generator, data_items):
|
|
|
+ for item in data_items:
|
|
|
+ yield f"data: {json.dumps(item)}\n\n"
|
|
|
+
|
|
|
async for data in original_generator:
|
|
|
yield data
|
|
|
|
|
|
- async def ollama_stream_wrapper(self, original_generator, citations):
|
|
|
- yield f"{json.dumps({'citations': citations})}\n"
|
|
|
+ async def ollama_stream_wrapper(self, original_generator, data_items):
|
|
|
+ for item in data_items:
|
|
|
+ yield f"{json.dumps(item)}\n"
|
|
|
+
|
|
|
async for data in original_generator:
|
|
|
yield data
|
|
|
|
|
|
|
|
|
app.add_middleware(ChatCompletionMiddleware)
|
|
|
|
|
|
+##################################
|
|
|
+#
|
|
|
+# Pipeline Middleware
|
|
|
+#
|
|
|
+##################################
|
|
|
+
|
|
|
|
|
|
def filter_pipeline(payload, user):
|
|
|
- user = {"id": user.id, "name": user.name, "role": user.role}
|
|
|
+ user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
|
|
model_id = payload["model"]
|
|
|
filters = [
|
|
|
model
|
|
@@ -532,7 +671,8 @@ class PipelineMiddleware(BaseHTTPMiddleware):
|
|
|
data = json.loads(body_str) if body_str else {}
|
|
|
|
|
|
user = get_current_user(
|
|
|
- get_http_authorization_cred(request.headers.get("Authorization"))
|
|
|
+ request,
|
|
|
+ get_http_authorization_cred(request.headers.get("Authorization")),
|
|
|
)
|
|
|
|
|
|
try:
|
|
@@ -600,7 +740,6 @@ async def update_embedding_function(request: Request, call_next):
|
|
|
|
|
|
app.mount("/ws", socket_app)
|
|
|
|
|
|
-
|
|
|
app.mount("/ollama", ollama_app)
|
|
|
app.mount("/openai", openai_app)
|
|
|
|
|
@@ -614,17 +753,18 @@ webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
|
|
|
|
|
|
|
|
|
async def get_all_models():
|
|
|
+ pipe_models = []
|
|
|
openai_models = []
|
|
|
ollama_models = []
|
|
|
|
|
|
+ pipe_models = await get_pipe_models()
|
|
|
+
|
|
|
if app.state.config.ENABLE_OPENAI_API:
|
|
|
openai_models = await get_openai_models()
|
|
|
-
|
|
|
openai_models = openai_models["data"]
|
|
|
|
|
|
if app.state.config.ENABLE_OLLAMA_API:
|
|
|
ollama_models = await get_ollama_models()
|
|
|
-
|
|
|
ollama_models = [
|
|
|
{
|
|
|
"id": model["model"],
|
|
@@ -637,9 +777,9 @@ async def get_all_models():
|
|
|
for model in ollama_models["models"]
|
|
|
]
|
|
|
|
|
|
- models = openai_models + ollama_models
|
|
|
- custom_models = Models.get_all_models()
|
|
|
+ models = pipe_models + openai_models + ollama_models
|
|
|
|
|
|
+ custom_models = Models.get_all_models()
|
|
|
for custom_model in custom_models:
|
|
|
if custom_model.base_model_id == None:
|
|
|
for model in models:
|
|
@@ -702,6 +842,253 @@ async def get_models(user=Depends(get_verified_user)):
|
|
|
return {"data": models}
|
|
|
|
|
|
|
|
|
+@app.post("/api/chat/completions")
|
|
|
+async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
|
|
|
+ model_id = form_data["model"]
|
|
|
+ if model_id not in app.state.MODELS:
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_404_NOT_FOUND,
|
|
|
+ detail="Model not found",
|
|
|
+ )
|
|
|
+
|
|
|
+ model = app.state.MODELS[model_id]
|
|
|
+ print(model)
|
|
|
+
|
|
|
+ pipe = model.get("pipe")
|
|
|
+ if pipe:
|
|
|
+ form_data["user"] = {
|
|
|
+ "id": user.id,
|
|
|
+ "email": user.email,
|
|
|
+ "name": user.name,
|
|
|
+ "role": user.role,
|
|
|
+ }
|
|
|
+
|
|
|
+ async def job():
|
|
|
+ pipe_id = form_data["model"]
|
|
|
+ if "." in pipe_id:
|
|
|
+ pipe_id, sub_pipe_id = pipe_id.split(".", 1)
|
|
|
+ print(pipe_id)
|
|
|
+
|
|
|
+ pipe = webui_app.state.FUNCTIONS[pipe_id].pipe
|
|
|
+ if form_data["stream"]:
|
|
|
+
|
|
|
+ async def stream_content():
|
|
|
+ if inspect.iscoroutinefunction(pipe):
|
|
|
+ res = await pipe(body=form_data)
|
|
|
+ else:
|
|
|
+ res = pipe(body=form_data)
|
|
|
+
|
|
|
+ if isinstance(res, str):
|
|
|
+ message = stream_message_template(form_data["model"], res)
|
|
|
+ yield f"data: {json.dumps(message)}\n\n"
|
|
|
+
|
|
|
+ if isinstance(res, Iterator):
|
|
|
+ for line in res:
|
|
|
+ if isinstance(line, BaseModel):
|
|
|
+ line = line.model_dump_json()
|
|
|
+ line = f"data: {line}"
|
|
|
+ try:
|
|
|
+ line = line.decode("utf-8")
|
|
|
+ except:
|
|
|
+ pass
|
|
|
+
|
|
|
+ if line.startswith("data:"):
|
|
|
+ yield f"{line}\n\n"
|
|
|
+ else:
|
|
|
+ line = stream_message_template(form_data["model"], line)
|
|
|
+ yield f"data: {json.dumps(line)}\n\n"
|
|
|
+
|
|
|
+ if isinstance(res, str) or isinstance(res, Generator):
|
|
|
+ finish_message = {
|
|
|
+ "id": f"{form_data['model']}-{str(uuid.uuid4())}",
|
|
|
+ "object": "chat.completion.chunk",
|
|
|
+ "created": int(time.time()),
|
|
|
+ "model": form_data["model"],
|
|
|
+ "choices": [
|
|
|
+ {
|
|
|
+ "index": 0,
|
|
|
+ "delta": {},
|
|
|
+ "logprobs": None,
|
|
|
+ "finish_reason": "stop",
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ }
|
|
|
+
|
|
|
+ yield f"data: {json.dumps(finish_message)}\n\n"
|
|
|
+ yield f"data: [DONE]"
|
|
|
+
|
|
|
+ return StreamingResponse(
|
|
|
+ stream_content(), media_type="text/event-stream"
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ if inspect.iscoroutinefunction(pipe):
|
|
|
+ res = await pipe(body=form_data)
|
|
|
+ else:
|
|
|
+ res = pipe(body=form_data)
|
|
|
+
|
|
|
+ if isinstance(res, dict):
|
|
|
+ return res
|
|
|
+ elif isinstance(res, BaseModel):
|
|
|
+ return res.model_dump()
|
|
|
+ else:
|
|
|
+ message = ""
|
|
|
+ if isinstance(res, str):
|
|
|
+ message = res
|
|
|
+ if isinstance(res, Generator):
|
|
|
+ for stream in res:
|
|
|
+ message = f"{message}{stream}"
|
|
|
+
|
|
|
+ return {
|
|
|
+ "id": f"{form_data['model']}-{str(uuid.uuid4())}",
|
|
|
+ "object": "chat.completion",
|
|
|
+ "created": int(time.time()),
|
|
|
+ "model": form_data["model"],
|
|
|
+ "choices": [
|
|
|
+ {
|
|
|
+ "index": 0,
|
|
|
+ "message": {
|
|
|
+ "role": "assistant",
|
|
|
+ "content": message,
|
|
|
+ },
|
|
|
+ "logprobs": None,
|
|
|
+ "finish_reason": "stop",
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ }
|
|
|
+
|
|
|
+ return await job()
|
|
|
+ if model["owned_by"] == "ollama":
|
|
|
+ return await generate_ollama_chat_completion(form_data, user=user)
|
|
|
+ else:
|
|
|
+ return await generate_openai_chat_completion(form_data, user=user)
|
|
|
+
|
|
|
+
|
|
|
+@app.post("/api/chat/completed")
|
|
|
+async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
|
|
+ data = form_data
|
|
|
+ model_id = data["model"]
|
|
|
+ if model_id not in app.state.MODELS:
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_404_NOT_FOUND,
|
|
|
+ detail="Model not found",
|
|
|
+ )
|
|
|
+ model = app.state.MODELS[model_id]
|
|
|
+
|
|
|
+ filters = [
|
|
|
+ model
|
|
|
+ for model in app.state.MODELS.values()
|
|
|
+ if "pipeline" in model
|
|
|
+ and "type" in model["pipeline"]
|
|
|
+ and model["pipeline"]["type"] == "filter"
|
|
|
+ and (
|
|
|
+ model["pipeline"]["pipelines"] == ["*"]
|
|
|
+ or any(
|
|
|
+ model_id == target_model_id
|
|
|
+ for target_model_id in model["pipeline"]["pipelines"]
|
|
|
+ )
|
|
|
+ )
|
|
|
+ ]
|
|
|
+
|
|
|
+ sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
|
|
+ if "pipeline" in model:
|
|
|
+ sorted_filters = [model] + sorted_filters
|
|
|
+
|
|
|
+ for filter in sorted_filters:
|
|
|
+ r = None
|
|
|
+ try:
|
|
|
+ urlIdx = filter["urlIdx"]
|
|
|
+
|
|
|
+ url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
|
|
+ key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
|
|
+
|
|
|
+ if key != "":
|
|
|
+ headers = {"Authorization": f"Bearer {key}"}
|
|
|
+ r = requests.post(
|
|
|
+ f"{url}/{filter['id']}/filter/outlet",
|
|
|
+ headers=headers,
|
|
|
+ json={
|
|
|
+ "user": {"id": user.id, "name": user.name, "role": user.role},
|
|
|
+ "body": data,
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+ r.raise_for_status()
|
|
|
+ data = r.json()
|
|
|
+ except Exception as e:
|
|
|
+ # Handle connection error here
|
|
|
+ print(f"Connection error: {e}")
|
|
|
+
|
|
|
+ if r is not None:
|
|
|
+ try:
|
|
|
+ res = r.json()
|
|
|
+ if "detail" in res:
|
|
|
+ return JSONResponse(
|
|
|
+ status_code=r.status_code,
|
|
|
+ content=res,
|
|
|
+ )
|
|
|
+ except:
|
|
|
+ pass
|
|
|
+
|
|
|
+ else:
|
|
|
+ pass
|
|
|
+
|
|
|
+ # Check if the model has any filters
|
|
|
+ if "info" in model and "meta" in model["info"]:
|
|
|
+ for filter_id in model["info"]["meta"].get("filterIds", []):
|
|
|
+ filter = Functions.get_function_by_id(filter_id)
|
|
|
+ if filter:
|
|
|
+ if filter_id in webui_app.state.FUNCTIONS:
|
|
|
+ function_module = webui_app.state.FUNCTIONS[filter_id]
|
|
|
+ else:
|
|
|
+ function_module, function_type = load_function_module_by_id(
|
|
|
+ filter_id
|
|
|
+ )
|
|
|
+ webui_app.state.FUNCTIONS[filter_id] = function_module
|
|
|
+
|
|
|
+ try:
|
|
|
+ if hasattr(function_module, "outlet"):
|
|
|
+ outlet = function_module.outlet
|
|
|
+ if inspect.iscoroutinefunction(outlet):
|
|
|
+ data = await outlet(
|
|
|
+ data,
|
|
|
+ {
|
|
|
+ "id": user.id,
|
|
|
+ "email": user.email,
|
|
|
+ "name": user.name,
|
|
|
+ "role": user.role,
|
|
|
+ },
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ data = outlet(
|
|
|
+ data,
|
|
|
+ {
|
|
|
+ "id": user.id,
|
|
|
+ "email": user.email,
|
|
|
+ "name": user.name,
|
|
|
+ "role": user.role,
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error: {e}")
|
|
|
+ return JSONResponse(
|
|
|
+ status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
+ content={"detail": str(e)},
|
|
|
+ )
|
|
|
+
|
|
|
+ return data
|
|
|
+
|
|
|
+
|
|
|
+##################################
|
|
|
+#
|
|
|
+# Task Endpoints
|
|
|
+#
|
|
|
+##################################
|
|
|
+
|
|
|
+
|
|
|
+# TODO: Refactor task API endpoints below into a separate file
|
|
|
+
|
|
|
+
|
|
|
@app.get("/api/task/config")
|
|
|
async def get_task_config(user=Depends(get_verified_user)):
|
|
|
return {
|
|
@@ -780,7 +1167,12 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
|
|
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
|
|
|
|
|
|
content = title_generation_template(
|
|
|
- template, form_data["prompt"], user.model_dump()
|
|
|
+ template,
|
|
|
+ form_data["prompt"],
|
|
|
+ {
|
|
|
+ "name": user.name,
|
|
|
+ "location": user.info.get("location") if user.info else None,
|
|
|
+ },
|
|
|
)
|
|
|
|
|
|
payload = {
|
|
@@ -792,7 +1184,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
|
|
"title": True,
|
|
|
}
|
|
|
|
|
|
- print(payload)
|
|
|
+ log.debug(payload)
|
|
|
|
|
|
try:
|
|
|
payload = filter_pipeline(payload, user)
|
|
@@ -803,9 +1195,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
|
|
)
|
|
|
|
|
|
if model["owned_by"] == "ollama":
|
|
|
- return await generate_ollama_chat_completion(
|
|
|
- OpenAIChatCompletionForm(**payload), user=user
|
|
|
- )
|
|
|
+ return await generate_ollama_chat_completion(payload, user=user)
|
|
|
else:
|
|
|
return await generate_openai_chat_completion(payload, user=user)
|
|
|
|
|
@@ -846,7 +1236,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
|
|
|
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
|
|
|
|
|
content = search_query_generation_template(
|
|
|
- template, form_data["prompt"], user.model_dump()
|
|
|
+ template, form_data["prompt"], {"name": user.name}
|
|
|
)
|
|
|
|
|
|
payload = {
|
|
@@ -868,9 +1258,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
|
|
|
)
|
|
|
|
|
|
if model["owned_by"] == "ollama":
|
|
|
- return await generate_ollama_chat_completion(
|
|
|
- OpenAIChatCompletionForm(**payload), user=user
|
|
|
- )
|
|
|
+ return await generate_ollama_chat_completion(payload, user=user)
|
|
|
else:
|
|
|
return await generate_openai_chat_completion(payload, user=user)
|
|
|
|
|
@@ -909,7 +1297,12 @@ Message: """{{prompt}}"""
|
|
|
'''
|
|
|
|
|
|
content = title_generation_template(
|
|
|
- template, form_data["prompt"], user.model_dump()
|
|
|
+ template,
|
|
|
+ form_data["prompt"],
|
|
|
+ {
|
|
|
+ "name": user.name,
|
|
|
+ "location": user.info.get("location") if user.info else None,
|
|
|
+ },
|
|
|
)
|
|
|
|
|
|
payload = {
|
|
@@ -921,7 +1314,7 @@ Message: """{{prompt}}"""
|
|
|
"task": True,
|
|
|
}
|
|
|
|
|
|
- print(payload)
|
|
|
+ log.debug(payload)
|
|
|
|
|
|
try:
|
|
|
payload = filter_pipeline(payload, user)
|
|
@@ -932,9 +1325,7 @@ Message: """{{prompt}}"""
|
|
|
)
|
|
|
|
|
|
if model["owned_by"] == "ollama":
|
|
|
- return await generate_ollama_chat_completion(
|
|
|
- OpenAIChatCompletionForm(**payload), user=user
|
|
|
- )
|
|
|
+ return await generate_ollama_chat_completion(payload, user=user)
|
|
|
else:
|
|
|
return await generate_openai_chat_completion(payload, user=user)
|
|
|
|
|
@@ -967,8 +1358,13 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
|
|
|
template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
|
|
|
|
|
try:
|
|
|
- context = await get_function_call_response(
|
|
|
- form_data["messages"], form_data["tool_id"], template, model_id, user
|
|
|
+ context, citation, file_handler = await get_function_call_response(
|
|
|
+ form_data["messages"],
|
|
|
+ form_data.get("files", []),
|
|
|
+ form_data["tool_id"],
|
|
|
+ template,
|
|
|
+ model_id,
|
|
|
+ user,
|
|
|
)
|
|
|
return context
|
|
|
except Exception as e:
|
|
@@ -978,94 +1374,14 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
|
|
|
)
|
|
|
|
|
|
|
|
|
-@app.post("/api/chat/completions")
|
|
|
-async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
|
|
|
- model_id = form_data["model"]
|
|
|
- if model_id not in app.state.MODELS:
|
|
|
- raise HTTPException(
|
|
|
- status_code=status.HTTP_404_NOT_FOUND,
|
|
|
- detail="Model not found",
|
|
|
- )
|
|
|
-
|
|
|
- model = app.state.MODELS[model_id]
|
|
|
- print(model)
|
|
|
+##################################
|
|
|
+#
|
|
|
+# Pipelines Endpoints
|
|
|
+#
|
|
|
+##################################
|
|
|
|
|
|
- if model["owned_by"] == "ollama":
|
|
|
- return await generate_ollama_chat_completion(
|
|
|
- OpenAIChatCompletionForm(**form_data), user=user
|
|
|
- )
|
|
|
- else:
|
|
|
- return await generate_openai_chat_completion(form_data, user=user)
|
|
|
|
|
|
-
|
|
|
-@app.post("/api/chat/completed")
|
|
|
-async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
|
|
- data = form_data
|
|
|
- model_id = data["model"]
|
|
|
-
|
|
|
- filters = [
|
|
|
- model
|
|
|
- for model in app.state.MODELS.values()
|
|
|
- if "pipeline" in model
|
|
|
- and "type" in model["pipeline"]
|
|
|
- and model["pipeline"]["type"] == "filter"
|
|
|
- and (
|
|
|
- model["pipeline"]["pipelines"] == ["*"]
|
|
|
- or any(
|
|
|
- model_id == target_model_id
|
|
|
- for target_model_id in model["pipeline"]["pipelines"]
|
|
|
- )
|
|
|
- )
|
|
|
- ]
|
|
|
- sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
|
|
-
|
|
|
- print(model_id)
|
|
|
-
|
|
|
- if model_id in app.state.MODELS:
|
|
|
- model = app.state.MODELS[model_id]
|
|
|
- if "pipeline" in model:
|
|
|
- sorted_filters = [model] + sorted_filters
|
|
|
-
|
|
|
- for filter in sorted_filters:
|
|
|
- r = None
|
|
|
- try:
|
|
|
- urlIdx = filter["urlIdx"]
|
|
|
-
|
|
|
- url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
|
|
- key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
|
|
-
|
|
|
- if key != "":
|
|
|
- headers = {"Authorization": f"Bearer {key}"}
|
|
|
- r = requests.post(
|
|
|
- f"{url}/{filter['id']}/filter/outlet",
|
|
|
- headers=headers,
|
|
|
- json={
|
|
|
- "user": {"id": user.id, "name": user.name, "role": user.role},
|
|
|
- "body": data,
|
|
|
- },
|
|
|
- )
|
|
|
-
|
|
|
- r.raise_for_status()
|
|
|
- data = r.json()
|
|
|
- except Exception as e:
|
|
|
- # Handle connection error here
|
|
|
- print(f"Connection error: {e}")
|
|
|
-
|
|
|
- if r is not None:
|
|
|
- try:
|
|
|
- res = r.json()
|
|
|
- if "detail" in res:
|
|
|
- return JSONResponse(
|
|
|
- status_code=r.status_code,
|
|
|
- content=res,
|
|
|
- )
|
|
|
- except:
|
|
|
- pass
|
|
|
-
|
|
|
- else:
|
|
|
- pass
|
|
|
-
|
|
|
- return data
|
|
|
+# TODO: Refactor pipelines API endpoints below into a separate file
|
|
|
|
|
|
|
|
|
@app.get("/api/pipelines/list")
|
|
@@ -1388,6 +1704,13 @@ async def update_pipeline_valves(
|
|
|
)
|
|
|
|
|
|
|
|
|
+##################################
|
|
|
+#
|
|
|
+# Config Endpoints
|
|
|
+#
|
|
|
+##################################
|
|
|
+
|
|
|
+
|
|
|
@app.get("/api/config")
|
|
|
async def get_app_config():
|
|
|
# Checking and Handling the Absence of 'ui' in CONFIG_DATA
|
|
@@ -1457,6 +1780,9 @@ async def update_model_filter_config(
|
|
|
}
|
|
|
|
|
|
|
|
|
+# TODO: webhook endpoint should be under config endpoints
|
|
|
+
|
|
|
+
|
|
|
@app.get("/api/webhook")
|
|
|
async def get_webhook_url(user=Depends(get_admin_user)):
|
|
|
return {
|