|
@@ -1,5 +1,6 @@
|
|
import base64
|
|
import base64
|
|
import uuid
|
|
import uuid
|
|
|
|
+import subprocess
|
|
from contextlib import asynccontextmanager
|
|
from contextlib import asynccontextmanager
|
|
|
|
|
|
from authlib.integrations.starlette_client import OAuth
|
|
from authlib.integrations.starlette_client import OAuth
|
|
@@ -27,6 +28,7 @@ from fastapi.responses import JSONResponse
|
|
from fastapi import HTTPException
|
|
from fastapi import HTTPException
|
|
from fastapi.middleware.wsgi import WSGIMiddleware
|
|
from fastapi.middleware.wsgi import WSGIMiddleware
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
+from sqlalchemy import text
|
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
from starlette.exceptions import HTTPException as StarletteHTTPException
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.middleware.sessions import SessionMiddleware
|
|
from starlette.middleware.sessions import SessionMiddleware
|
|
@@ -54,6 +56,7 @@ from apps.webui.main import (
|
|
get_pipe_models,
|
|
get_pipe_models,
|
|
generate_function_chat_completion,
|
|
generate_function_chat_completion,
|
|
)
|
|
)
|
|
|
|
+from apps.webui.internal.db import Session, SessionLocal
|
|
|
|
|
|
|
|
|
|
from pydantic import BaseModel
|
|
from pydantic import BaseModel
|
|
@@ -125,8 +128,10 @@ from config import (
|
|
WEBUI_SESSION_COOKIE_SAME_SITE,
|
|
WEBUI_SESSION_COOKIE_SAME_SITE,
|
|
WEBUI_SESSION_COOKIE_SECURE,
|
|
WEBUI_SESSION_COOKIE_SECURE,
|
|
AppConfig,
|
|
AppConfig,
|
|
|
|
+ BACKEND_DIR,
|
|
|
|
+ DATABASE_URL,
|
|
)
|
|
)
|
|
-from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES
|
|
|
|
|
|
+from constants import ERROR_MESSAGES, WEBHOOK_MESSAGES, TASKS
|
|
from utils.webhook import post_webhook
|
|
from utils.webhook import post_webhook
|
|
|
|
|
|
if SAFE_MODE:
|
|
if SAFE_MODE:
|
|
@@ -167,8 +172,20 @@ https://github.com/open-webui/open-webui
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
+def run_migrations():
|
|
|
|
+ try:
|
|
|
|
+ from alembic.config import Config
|
|
|
|
+ from alembic import command
|
|
|
|
+
|
|
|
|
+ alembic_cfg = Config("alembic.ini")
|
|
|
|
+ command.upgrade(alembic_cfg, "head")
|
|
|
|
+ except Exception as e:
|
|
|
|
+ print(f"Error: {e}")
|
|
|
|
+
|
|
|
|
+
|
|
@asynccontextmanager
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
async def lifespan(app: FastAPI):
|
|
|
|
+ run_migrations()
|
|
yield
|
|
yield
|
|
|
|
|
|
|
|
|
|
@@ -285,6 +302,7 @@ async def get_function_call_response(
|
|
user,
|
|
user,
|
|
model,
|
|
model,
|
|
__event_emitter__=None,
|
|
__event_emitter__=None,
|
|
|
|
+ __event_call__=None,
|
|
):
|
|
):
|
|
tool = Tools.get_tool_by_id(tool_id)
|
|
tool = Tools.get_tool_by_id(tool_id)
|
|
tools_specs = json.dumps(tool.specs, indent=2)
|
|
tools_specs = json.dumps(tool.specs, indent=2)
|
|
@@ -311,7 +329,7 @@ async def get_function_call_response(
|
|
{"role": "user", "content": f"Query: {prompt}"},
|
|
{"role": "user", "content": f"Query: {prompt}"},
|
|
],
|
|
],
|
|
"stream": False,
|
|
"stream": False,
|
|
- "function": True,
|
|
|
|
|
|
+ "task": TASKS.FUNCTION_CALLING,
|
|
}
|
|
}
|
|
|
|
|
|
try:
|
|
try:
|
|
@@ -324,7 +342,6 @@ async def get_function_call_response(
|
|
response = None
|
|
response = None
|
|
try:
|
|
try:
|
|
response = await generate_chat_completions(form_data=payload, user=user)
|
|
response = await generate_chat_completions(form_data=payload, user=user)
|
|
-
|
|
|
|
content = None
|
|
content = None
|
|
|
|
|
|
if hasattr(response, "body_iterator"):
|
|
if hasattr(response, "body_iterator"):
|
|
@@ -429,6 +446,13 @@ async def get_function_call_response(
|
|
"__event_emitter__": __event_emitter__,
|
|
"__event_emitter__": __event_emitter__,
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ if "__event_call__" in sig.parameters:
|
|
|
|
+ # Call the function with the '__event_call__' parameter included
|
|
|
|
+ params = {
|
|
|
|
+ **params,
|
|
|
|
+ "__event_call__": __event_call__,
|
|
|
|
+ }
|
|
|
|
+
|
|
if inspect.iscoroutinefunction(function):
|
|
if inspect.iscoroutinefunction(function):
|
|
function_result = await function(**params)
|
|
function_result = await function(**params)
|
|
else:
|
|
else:
|
|
@@ -452,7 +476,9 @@ async def get_function_call_response(
|
|
return None, None, False
|
|
return None, None, False
|
|
|
|
|
|
|
|
|
|
-async def chat_completion_functions_handler(body, model, user, __event_emitter__):
|
|
|
|
|
|
+async def chat_completion_functions_handler(
|
|
|
|
+ body, model, user, __event_emitter__, __event_call__
|
|
|
|
+):
|
|
skip_files = None
|
|
skip_files = None
|
|
|
|
|
|
filter_ids = get_filter_function_ids(model)
|
|
filter_ids = get_filter_function_ids(model)
|
|
@@ -518,12 +544,19 @@ async def chat_completion_functions_handler(body, model, user, __event_emitter__
|
|
**params,
|
|
**params,
|
|
"__model__": model,
|
|
"__model__": model,
|
|
}
|
|
}
|
|
|
|
+
|
|
if "__event_emitter__" in sig.parameters:
|
|
if "__event_emitter__" in sig.parameters:
|
|
params = {
|
|
params = {
|
|
**params,
|
|
**params,
|
|
"__event_emitter__": __event_emitter__,
|
|
"__event_emitter__": __event_emitter__,
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ if "__event_call__" in sig.parameters:
|
|
|
|
+ params = {
|
|
|
|
+ **params,
|
|
|
|
+ "__event_call__": __event_call__,
|
|
|
|
+ }
|
|
|
|
+
|
|
if inspect.iscoroutinefunction(inlet):
|
|
if inspect.iscoroutinefunction(inlet):
|
|
body = await inlet(**params)
|
|
body = await inlet(**params)
|
|
else:
|
|
else:
|
|
@@ -540,7 +573,9 @@ async def chat_completion_functions_handler(body, model, user, __event_emitter__
|
|
return body, {}
|
|
return body, {}
|
|
|
|
|
|
|
|
|
|
-async def chat_completion_tools_handler(body, model, user, __event_emitter__):
|
|
|
|
|
|
+async def chat_completion_tools_handler(
|
|
|
|
+ body, model, user, __event_emitter__, __event_call__
|
|
|
|
+):
|
|
skip_files = None
|
|
skip_files = None
|
|
|
|
|
|
contexts = []
|
|
contexts = []
|
|
@@ -563,6 +598,7 @@ async def chat_completion_tools_handler(body, model, user, __event_emitter__):
|
|
user=user,
|
|
user=user,
|
|
model=model,
|
|
model=model,
|
|
__event_emitter__=__event_emitter__,
|
|
__event_emitter__=__event_emitter__,
|
|
|
|
+ __event_call__=__event_call__,
|
|
)
|
|
)
|
|
|
|
|
|
print(file_handler)
|
|
print(file_handler)
|
|
@@ -660,6 +696,14 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
to=session_id,
|
|
to=session_id,
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+ async def __event_call__(data):
|
|
|
|
+ response = await sio.call(
|
|
|
|
+ "chat-events",
|
|
|
|
+ {"chat_id": chat_id, "message_id": message_id, "data": data},
|
|
|
|
+ to=session_id,
|
|
|
|
+ )
|
|
|
|
+ return response
|
|
|
|
+
|
|
# Initialize data_items to store additional data to be sent to the client
|
|
# Initialize data_items to store additional data to be sent to the client
|
|
data_items = []
|
|
data_items = []
|
|
|
|
|
|
@@ -669,7 +713,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
|
|
|
try:
|
|
try:
|
|
body, flags = await chat_completion_functions_handler(
|
|
body, flags = await chat_completion_functions_handler(
|
|
- body, model, user, __event_emitter__
|
|
|
|
|
|
+ body, model, user, __event_emitter__, __event_call__
|
|
)
|
|
)
|
|
except Exception as e:
|
|
except Exception as e:
|
|
return JSONResponse(
|
|
return JSONResponse(
|
|
@@ -679,7 +723,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
|
|
|
try:
|
|
try:
|
|
body, flags = await chat_completion_tools_handler(
|
|
body, flags = await chat_completion_tools_handler(
|
|
- body, model, user, __event_emitter__
|
|
|
|
|
|
+ body, model, user, __event_emitter__, __event_call__
|
|
)
|
|
)
|
|
|
|
|
|
contexts.extend(flags.get("contexts", []))
|
|
contexts.extend(flags.get("contexts", []))
|
|
@@ -834,9 +878,8 @@ def filter_pipeline(payload, user):
|
|
pass
|
|
pass
|
|
|
|
|
|
if "pipeline" not in app.state.MODELS[model_id]:
|
|
if "pipeline" not in app.state.MODELS[model_id]:
|
|
- for key in ["title", "task", "function"]:
|
|
|
|
- if key in payload:
|
|
|
|
- del payload[key]
|
|
|
|
|
|
+ if "task" in payload:
|
|
|
|
+ del payload["task"]
|
|
|
|
|
|
return payload
|
|
return payload
|
|
|
|
|
|
@@ -901,6 +944,14 @@ app.add_middleware(
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
+@app.middleware("http")
|
|
|
|
+async def commit_session_after_request(request: Request, call_next):
|
|
|
|
+ response = await call_next(request)
|
|
|
|
+ log.debug("Commit session after request")
|
|
|
|
+ Session.commit()
|
|
|
|
+ return response
|
|
|
|
+
|
|
|
|
+
|
|
@app.middleware("http")
|
|
@app.middleware("http")
|
|
async def check_url(request: Request, call_next):
|
|
async def check_url(request: Request, call_next):
|
|
if len(app.state.MODELS) == 0:
|
|
if len(app.state.MODELS) == 0:
|
|
@@ -977,12 +1028,16 @@ async def get_all_models():
|
|
model["info"] = custom_model.model_dump()
|
|
model["info"] = custom_model.model_dump()
|
|
else:
|
|
else:
|
|
owned_by = "openai"
|
|
owned_by = "openai"
|
|
|
|
+ pipe = None
|
|
|
|
+
|
|
for model in models:
|
|
for model in models:
|
|
if (
|
|
if (
|
|
custom_model.base_model_id == model["id"]
|
|
custom_model.base_model_id == model["id"]
|
|
or custom_model.base_model_id == model["id"].split(":")[0]
|
|
or custom_model.base_model_id == model["id"].split(":")[0]
|
|
):
|
|
):
|
|
owned_by = model["owned_by"]
|
|
owned_by = model["owned_by"]
|
|
|
|
+ if "pipe" in model:
|
|
|
|
+ pipe = model["pipe"]
|
|
break
|
|
break
|
|
|
|
|
|
models.append(
|
|
models.append(
|
|
@@ -994,11 +1049,11 @@ async def get_all_models():
|
|
"owned_by": owned_by,
|
|
"owned_by": owned_by,
|
|
"info": custom_model.model_dump(),
|
|
"info": custom_model.model_dump(),
|
|
"preset": True,
|
|
"preset": True,
|
|
|
|
+ **({"pipe": pipe} if pipe is not None else {}),
|
|
}
|
|
}
|
|
)
|
|
)
|
|
|
|
|
|
app.state.MODELS = {model["id"]: model for model in models}
|
|
app.state.MODELS = {model["id"]: model for model in models}
|
|
-
|
|
|
|
webui_app.state.MODELS = app.state.MODELS
|
|
webui_app.state.MODELS = app.state.MODELS
|
|
|
|
|
|
return models
|
|
return models
|
|
@@ -1133,6 +1188,14 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
|
to=data["session_id"],
|
|
to=data["session_id"],
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+ async def __event_call__(data):
|
|
|
|
+ response = await sio.call(
|
|
|
|
+ "chat-events",
|
|
|
|
+ {"chat_id": data["chat_id"], "message_id": data["id"], "data": data},
|
|
|
|
+ to=data["session_id"],
|
|
|
|
+ )
|
|
|
|
+ return response
|
|
|
|
+
|
|
def get_priority(function_id):
|
|
def get_priority(function_id):
|
|
function = Functions.get_function_by_id(function_id)
|
|
function = Functions.get_function_by_id(function_id)
|
|
if function is not None and hasattr(function, "valves"):
|
|
if function is not None and hasattr(function, "valves"):
|
|
@@ -1220,6 +1283,12 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
|
"__event_emitter__": __event_emitter__,
|
|
"__event_emitter__": __event_emitter__,
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ if "__event_call__" in sig.parameters:
|
|
|
|
+ params = {
|
|
|
|
+ **params,
|
|
|
|
+ "__event_call__": __event_call__,
|
|
|
|
+ }
|
|
|
|
+
|
|
if inspect.iscoroutinefunction(outlet):
|
|
if inspect.iscoroutinefunction(outlet):
|
|
data = await outlet(**params)
|
|
data = await outlet(**params)
|
|
else:
|
|
else:
|
|
@@ -1337,7 +1406,7 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
|
"stream": False,
|
|
"stream": False,
|
|
"max_tokens": 50,
|
|
"max_tokens": 50,
|
|
"chat_id": form_data.get("chat_id", None),
|
|
"chat_id": form_data.get("chat_id", None),
|
|
- "title": True,
|
|
|
|
|
|
+ "task": TASKS.TITLE_GENERATION,
|
|
}
|
|
}
|
|
|
|
|
|
log.debug(payload)
|
|
log.debug(payload)
|
|
@@ -1400,7 +1469,7 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
|
|
"messages": [{"role": "user", "content": content}],
|
|
"messages": [{"role": "user", "content": content}],
|
|
"stream": False,
|
|
"stream": False,
|
|
"max_tokens": 30,
|
|
"max_tokens": 30,
|
|
- "task": True,
|
|
|
|
|
|
+ "task": TASKS.QUERY_GENERATION,
|
|
}
|
|
}
|
|
|
|
|
|
print(payload)
|
|
print(payload)
|
|
@@ -1467,7 +1536,7 @@ Message: """{{prompt}}"""
|
|
"stream": False,
|
|
"stream": False,
|
|
"max_tokens": 4,
|
|
"max_tokens": 4,
|
|
"chat_id": form_data.get("chat_id", None),
|
|
"chat_id": form_data.get("chat_id", None),
|
|
- "task": True,
|
|
|
|
|
|
+ "task": TASKS.EMOJI_GENERATION,
|
|
}
|
|
}
|
|
|
|
|
|
log.debug(payload)
|
|
log.debug(payload)
|
|
@@ -1742,7 +1811,9 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use
|
|
|
|
|
|
@app.get("/api/pipelines/{pipeline_id}/valves")
|
|
@app.get("/api/pipelines/{pipeline_id}/valves")
|
|
async def get_pipeline_valves(
|
|
async def get_pipeline_valves(
|
|
- urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
|
|
|
|
|
|
+ urlIdx: Optional[int],
|
|
|
|
+ pipeline_id: str,
|
|
|
|
+ user=Depends(get_admin_user),
|
|
):
|
|
):
|
|
models = await get_all_models()
|
|
models = await get_all_models()
|
|
r = None
|
|
r = None
|
|
@@ -1780,7 +1851,9 @@ async def get_pipeline_valves(
|
|
|
|
|
|
@app.get("/api/pipelines/{pipeline_id}/valves/spec")
|
|
@app.get("/api/pipelines/{pipeline_id}/valves/spec")
|
|
async def get_pipeline_valves_spec(
|
|
async def get_pipeline_valves_spec(
|
|
- urlIdx: Optional[int], pipeline_id: str, user=Depends(get_admin_user)
|
|
|
|
|
|
+ urlIdx: Optional[int],
|
|
|
|
+ pipeline_id: str,
|
|
|
|
+ user=Depends(get_admin_user),
|
|
):
|
|
):
|
|
models = await get_all_models()
|
|
models = await get_all_models()
|
|
|
|
|
|
@@ -2066,7 +2139,8 @@ async def oauth_callback(provider: str, request: Request, response: Response):
|
|
if existing_user:
|
|
if existing_user:
|
|
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
|
raise HTTPException(400, detail=ERROR_MESSAGES.EMAIL_TAKEN)
|
|
|
|
|
|
- picture_url = user_data.get("picture", "")
|
|
|
|
|
|
+ picture_claim = webui_app.state.config.OAUTH_PICTURE_CLAIM
|
|
|
|
+ picture_url = user_data.get(picture_claim, "")
|
|
if picture_url:
|
|
if picture_url:
|
|
# Download the profile image into a base64 string
|
|
# Download the profile image into a base64 string
|
|
try:
|
|
try:
|
|
@@ -2086,6 +2160,7 @@ async def oauth_callback(provider: str, request: Request, response: Response):
|
|
picture_url = ""
|
|
picture_url = ""
|
|
if not picture_url:
|
|
if not picture_url:
|
|
picture_url = "/user.png"
|
|
picture_url = "/user.png"
|
|
|
|
+ username_claim = webui_app.state.config.OAUTH_USERNAME_CLAIM
|
|
role = (
|
|
role = (
|
|
"admin"
|
|
"admin"
|
|
if Users.get_num_users() == 0
|
|
if Users.get_num_users() == 0
|
|
@@ -2096,7 +2171,7 @@ async def oauth_callback(provider: str, request: Request, response: Response):
|
|
password=get_password_hash(
|
|
password=get_password_hash(
|
|
str(uuid.uuid4())
|
|
str(uuid.uuid4())
|
|
), # Random password, not used
|
|
), # Random password, not used
|
|
- name=user_data.get("name", "User"),
|
|
|
|
|
|
+ name=user_data.get(username_claim, "User"),
|
|
profile_image_url=picture_url,
|
|
profile_image_url=picture_url,
|
|
role=role,
|
|
role=role,
|
|
oauth_sub=provider_sub,
|
|
oauth_sub=provider_sub,
|
|
@@ -2154,7 +2229,7 @@ async def get_opensearch_xml():
|
|
<ShortName>{WEBUI_NAME}</ShortName>
|
|
<ShortName>{WEBUI_NAME}</ShortName>
|
|
<Description>Search {WEBUI_NAME}</Description>
|
|
<Description>Search {WEBUI_NAME}</Description>
|
|
<InputEncoding>UTF-8</InputEncoding>
|
|
<InputEncoding>UTF-8</InputEncoding>
|
|
- <Image width="16" height="16" type="image/x-icon">{WEBUI_URL}/favicon.png</Image>
|
|
|
|
|
|
+ <Image width="16" height="16" type="image/x-icon">{WEBUI_URL}/static/favicon.png</Image>
|
|
<Url type="text/html" method="get" template="{WEBUI_URL}/?q={"{searchTerms}"}"/>
|
|
<Url type="text/html" method="get" template="{WEBUI_URL}/?q={"{searchTerms}"}"/>
|
|
<moz:SearchForm>{WEBUI_URL}</moz:SearchForm>
|
|
<moz:SearchForm>{WEBUI_URL}</moz:SearchForm>
|
|
</OpenSearchDescription>
|
|
</OpenSearchDescription>
|
|
@@ -2167,6 +2242,12 @@ async def healthcheck():
|
|
return {"status": True}
|
|
return {"status": True}
|
|
|
|
|
|
|
|
|
|
|
|
+@app.get("/health/db")
|
|
|
|
+async def healthcheck_with_db():
|
|
|
|
+ Session.execute(text("SELECT 1;")).all()
|
|
|
|
+ return {"status": True}
|
|
|
|
+
|
|
|
|
+
|
|
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
|
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
|
app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
|
|
app.mount("/cache", StaticFiles(directory=CACHE_DIR), name="cache")
|
|
|
|
|