|
@@ -13,8 +13,12 @@ import logging
|
|
|
import aiohttp
|
|
|
import requests
|
|
|
import mimetypes
|
|
|
+import shutil
|
|
|
+import os
|
|
|
+import inspect
|
|
|
+import asyncio
|
|
|
|
|
|
-from fastapi import FastAPI, Request, Depends, status
|
|
|
+from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
|
|
|
from fastapi.staticfiles import StaticFiles
|
|
|
from fastapi.responses import JSONResponse
|
|
|
from fastapi import HTTPException
|
|
@@ -27,21 +31,33 @@ from starlette.responses import StreamingResponse, Response, RedirectResponse
|
|
|
|
|
|
|
|
|
from apps.socket.main import app as socket_app
|
|
|
-from apps.ollama.main import app as ollama_app, get_all_models as get_ollama_models
|
|
|
-from apps.openai.main import app as openai_app, get_all_models as get_openai_models
|
|
|
+from apps.ollama.main import (
|
|
|
+ app as ollama_app,
|
|
|
+ OpenAIChatCompletionForm,
|
|
|
+ get_all_models as get_ollama_models,
|
|
|
+ generate_openai_chat_completion as generate_ollama_chat_completion,
|
|
|
+)
|
|
|
+from apps.openai.main import (
|
|
|
+ app as openai_app,
|
|
|
+ get_all_models as get_openai_models,
|
|
|
+ generate_chat_completion as generate_openai_chat_completion,
|
|
|
+)
|
|
|
|
|
|
from apps.audio.main import app as audio_app
|
|
|
from apps.images.main import app as images_app
|
|
|
from apps.rag.main import app as rag_app
|
|
|
from apps.webui.main import app as webui_app
|
|
|
|
|
|
-import asyncio
|
|
|
+
|
|
|
from pydantic import BaseModel
|
|
|
from typing import List, Optional
|
|
|
|
|
|
from apps.webui.models.auths import Auths
|
|
|
-from apps.webui.models.models import Models
|
|
|
+from apps.webui.models.models import Models, ModelModel
|
|
|
+from apps.webui.models.tools import Tools
|
|
|
from apps.webui.models.users import Users
|
|
|
+from apps.webui.utils import load_toolkit_module_by_id
|
|
|
+
|
|
|
from utils.misc import parse_duration
|
|
|
from utils.utils import (
|
|
|
get_admin_user,
|
|
@@ -51,7 +67,14 @@ from utils.utils import (
|
|
|
get_password_hash,
|
|
|
create_token,
|
|
|
)
|
|
|
-from apps.rag.utils import rag_messages
|
|
|
+from utils.task import (
|
|
|
+ title_generation_template,
|
|
|
+ search_query_generation_template,
|
|
|
+ tools_function_calling_generation_template,
|
|
|
+)
|
|
|
+from utils.misc import get_last_user_message, add_or_update_system_message
|
|
|
+
|
|
|
+from apps.rag.utils import get_rag_context, rag_template
|
|
|
|
|
|
from config import (
|
|
|
CONFIG_DATA,
|
|
@@ -72,14 +95,20 @@ from config import (
|
|
|
SRC_LOG_LEVELS,
|
|
|
WEBHOOK_URL,
|
|
|
ENABLE_ADMIN_EXPORT,
|
|
|
- AppConfig,
|
|
|
WEBUI_BUILD_HASH,
|
|
|
+ TASK_MODEL,
|
|
|
+ TASK_MODEL_EXTERNAL,
|
|
|
+ TITLE_GENERATION_PROMPT_TEMPLATE,
|
|
|
+ SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
|
+ SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
|
|
|
+ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
|
OAUTH_PROVIDERS,
|
|
|
ENABLE_OAUTH_SIGNUP,
|
|
|
OAUTH_MERGE_ACCOUNTS_BY_EMAIL,
|
|
|
WEBUI_SECRET_KEY,
|
|
|
WEBUI_SESSION_COOKIE_SAME_SITE,
|
|
|
WEBUI_SESSION_COOKIE_SECURE,
|
|
|
+ AppConfig,
|
|
|
)
|
|
|
from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
|
|
from utils.webhook import post_webhook
|
|
@@ -134,27 +163,133 @@ app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
|
|
|
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
|
|
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
|
|
|
|
|
-
|
|
|
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
|
|
|
|
|
|
|
|
+app.state.config.TASK_MODEL = TASK_MODEL
|
|
|
+app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
|
|
|
+app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
|
|
|
+app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
|
|
+ SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
|
|
+)
|
|
|
+app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
|
|
|
+ SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
|
|
|
+)
|
|
|
+app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
|
|
+ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
|
|
+)
|
|
|
+
|
|
|
app.state.MODELS = {}
|
|
|
|
|
|
origins = ["*"]
|
|
|
|
|
|
-# Custom middleware to add security headers
|
|
|
-# class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
|
|
-# async def dispatch(self, request: Request, call_next):
|
|
|
-# response: Response = await call_next(request)
|
|
|
-# response.headers["Cross-Origin-Opener-Policy"] = "same-origin"
|
|
|
-# response.headers["Cross-Origin-Embedder-Policy"] = "require-corp"
|
|
|
-# return response
|
|
|
|
|
|
+async def get_function_call_response(messages, tool_id, template, task_model_id, user):
|
|
|
+ tool = Tools.get_tool_by_id(tool_id)
|
|
|
+ tools_specs = json.dumps(tool.specs, indent=2)
|
|
|
+ content = tools_function_calling_generation_template(template, tools_specs)
|
|
|
+
|
|
|
+ user_message = get_last_user_message(messages)
|
|
|
+ prompt = (
|
|
|
+ "History:\n"
|
|
|
+ + "\n".join(
|
|
|
+ [
|
|
|
+ f"{message['role'].upper()}: \"\"\"{message['content']}\"\"\""
|
|
|
+ for message in messages[::-1][:4]
|
|
|
+ ]
|
|
|
+ )
|
|
|
+ + f"\nQuery: {user_message}"
|
|
|
+ )
|
|
|
+
|
|
|
+ print(prompt)
|
|
|
+
|
|
|
+ payload = {
|
|
|
+ "model": task_model_id,
|
|
|
+ "messages": [
|
|
|
+ {"role": "system", "content": content},
|
|
|
+ {"role": "user", "content": f"Query: {prompt}"},
|
|
|
+ ],
|
|
|
+ "stream": False,
|
|
|
+ }
|
|
|
+
|
|
|
+ try:
|
|
|
+ payload = filter_pipeline(payload, user)
|
|
|
+ except Exception as e:
|
|
|
+ raise e
|
|
|
+
|
|
|
+ model = app.state.MODELS[task_model_id]
|
|
|
|
|
|
-# app.add_middleware(SecurityHeadersMiddleware)
|
|
|
+ response = None
|
|
|
+ try:
|
|
|
+ if model["owned_by"] == "ollama":
|
|
|
+ response = await generate_ollama_chat_completion(
|
|
|
+ OpenAIChatCompletionForm(**payload), user=user
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ response = await generate_openai_chat_completion(payload, user=user)
|
|
|
|
|
|
+ content = None
|
|
|
|
|
|
-class RAGMiddleware(BaseHTTPMiddleware):
|
|
|
+ if hasattr(response, "body_iterator"):
|
|
|
+ async for chunk in response.body_iterator:
|
|
|
+ data = json.loads(chunk.decode("utf-8"))
|
|
|
+ content = data["choices"][0]["message"]["content"]
|
|
|
+
|
|
|
+ # Cleanup any remaining background tasks if necessary
|
|
|
+ if response.background is not None:
|
|
|
+ await response.background()
|
|
|
+ else:
|
|
|
+ content = response["choices"][0]["message"]["content"]
|
|
|
+
|
|
|
+ # Parse the function response
|
|
|
+ if content is not None:
|
|
|
+ print(f"content: {content}")
|
|
|
+ result = json.loads(content)
|
|
|
+ print(result)
|
|
|
+
|
|
|
+ # Call the function
|
|
|
+ if "name" in result:
|
|
|
+ if tool_id in webui_app.state.TOOLS:
|
|
|
+ toolkit_module = webui_app.state.TOOLS[tool_id]
|
|
|
+ else:
|
|
|
+ toolkit_module = load_toolkit_module_by_id(tool_id)
|
|
|
+ webui_app.state.TOOLS[tool_id] = toolkit_module
|
|
|
+
|
|
|
+ function = getattr(toolkit_module, result["name"])
|
|
|
+ function_result = None
|
|
|
+ try:
|
|
|
+ # Get the signature of the function
|
|
|
+ sig = inspect.signature(function)
|
|
|
+ # Check if '__user__' is a parameter of the function
|
|
|
+ if "__user__" in sig.parameters:
|
|
|
+ # Call the function with the '__user__' parameter included
|
|
|
+ function_result = function(
|
|
|
+ **{
|
|
|
+ **result["parameters"],
|
|
|
+ "__user__": {
|
|
|
+ "id": user.id,
|
|
|
+ "email": user.email,
|
|
|
+ "name": user.name,
|
|
|
+ "role": user.role,
|
|
|
+ },
|
|
|
+ }
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ # Call the function without modifying the parameters
|
|
|
+ function_result = function(**result["parameters"])
|
|
|
+ except Exception as e:
|
|
|
+ print(e)
|
|
|
+
|
|
|
+ # Add the function result to the system prompt
|
|
|
+ if function_result:
|
|
|
+ return function_result
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error: {e}")
|
|
|
+
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
|
return_citations = False
|
|
|
|
|
@@ -171,35 +306,98 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
|
|
# Parse string to JSON
|
|
|
data = json.loads(body_str) if body_str else {}
|
|
|
|
|
|
+ user = get_current_user(
|
|
|
+ get_http_authorization_cred(request.headers.get("Authorization"))
|
|
|
+ )
|
|
|
+
|
|
|
+ # Remove the citations from the body
|
|
|
return_citations = data.get("citations", False)
|
|
|
if "citations" in data:
|
|
|
del data["citations"]
|
|
|
|
|
|
- # Example: Add a new key-value pair or modify existing ones
|
|
|
- # data["modified"] = True # Example modification
|
|
|
+ # Set the task model
|
|
|
+ task_model_id = data["model"]
|
|
|
+ if task_model_id not in app.state.MODELS:
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_404_NOT_FOUND,
|
|
|
+ detail="Model not found",
|
|
|
+ )
|
|
|
+
|
|
|
+ # Check if the user has a custom task model
|
|
|
+ # If the user has a custom task model, use that model
|
|
|
+ if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
|
|
|
+ if (
|
|
|
+ app.state.config.TASK_MODEL
|
|
|
+ and app.state.config.TASK_MODEL in app.state.MODELS
|
|
|
+ ):
|
|
|
+ task_model_id = app.state.config.TASK_MODEL
|
|
|
+ else:
|
|
|
+ if (
|
|
|
+ app.state.config.TASK_MODEL_EXTERNAL
|
|
|
+ and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
|
|
|
+ ):
|
|
|
+ task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
|
|
+
|
|
|
+ prompt = get_last_user_message(data["messages"])
|
|
|
+ context = ""
|
|
|
+
|
|
|
+ # If tool_ids field is present, call the functions
|
|
|
+ if "tool_ids" in data:
|
|
|
+ print(data["tool_ids"])
|
|
|
+ for tool_id in data["tool_ids"]:
|
|
|
+ print(tool_id)
|
|
|
+ try:
|
|
|
+ response = await get_function_call_response(
|
|
|
+ messages=data["messages"],
|
|
|
+ tool_id=tool_id,
|
|
|
+ template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
|
+ task_model_id=task_model_id,
|
|
|
+ user=user,
|
|
|
+ )
|
|
|
+
|
|
|
+ if response:
|
|
|
+ context += ("\n" if context != "" else "") + response
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error: {e}")
|
|
|
+ del data["tool_ids"]
|
|
|
+
|
|
|
+ print(f"tool_context: {context}")
|
|
|
+
|
|
|
+ # If docs field is present, generate RAG completions
|
|
|
if "docs" in data:
|
|
|
data = {**data}
|
|
|
- data["messages"], citations = rag_messages(
|
|
|
+ rag_context, citations = get_rag_context(
|
|
|
docs=data["docs"],
|
|
|
messages=data["messages"],
|
|
|
- template=rag_app.state.config.RAG_TEMPLATE,
|
|
|
embedding_function=rag_app.state.EMBEDDING_FUNCTION,
|
|
|
k=rag_app.state.config.TOP_K,
|
|
|
reranking_function=rag_app.state.sentence_transformer_rf,
|
|
|
r=rag_app.state.config.RELEVANCE_THRESHOLD,
|
|
|
hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
|
|
)
|
|
|
+
|
|
|
+ if rag_context:
|
|
|
+ context += ("\n" if context != "" else "") + rag_context
|
|
|
+
|
|
|
del data["docs"]
|
|
|
|
|
|
- log.debug(
|
|
|
- f"data['messages']: {data['messages']}, citations: {citations}"
|
|
|
+ log.debug(f"rag_context: {rag_context}, citations: {citations}")
|
|
|
+
|
|
|
+ if context != "":
|
|
|
+ system_prompt = rag_template(
|
|
|
+ rag_app.state.config.RAG_TEMPLATE, context, prompt
|
|
|
+ )
|
|
|
+
|
|
|
+ print(system_prompt)
|
|
|
+
|
|
|
+ data["messages"] = add_or_update_system_message(
|
|
|
+ f"\n{system_prompt}", data["messages"]
|
|
|
)
|
|
|
|
|
|
modified_body_bytes = json.dumps(data).encode("utf-8")
|
|
|
|
|
|
# Replace the request body with the modified one
|
|
|
request._body = modified_body_bytes
|
|
|
-
|
|
|
# Set custom header to ensure content-length matches new body length
|
|
|
request.headers.__dict__["_list"] = [
|
|
|
(b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
|
|
@@ -242,7 +440,80 @@ class RAGMiddleware(BaseHTTPMiddleware):
|
|
|
yield data
|
|
|
|
|
|
|
|
|
-app.add_middleware(RAGMiddleware)
|
|
|
+app.add_middleware(ChatCompletionMiddleware)
|
|
|
+
|
|
|
+
|
|
|
+def filter_pipeline(payload, user):
|
|
|
+ user = {"id": user.id, "name": user.name, "role": user.role}
|
|
|
+ model_id = payload["model"]
|
|
|
+ filters = [
|
|
|
+ model
|
|
|
+ for model in app.state.MODELS.values()
|
|
|
+ if "pipeline" in model
|
|
|
+ and "type" in model["pipeline"]
|
|
|
+ and model["pipeline"]["type"] == "filter"
|
|
|
+ and (
|
|
|
+ model["pipeline"]["pipelines"] == ["*"]
|
|
|
+ or any(
|
|
|
+ model_id == target_model_id
|
|
|
+ for target_model_id in model["pipeline"]["pipelines"]
|
|
|
+ )
|
|
|
+ )
|
|
|
+ ]
|
|
|
+ sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
|
|
+
|
|
|
+ model = app.state.MODELS[model_id]
|
|
|
+
|
|
|
+ if "pipeline" in model:
|
|
|
+ sorted_filters.append(model)
|
|
|
+
|
|
|
+ for filter in sorted_filters:
|
|
|
+ r = None
|
|
|
+ try:
|
|
|
+ urlIdx = filter["urlIdx"]
|
|
|
+
|
|
|
+ url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
|
|
+ key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
|
|
+
|
|
|
+ if key != "":
|
|
|
+ headers = {"Authorization": f"Bearer {key}"}
|
|
|
+ r = requests.post(
|
|
|
+ f"{url}/{filter['id']}/filter/inlet",
|
|
|
+ headers=headers,
|
|
|
+ json={
|
|
|
+ "user": user,
|
|
|
+ "body": payload,
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
+ r.raise_for_status()
|
|
|
+ payload = r.json()
|
|
|
+ except Exception as e:
|
|
|
+ # Handle connection error here
|
|
|
+ print(f"Connection error: {e}")
|
|
|
+
|
|
|
+ if r is not None:
|
|
|
+ try:
|
|
|
+ res = r.json()
|
|
|
+ except:
|
|
|
+ pass
|
|
|
+ if "detail" in res:
|
|
|
+ raise Exception(r.status_code, res["detail"])
|
|
|
+
|
|
|
+ else:
|
|
|
+ pass
|
|
|
+
|
|
|
+ if "pipeline" not in app.state.MODELS[model_id]:
|
|
|
+ if "chat_id" in payload:
|
|
|
+ del payload["chat_id"]
|
|
|
+
|
|
|
+ if "title" in payload:
|
|
|
+ del payload["title"]
|
|
|
+
|
|
|
+ if "task" in payload:
|
|
|
+ del payload["task"]
|
|
|
+
|
|
|
+ return payload
|
|
|
|
|
|
|
|
|
class PipelineMiddleware(BaseHTTPMiddleware):
|
|
@@ -260,85 +531,17 @@ class PipelineMiddleware(BaseHTTPMiddleware):
|
|
|
# Parse string to JSON
|
|
|
data = json.loads(body_str) if body_str else {}
|
|
|
|
|
|
- model_id = data["model"]
|
|
|
- filters = [
|
|
|
- model
|
|
|
- for model in app.state.MODELS.values()
|
|
|
- if "pipeline" in model
|
|
|
- and "type" in model["pipeline"]
|
|
|
- and model["pipeline"]["type"] == "filter"
|
|
|
- and (
|
|
|
- model["pipeline"]["pipelines"] == ["*"]
|
|
|
- or any(
|
|
|
- model_id == target_model_id
|
|
|
- for target_model_id in model["pipeline"]["pipelines"]
|
|
|
- )
|
|
|
- )
|
|
|
- ]
|
|
|
- sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
|
|
-
|
|
|
- user = None
|
|
|
- if len(sorted_filters) > 0:
|
|
|
- try:
|
|
|
- user = get_current_user(
|
|
|
- get_http_authorization_cred(
|
|
|
- request.headers.get("Authorization")
|
|
|
- )
|
|
|
- )
|
|
|
- user = {"id": user.id, "name": user.name, "role": user.role}
|
|
|
- except:
|
|
|
- pass
|
|
|
-
|
|
|
- model = app.state.MODELS[model_id]
|
|
|
-
|
|
|
- if "pipeline" in model:
|
|
|
- sorted_filters.append(model)
|
|
|
-
|
|
|
- for filter in sorted_filters:
|
|
|
- r = None
|
|
|
- try:
|
|
|
- urlIdx = filter["urlIdx"]
|
|
|
-
|
|
|
- url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
|
|
- key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
|
|
-
|
|
|
- if key != "":
|
|
|
- headers = {"Authorization": f"Bearer {key}"}
|
|
|
- r = requests.post(
|
|
|
- f"{url}/{filter['id']}/filter/inlet",
|
|
|
- headers=headers,
|
|
|
- json={
|
|
|
- "user": user,
|
|
|
- "body": data,
|
|
|
- },
|
|
|
- )
|
|
|
-
|
|
|
- r.raise_for_status()
|
|
|
- data = r.json()
|
|
|
- except Exception as e:
|
|
|
- # Handle connection error here
|
|
|
- print(f"Connection error: {e}")
|
|
|
-
|
|
|
- if r is not None:
|
|
|
- try:
|
|
|
- res = r.json()
|
|
|
- if "detail" in res:
|
|
|
- return JSONResponse(
|
|
|
- status_code=r.status_code,
|
|
|
- content=res,
|
|
|
- )
|
|
|
- except:
|
|
|
- pass
|
|
|
-
|
|
|
- else:
|
|
|
- pass
|
|
|
-
|
|
|
- if "pipeline" not in app.state.MODELS[model_id]:
|
|
|
- if "chat_id" in data:
|
|
|
- del data["chat_id"]
|
|
|
+ user = get_current_user(
|
|
|
+ get_http_authorization_cred(request.headers.get("Authorization"))
|
|
|
+ )
|
|
|
|
|
|
- if "title" in data:
|
|
|
- del data["title"]
|
|
|
+ try:
|
|
|
+ data = filter_pipeline(data, user)
|
|
|
+ except Exception as e:
|
|
|
+ return JSONResponse(
|
|
|
+ status_code=e.args[0],
|
|
|
+ content={"detail": e.args[1]},
|
|
|
+ )
|
|
|
|
|
|
modified_body_bytes = json.dumps(data).encode("utf-8")
|
|
|
# Replace the request body with the modified one
|
|
@@ -499,6 +702,302 @@ async def get_models(user=Depends(get_verified_user)):
|
|
|
return {"data": models}
|
|
|
|
|
|
|
|
|
+@app.get("/api/task/config")
|
|
|
+async def get_task_config(user=Depends(get_verified_user)):
|
|
|
+ return {
|
|
|
+ "TASK_MODEL": app.state.config.TASK_MODEL,
|
|
|
+ "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
|
|
|
+ "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
|
|
+ "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
|
+ "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
|
|
|
+ "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+class TaskConfigForm(BaseModel):
|
|
|
+ TASK_MODEL: Optional[str]
|
|
|
+ TASK_MODEL_EXTERNAL: Optional[str]
|
|
|
+ TITLE_GENERATION_PROMPT_TEMPLATE: str
|
|
|
+ SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
|
|
|
+ SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: int
|
|
|
+ TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
|
|
|
+
|
|
|
+
|
|
|
+@app.post("/api/task/config/update")
|
|
|
+async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_user)):
|
|
|
+ app.state.config.TASK_MODEL = form_data.TASK_MODEL
|
|
|
+ app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
|
|
|
+ app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
|
|
|
+ form_data.TITLE_GENERATION_PROMPT_TEMPLATE
|
|
|
+ )
|
|
|
+ app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
|
|
+ form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
|
|
+ )
|
|
|
+ app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
|
|
|
+ form_data.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
|
|
|
+ )
|
|
|
+ app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
|
|
+ form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
|
|
+ )
|
|
|
+
|
|
|
+ return {
|
|
|
+ "TASK_MODEL": app.state.config.TASK_MODEL,
|
|
|
+ "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
|
|
|
+ "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
|
|
+ "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
|
+ "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
|
|
|
+ "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+@app.post("/api/task/title/completions")
|
|
|
+async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
|
|
+ print("generate_title")
|
|
|
+
|
|
|
+ model_id = form_data["model"]
|
|
|
+ if model_id not in app.state.MODELS:
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_404_NOT_FOUND,
|
|
|
+ detail="Model not found",
|
|
|
+ )
|
|
|
+
|
|
|
+ # Check if the user has a custom task model
|
|
|
+ # If the user has a custom task model, use that model
|
|
|
+ if app.state.MODELS[model_id]["owned_by"] == "ollama":
|
|
|
+ if app.state.config.TASK_MODEL:
|
|
|
+ task_model_id = app.state.config.TASK_MODEL
|
|
|
+ if task_model_id in app.state.MODELS:
|
|
|
+ model_id = task_model_id
|
|
|
+ else:
|
|
|
+ if app.state.config.TASK_MODEL_EXTERNAL:
|
|
|
+ task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
|
|
+ if task_model_id in app.state.MODELS:
|
|
|
+ model_id = task_model_id
|
|
|
+
|
|
|
+ print(model_id)
|
|
|
+ model = app.state.MODELS[model_id]
|
|
|
+
|
|
|
+ template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
|
|
|
+
|
|
|
+ content = title_generation_template(
|
|
|
+ template, form_data["prompt"], user.model_dump()
|
|
|
+ )
|
|
|
+
|
|
|
+ payload = {
|
|
|
+ "model": model_id,
|
|
|
+ "messages": [{"role": "user", "content": content}],
|
|
|
+ "stream": False,
|
|
|
+ "max_tokens": 50,
|
|
|
+ "chat_id": form_data.get("chat_id", None),
|
|
|
+ "title": True,
|
|
|
+ }
|
|
|
+
|
|
|
+ print(payload)
|
|
|
+
|
|
|
+ try:
|
|
|
+ payload = filter_pipeline(payload, user)
|
|
|
+ except Exception as e:
|
|
|
+ return JSONResponse(
|
|
|
+ status_code=e.args[0],
|
|
|
+ content={"detail": e.args[1]},
|
|
|
+ )
|
|
|
+
|
|
|
+ if model["owned_by"] == "ollama":
|
|
|
+ return await generate_ollama_chat_completion(
|
|
|
+ OpenAIChatCompletionForm(**payload), user=user
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ return await generate_openai_chat_completion(payload, user=user)
|
|
|
+
|
|
|
+
|
|
|
+@app.post("/api/task/query/completions")
|
|
|
+async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
|
|
|
+ print("generate_search_query")
|
|
|
+
|
|
|
+ if len(form_data["prompt"]) < app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD:
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
+ detail=f"Skip search query generation for short prompts (< {app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD} characters)",
|
|
|
+ )
|
|
|
+
|
|
|
+ model_id = form_data["model"]
|
|
|
+ if model_id not in app.state.MODELS:
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_404_NOT_FOUND,
|
|
|
+ detail="Model not found",
|
|
|
+ )
|
|
|
+
|
|
|
+ # Check if the user has a custom task model
|
|
|
+ # If the user has a custom task model, use that model
|
|
|
+ if app.state.MODELS[model_id]["owned_by"] == "ollama":
|
|
|
+ if app.state.config.TASK_MODEL:
|
|
|
+ task_model_id = app.state.config.TASK_MODEL
|
|
|
+ if task_model_id in app.state.MODELS:
|
|
|
+ model_id = task_model_id
|
|
|
+ else:
|
|
|
+ if app.state.config.TASK_MODEL_EXTERNAL:
|
|
|
+ task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
|
|
+ if task_model_id in app.state.MODELS:
|
|
|
+ model_id = task_model_id
|
|
|
+
|
|
|
+ print(model_id)
|
|
|
+ model = app.state.MODELS[model_id]
|
|
|
+
|
|
|
+ template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
|
|
+
|
|
|
+ content = search_query_generation_template(
|
|
|
+ template, form_data["prompt"], user.model_dump()
|
|
|
+ )
|
|
|
+
|
|
|
+ payload = {
|
|
|
+ "model": model_id,
|
|
|
+ "messages": [{"role": "user", "content": content}],
|
|
|
+ "stream": False,
|
|
|
+ "max_tokens": 30,
|
|
|
+ "task": True,
|
|
|
+ }
|
|
|
+
|
|
|
+ print(payload)
|
|
|
+
|
|
|
+ try:
|
|
|
+ payload = filter_pipeline(payload, user)
|
|
|
+ except Exception as e:
|
|
|
+ return JSONResponse(
|
|
|
+ status_code=e.args[0],
|
|
|
+ content={"detail": e.args[1]},
|
|
|
+ )
|
|
|
+
|
|
|
+ if model["owned_by"] == "ollama":
|
|
|
+ return await generate_ollama_chat_completion(
|
|
|
+ OpenAIChatCompletionForm(**payload), user=user
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ return await generate_openai_chat_completion(payload, user=user)
|
|
|
+
|
|
|
+
|
|
|
+@app.post("/api/task/emoji/completions")
|
|
|
+async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
|
|
|
+ print("generate_emoji")
|
|
|
+
|
|
|
+ model_id = form_data["model"]
|
|
|
+ if model_id not in app.state.MODELS:
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_404_NOT_FOUND,
|
|
|
+ detail="Model not found",
|
|
|
+ )
|
|
|
+
|
|
|
+ # Check if the user has a custom task model
|
|
|
+ # If the user has a custom task model, use that model
|
|
|
+ if app.state.MODELS[model_id]["owned_by"] == "ollama":
|
|
|
+ if app.state.config.TASK_MODEL:
|
|
|
+ task_model_id = app.state.config.TASK_MODEL
|
|
|
+ if task_model_id in app.state.MODELS:
|
|
|
+ model_id = task_model_id
|
|
|
+ else:
|
|
|
+ if app.state.config.TASK_MODEL_EXTERNAL:
|
|
|
+ task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
|
|
+ if task_model_id in app.state.MODELS:
|
|
|
+ model_id = task_model_id
|
|
|
+
|
|
|
+ print(model_id)
|
|
|
+ model = app.state.MODELS[model_id]
|
|
|
+
|
|
|
+ template = '''
|
|
|
+Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
|
|
|
+
|
|
|
+Message: """{{prompt}}"""
|
|
|
+'''
|
|
|
+
|
|
|
+ content = title_generation_template(
|
|
|
+ template, form_data["prompt"], user.model_dump()
|
|
|
+ )
|
|
|
+
|
|
|
+ payload = {
|
|
|
+ "model": model_id,
|
|
|
+ "messages": [{"role": "user", "content": content}],
|
|
|
+ "stream": False,
|
|
|
+ "max_tokens": 4,
|
|
|
+ "chat_id": form_data.get("chat_id", None),
|
|
|
+ "task": True,
|
|
|
+ }
|
|
|
+
|
|
|
+ print(payload)
|
|
|
+
|
|
|
+ try:
|
|
|
+ payload = filter_pipeline(payload, user)
|
|
|
+ except Exception as e:
|
|
|
+ return JSONResponse(
|
|
|
+ status_code=e.args[0],
|
|
|
+ content={"detail": e.args[1]},
|
|
|
+ )
|
|
|
+
|
|
|
+ if model["owned_by"] == "ollama":
|
|
|
+ return await generate_ollama_chat_completion(
|
|
|
+ OpenAIChatCompletionForm(**payload), user=user
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ return await generate_openai_chat_completion(payload, user=user)
|
|
|
+
|
|
|
+
|
|
|
+@app.post("/api/task/tools/completions")
|
|
|
+async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)):
|
|
|
+ print("get_tools_function_calling")
|
|
|
+
|
|
|
+ model_id = form_data["model"]
|
|
|
+ if model_id not in app.state.MODELS:
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_404_NOT_FOUND,
|
|
|
+ detail="Model not found",
|
|
|
+ )
|
|
|
+
|
|
|
+ # Check if the user has a custom task model
|
|
|
+ # If the user has a custom task model, use that model
|
|
|
+ if app.state.MODELS[model_id]["owned_by"] == "ollama":
|
|
|
+ if app.state.config.TASK_MODEL:
|
|
|
+ task_model_id = app.state.config.TASK_MODEL
|
|
|
+ if task_model_id in app.state.MODELS:
|
|
|
+ model_id = task_model_id
|
|
|
+ else:
|
|
|
+ if app.state.config.TASK_MODEL_EXTERNAL:
|
|
|
+ task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
|
|
+ if task_model_id in app.state.MODELS:
|
|
|
+ model_id = task_model_id
|
|
|
+
|
|
|
+ print(model_id)
|
|
|
+ template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
|
|
+
|
|
|
+ try:
|
|
|
+ context = await get_function_call_response(
|
|
|
+ form_data["messages"], form_data["tool_id"], template, model_id, user
|
|
|
+ )
|
|
|
+ return context
|
|
|
+ except Exception as e:
|
|
|
+ return JSONResponse(
|
|
|
+ status_code=e.args[0],
|
|
|
+ content={"detail": e.args[1]},
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+@app.post("/api/chat/completions")
|
|
|
+async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
|
|
|
+ model_id = form_data["model"]
|
|
|
+ if model_id not in app.state.MODELS:
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_404_NOT_FOUND,
|
|
|
+ detail="Model not found",
|
|
|
+ )
|
|
|
+
|
|
|
+ model = app.state.MODELS[model_id]
|
|
|
+ print(model)
|
|
|
+
|
|
|
+ if model["owned_by"] == "ollama":
|
|
|
+ return await generate_ollama_chat_completion(
|
|
|
+ OpenAIChatCompletionForm(**form_data), user=user
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ return await generate_openai_chat_completion(form_data, user=user)
|
|
|
+
|
|
|
+
|
|
|
@app.post("/api/chat/completed")
|
|
|
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
|
|
data = form_data
|
|
@@ -591,6 +1090,63 @@ async def get_pipelines_list(user=Depends(get_admin_user)):
|
|
|
}
|
|
|
|
|
|
|
|
|
+@app.post("/api/pipelines/upload")
|
|
|
+async def upload_pipeline(
|
|
|
+ urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user)
|
|
|
+):
|
|
|
+ print("upload_pipeline", urlIdx, file.filename)
|
|
|
+ # Check if the uploaded file is a python file
|
|
|
+ if not file.filename.endswith(".py"):
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
+ detail="Only Python (.py) files are allowed.",
|
|
|
+ )
|
|
|
+
|
|
|
+ upload_folder = f"{CACHE_DIR}/pipelines"
|
|
|
+ os.makedirs(upload_folder, exist_ok=True)
|
|
|
+ file_path = os.path.join(upload_folder, file.filename)
|
|
|
+
|
|
|
+ try:
|
|
|
+ # Save the uploaded file
|
|
|
+ with open(file_path, "wb") as buffer:
|
|
|
+ shutil.copyfileobj(file.file, buffer)
|
|
|
+
|
|
|
+ url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
|
|
+ key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
|
|
|
+
|
|
|
+ headers = {"Authorization": f"Bearer {key}"}
|
|
|
+
|
|
|
+ with open(file_path, "rb") as f:
|
|
|
+ files = {"file": f}
|
|
|
+ r = requests.post(f"{url}/pipelines/upload", headers=headers, files=files)
|
|
|
+
|
|
|
+ r.raise_for_status()
|
|
|
+ data = r.json()
|
|
|
+
|
|
|
+ return {**data}
|
|
|
+ except Exception as e:
|
|
|
+ # Handle connection error here
|
|
|
+ print(f"Connection error: {e}")
|
|
|
+
|
|
|
+ detail = "Pipeline not found"
|
|
|
+ if r is not None:
|
|
|
+ try:
|
|
|
+ res = r.json()
|
|
|
+ if "detail" in res:
|
|
|
+ detail = res["detail"]
|
|
|
+ except:
|
|
|
+ pass
|
|
|
+
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=(r.status_code if r is not None else status.HTTP_404_NOT_FOUND),
|
|
|
+ detail=detail,
|
|
|
+ )
|
|
|
+ finally:
|
|
|
+ # Ensure the file is deleted after the upload is completed or on failure
|
|
|
+ if os.path.exists(file_path):
|
|
|
+ os.remove(file_path)
|
|
|
+
|
|
|
+
|
|
|
class AddPipelineForm(BaseModel):
|
|
|
url: str
|
|
|
urlIdx: int
|
|
@@ -857,6 +1413,15 @@ async def get_app_config():
|
|
|
"enable_community_sharing": webui_app.state.config.ENABLE_COMMUNITY_SHARING,
|
|
|
"enable_admin_export": ENABLE_ADMIN_EXPORT,
|
|
|
},
|
|
|
+ "audio": {
|
|
|
+ "tts": {
|
|
|
+ "engine": audio_app.state.config.TTS_ENGINE,
|
|
|
+ "voice": audio_app.state.config.TTS_VOICE,
|
|
|
+ },
|
|
|
+ "stt": {
|
|
|
+ "engine": audio_app.state.config.STT_ENGINE,
|
|
|
+ },
|
|
|
+ },
|
|
|
"oauth": {
|
|
|
"providers": {
|
|
|
name: config.get("name", name)
|
|
@@ -925,7 +1490,7 @@ async def get_app_changelog():
|
|
|
@app.get("/api/version/updates")
|
|
|
async def get_app_latest_release_version():
|
|
|
try:
|
|
|
- async with aiohttp.ClientSession() as session:
|
|
|
+ async with aiohttp.ClientSession(trust_env=True) as session:
|
|
|
async with session.get(
|
|
|
"https://api.github.com/repos/open-webui/open-webui/releases/latest"
|
|
|
) as response:
|