|
@@ -208,8 +208,6 @@ app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
|
app.state.MODELS = {}
|
|
app.state.MODELS = {}
|
|
|
|
|
|
|
|
|
|
-
|
|
|
|
-
|
|
|
|
##################################
|
|
##################################
|
|
#
|
|
#
|
|
# ChatCompletion Middleware
|
|
# ChatCompletion Middleware
|
|
@@ -223,14 +221,14 @@ def get_task_model_id(default_model_id):
|
|
# Check if the user has a custom task model and use that model
|
|
# Check if the user has a custom task model and use that model
|
|
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
|
|
if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
|
|
if (
|
|
if (
|
|
- app.state.config.TASK_MODEL
|
|
|
|
- and app.state.config.TASK_MODEL in app.state.MODELS
|
|
|
|
|
|
+ app.state.config.TASK_MODEL
|
|
|
|
+ and app.state.config.TASK_MODEL in app.state.MODELS
|
|
):
|
|
):
|
|
task_model_id = app.state.config.TASK_MODEL
|
|
task_model_id = app.state.config.TASK_MODEL
|
|
else:
|
|
else:
|
|
if (
|
|
if (
|
|
- app.state.config.TASK_MODEL_EXTERNAL
|
|
|
|
- and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
|
|
|
|
|
|
+ app.state.config.TASK_MODEL_EXTERNAL
|
|
|
|
+ and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
|
|
):
|
|
):
|
|
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
|
task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
|
|
|
|
|
@@ -367,7 +365,7 @@ async def get_content_from_response(response) -> Optional[str]:
|
|
|
|
|
|
|
|
|
|
async def chat_completion_tools_handler(
|
|
async def chat_completion_tools_handler(
|
|
- body: dict, user: UserModel, extra_params: dict
|
|
|
|
|
|
+ body: dict, user: UserModel, extra_params: dict
|
|
) -> tuple[dict, dict]:
|
|
) -> tuple[dict, dict]:
|
|
# If tool_ids field is present, call the functions
|
|
# If tool_ids field is present, call the functions
|
|
metadata = body.get("metadata", {})
|
|
metadata = body.get("metadata", {})
|
|
@@ -681,15 +679,15 @@ def get_sorted_filters(model_id):
|
|
model
|
|
model
|
|
for model in app.state.MODELS.values()
|
|
for model in app.state.MODELS.values()
|
|
if "pipeline" in model
|
|
if "pipeline" in model
|
|
- and "type" in model["pipeline"]
|
|
|
|
- and model["pipeline"]["type"] == "filter"
|
|
|
|
- and (
|
|
|
|
- model["pipeline"]["pipelines"] == ["*"]
|
|
|
|
- or any(
|
|
|
|
- model_id == target_model_id
|
|
|
|
- for target_model_id in model["pipeline"]["pipelines"]
|
|
|
|
- )
|
|
|
|
- )
|
|
|
|
|
|
+ and "type" in model["pipeline"]
|
|
|
|
+ and model["pipeline"]["type"] == "filter"
|
|
|
|
+ and (
|
|
|
|
+ model["pipeline"]["pipelines"] == ["*"]
|
|
|
|
+ or any(
|
|
|
|
+ model_id == target_model_id
|
|
|
|
+ for target_model_id in model["pipeline"]["pipelines"]
|
|
|
|
+ )
|
|
|
|
+ )
|
|
]
|
|
]
|
|
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
|
sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
|
return sorted_filters
|
|
return sorted_filters
|
|
@@ -875,8 +873,8 @@ async def update_embedding_function(request: Request, call_next):
|
|
@app.middleware("http")
|
|
@app.middleware("http")
|
|
async def inspect_websocket(request: Request, call_next):
|
|
async def inspect_websocket(request: Request, call_next):
|
|
if (
|
|
if (
|
|
- "/ws/socket.io" in request.url.path
|
|
|
|
- and request.query_params.get("transport") == "websocket"
|
|
|
|
|
|
+ "/ws/socket.io" in request.url.path
|
|
|
|
+ and request.query_params.get("transport") == "websocket"
|
|
):
|
|
):
|
|
upgrade = (request.headers.get("Upgrade") or "").lower()
|
|
upgrade = (request.headers.get("Upgrade") or "").lower()
|
|
connection = (request.headers.get("Connection") or "").lower().split(",")
|
|
connection = (request.headers.get("Connection") or "").lower().split(",")
|
|
@@ -945,8 +943,8 @@ async def get_all_models():
|
|
if custom_model.base_model_id is None:
|
|
if custom_model.base_model_id is None:
|
|
for model in models:
|
|
for model in models:
|
|
if (
|
|
if (
|
|
- custom_model.id == model["id"]
|
|
|
|
- or custom_model.id == model["id"].split(":")[0]
|
|
|
|
|
|
+ custom_model.id == model["id"]
|
|
|
|
+ or custom_model.id == model["id"].split(":")[0]
|
|
):
|
|
):
|
|
model["name"] = custom_model.name
|
|
model["name"] = custom_model.name
|
|
model["info"] = custom_model.model_dump()
|
|
model["info"] = custom_model.model_dump()
|
|
@@ -963,8 +961,8 @@ async def get_all_models():
|
|
|
|
|
|
for model in models:
|
|
for model in models:
|
|
if (
|
|
if (
|
|
- custom_model.base_model_id == model["id"]
|
|
|
|
- or custom_model.base_model_id == model["id"].split(":")[0]
|
|
|
|
|
|
+ custom_model.base_model_id == model["id"]
|
|
|
|
+ or custom_model.base_model_id == model["id"].split(":")[0]
|
|
):
|
|
):
|
|
owned_by = model["owned_by"]
|
|
owned_by = model["owned_by"]
|
|
if "pipe" in model:
|
|
if "pipe" in model:
|
|
@@ -1840,7 +1838,7 @@ async def get_pipelines_list(user=Depends(get_admin_user)):
|
|
|
|
|
|
@app.post("/api/pipelines/upload")
|
|
@app.post("/api/pipelines/upload")
|
|
async def upload_pipeline(
|
|
async def upload_pipeline(
|
|
- urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user)
|
|
|
|
|
|
+ urlIdx: int = Form(...), file: UploadFile = File(...), user=Depends(get_admin_user)
|
|
):
|
|
):
|
|
print("upload_pipeline", urlIdx, file.filename)
|
|
print("upload_pipeline", urlIdx, file.filename)
|
|
# Check if the uploaded file is a python file
|
|
# Check if the uploaded file is a python file
|
|
@@ -2017,9 +2015,9 @@ async def get_pipelines(urlIdx: Optional[int] = None, user=Depends(get_admin_use
|
|
|
|
|
|
@app.get("/api/pipelines/{pipeline_id}/valves")
|
|
@app.get("/api/pipelines/{pipeline_id}/valves")
|
|
async def get_pipeline_valves(
|
|
async def get_pipeline_valves(
|
|
- urlIdx: Optional[int],
|
|
|
|
- pipeline_id: str,
|
|
|
|
- user=Depends(get_admin_user),
|
|
|
|
|
|
+ urlIdx: Optional[int],
|
|
|
|
+ pipeline_id: str,
|
|
|
|
+ user=Depends(get_admin_user),
|
|
):
|
|
):
|
|
r = None
|
|
r = None
|
|
try:
|
|
try:
|
|
@@ -2055,9 +2053,9 @@ async def get_pipeline_valves(
|
|
|
|
|
|
@app.get("/api/pipelines/{pipeline_id}/valves/spec")
|
|
@app.get("/api/pipelines/{pipeline_id}/valves/spec")
|
|
async def get_pipeline_valves_spec(
|
|
async def get_pipeline_valves_spec(
|
|
- urlIdx: Optional[int],
|
|
|
|
- pipeline_id: str,
|
|
|
|
- user=Depends(get_admin_user),
|
|
|
|
|
|
+ urlIdx: Optional[int],
|
|
|
|
+ pipeline_id: str,
|
|
|
|
+ user=Depends(get_admin_user),
|
|
):
|
|
):
|
|
r = None
|
|
r = None
|
|
try:
|
|
try:
|
|
@@ -2092,10 +2090,10 @@ async def get_pipeline_valves_spec(
|
|
|
|
|
|
@app.post("/api/pipelines/{pipeline_id}/valves/update")
|
|
@app.post("/api/pipelines/{pipeline_id}/valves/update")
|
|
async def update_pipeline_valves(
|
|
async def update_pipeline_valves(
|
|
- urlIdx: Optional[int],
|
|
|
|
- pipeline_id: str,
|
|
|
|
- form_data: dict,
|
|
|
|
- user=Depends(get_admin_user),
|
|
|
|
|
|
+ urlIdx: Optional[int],
|
|
|
|
+ pipeline_id: str,
|
|
|
|
+ form_data: dict,
|
|
|
|
+ user=Depends(get_admin_user),
|
|
):
|
|
):
|
|
r = None
|
|
r = None
|
|
try:
|
|
try:
|
|
@@ -2219,7 +2217,7 @@ class ModelFilterConfigForm(BaseModel):
|
|
|
|
|
|
@app.post("/api/config/model/filter")
|
|
@app.post("/api/config/model/filter")
|
|
async def update_model_filter_config(
|
|
async def update_model_filter_config(
|
|
- form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
|
|
|
|
|
|
+ form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
|
|
):
|
|
):
|
|
app.state.config.ENABLE_MODEL_FILTER = form_data.enabled
|
|
app.state.config.ENABLE_MODEL_FILTER = form_data.enabled
|
|
app.state.config.MODEL_FILTER_LIST = form_data.models
|
|
app.state.config.MODEL_FILTER_LIST = form_data.models
|
|
@@ -2274,7 +2272,7 @@ async def get_app_latest_release_version():
|
|
timeout = aiohttp.ClientTimeout(total=1)
|
|
timeout = aiohttp.ClientTimeout(total=1)
|
|
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
|
async with aiohttp.ClientSession(timeout=timeout, trust_env=True) as session:
|
|
async with session.get(
|
|
async with session.get(
|
|
- "https://api.github.com/repos/open-webui/open-webui/releases/latest"
|
|
|
|
|
|
+ "https://api.github.com/repos/open-webui/open-webui/releases/latest"
|
|
) as response:
|
|
) as response:
|
|
response.raise_for_status()
|
|
response.raise_for_status()
|
|
data = await response.json()
|
|
data = await response.json()
|