|
@@ -75,6 +75,11 @@ from open_webui.routers.retrieval import (
|
|
get_ef,
|
|
get_ef,
|
|
get_rf,
|
|
get_rf,
|
|
)
|
|
)
|
|
|
|
+from open_webui.routers.pipelines import (
|
|
|
|
+ process_pipeline_inlet_filter,
|
|
|
|
+ process_pipeline_outlet_filter,
|
|
|
|
+)
|
|
|
|
+
|
|
from open_webui.retrieval.utils import get_sources_from_files
|
|
from open_webui.retrieval.utils import get_sources_from_files
|
|
|
|
|
|
|
|
|
|
@@ -290,6 +295,7 @@ from open_webui.utils.response import (
|
|
)
|
|
)
|
|
|
|
|
|
from open_webui.utils.task import (
|
|
from open_webui.utils.task import (
|
|
|
|
+ get_task_model_id,
|
|
rag_template,
|
|
rag_template,
|
|
tools_function_calling_generation_template,
|
|
tools_function_calling_generation_template,
|
|
)
|
|
)
|
|
@@ -662,34 +668,35 @@ app.state.MODELS = {}
|
|
##################################
|
|
##################################
|
|
|
|
|
|
|
|
|
|
-def get_filter_function_ids(model):
|
|
|
|
- def get_priority(function_id):
|
|
|
|
- function = Functions.get_function_by_id(function_id)
|
|
|
|
- if function is not None and hasattr(function, "valves"):
|
|
|
|
- # TODO: Fix FunctionModel
|
|
|
|
- return (function.valves if function.valves else {}).get("priority", 0)
|
|
|
|
- return 0
|
|
|
|
-
|
|
|
|
- filter_ids = [function.id for function in Functions.get_global_filter_functions()]
|
|
|
|
- if "info" in model and "meta" in model["info"]:
|
|
|
|
- filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
|
|
|
- filter_ids = list(set(filter_ids))
|
|
|
|
|
|
+async def chat_completion_filter_functions_handler(body, model, extra_params):
|
|
|
|
+ skip_files = None
|
|
|
|
|
|
- enabled_filter_ids = [
|
|
|
|
- function.id
|
|
|
|
- for function in Functions.get_functions_by_type("filter", active_only=True)
|
|
|
|
- ]
|
|
|
|
|
|
+ def get_filter_function_ids(model):
|
|
|
|
+ def get_priority(function_id):
|
|
|
|
+ function = Functions.get_function_by_id(function_id)
|
|
|
|
+ if function is not None and hasattr(function, "valves"):
|
|
|
|
+ # TODO: Fix FunctionModel
|
|
|
|
+ return (function.valves if function.valves else {}).get("priority", 0)
|
|
|
|
+ return 0
|
|
|
|
|
|
- filter_ids = [
|
|
|
|
- filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
|
|
|
|
- ]
|
|
|
|
|
|
+ filter_ids = [
|
|
|
|
+ function.id for function in Functions.get_global_filter_functions()
|
|
|
|
+ ]
|
|
|
|
+ if "info" in model and "meta" in model["info"]:
|
|
|
|
+ filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
|
|
|
+ filter_ids = list(set(filter_ids))
|
|
|
|
|
|
- filter_ids.sort(key=get_priority)
|
|
|
|
- return filter_ids
|
|
|
|
|
|
+ enabled_filter_ids = [
|
|
|
|
+ function.id
|
|
|
|
+ for function in Functions.get_functions_by_type("filter", active_only=True)
|
|
|
|
+ ]
|
|
|
|
|
|
|
|
+ filter_ids = [
|
|
|
|
+ filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
|
|
|
|
+ ]
|
|
|
|
|
|
-async def chat_completion_filter_functions_handler(body, model, extra_params):
|
|
|
|
- skip_files = None
|
|
|
|
|
|
+ filter_ids.sort(key=get_priority)
|
|
|
|
+ return filter_ids
|
|
|
|
|
|
filter_ids = get_filter_function_ids(model)
|
|
filter_ids = get_filter_function_ids(model)
|
|
for filter_id in filter_ids:
|
|
for filter_id in filter_ids:
|
|
@@ -791,22 +798,6 @@ async def get_content_from_response(response) -> Optional[str]:
|
|
return content
|
|
return content
|
|
|
|
|
|
|
|
|
|
-def get_task_model_id(
|
|
|
|
- default_model_id: str, task_model: str, task_model_external: str, models
|
|
|
|
-) -> str:
|
|
|
|
- # Set the task model
|
|
|
|
- task_model_id = default_model_id
|
|
|
|
- # Check if the user has a custom task model and use that model
|
|
|
|
- if models[task_model_id]["owned_by"] == "ollama":
|
|
|
|
- if task_model and task_model in models:
|
|
|
|
- task_model_id = task_model
|
|
|
|
- else:
|
|
|
|
- if task_model_external and task_model_external in models:
|
|
|
|
- task_model_id = task_model_external
|
|
|
|
-
|
|
|
|
- return task_model_id
|
|
|
|
-
|
|
|
|
-
|
|
|
|
async def chat_completion_tools_handler(
|
|
async def chat_completion_tools_handler(
|
|
body: dict, user: UserModel, models, extra_params: dict
|
|
body: dict, user: UserModel, models, extra_params: dict
|
|
) -> tuple[dict, dict]:
|
|
) -> tuple[dict, dict]:
|
|
@@ -857,7 +848,7 @@ async def chat_completion_tools_handler(
|
|
)
|
|
)
|
|
|
|
|
|
try:
|
|
try:
|
|
- payload = filter_pipeline(payload, user, models)
|
|
|
|
|
|
+ payload = process_pipeline_inlet_filter(request, payload, user, models)
|
|
except Exception as e:
|
|
except Exception as e:
|
|
raise e
|
|
raise e
|
|
|
|
|
|
@@ -1153,7 +1144,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
if prompt is None:
|
|
if prompt is None:
|
|
raise Exception("No user message found")
|
|
raise Exception("No user message found")
|
|
if (
|
|
if (
|
|
- retrieval_app.state.config.RELEVANCE_THRESHOLD == 0
|
|
|
|
|
|
+ app.state.config.RELEVANCE_THRESHOLD == 0
|
|
and context_string.strip() == ""
|
|
and context_string.strip() == ""
|
|
):
|
|
):
|
|
log.debug(
|
|
log.debug(
|
|
@@ -1164,16 +1155,12 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
# TODO: replace with add_or_update_system_message
|
|
# TODO: replace with add_or_update_system_message
|
|
if model["owned_by"] == "ollama":
|
|
if model["owned_by"] == "ollama":
|
|
body["messages"] = prepend_to_first_user_message_content(
|
|
body["messages"] = prepend_to_first_user_message_content(
|
|
- rag_template(
|
|
|
|
- retrieval_app.state.config.RAG_TEMPLATE, context_string, prompt
|
|
|
|
- ),
|
|
|
|
|
|
+ rag_template(app.state.config.RAG_TEMPLATE, context_string, prompt),
|
|
body["messages"],
|
|
body["messages"],
|
|
)
|
|
)
|
|
else:
|
|
else:
|
|
body["messages"] = add_or_update_system_message(
|
|
body["messages"] = add_or_update_system_message(
|
|
- rag_template(
|
|
|
|
- retrieval_app.state.config.RAG_TEMPLATE, context_string, prompt
|
|
|
|
- ),
|
|
|
|
|
|
+ rag_template(app.state.config.RAG_TEMPLATE, context_string, prompt),
|
|
body["messages"],
|
|
body["messages"],
|
|
)
|
|
)
|
|
|
|
|
|
@@ -1225,77 +1212,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
app.add_middleware(ChatCompletionMiddleware)
|
|
app.add_middleware(ChatCompletionMiddleware)
|
|
|
|
|
|
|
|
|
|
-##################################
|
|
|
|
-#
|
|
|
|
-# Pipeline Middleware
|
|
|
|
-#
|
|
|
|
-##################################
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-def get_sorted_filters(model_id, models):
|
|
|
|
- filters = [
|
|
|
|
- model
|
|
|
|
- for model in models.values()
|
|
|
|
- if "pipeline" in model
|
|
|
|
- and "type" in model["pipeline"]
|
|
|
|
- and model["pipeline"]["type"] == "filter"
|
|
|
|
- and (
|
|
|
|
- model["pipeline"]["pipelines"] == ["*"]
|
|
|
|
- or any(
|
|
|
|
- model_id == target_model_id
|
|
|
|
- for target_model_id in model["pipeline"]["pipelines"]
|
|
|
|
- )
|
|
|
|
- )
|
|
|
|
- ]
|
|
|
|
- sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
|
|
|
|
- return sorted_filters
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-def filter_pipeline(payload, user, models):
|
|
|
|
- user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
|
|
|
- model_id = payload["model"]
|
|
|
|
-
|
|
|
|
- sorted_filters = get_sorted_filters(model_id, models)
|
|
|
|
- model = models[model_id]
|
|
|
|
-
|
|
|
|
- if "pipeline" in model:
|
|
|
|
- sorted_filters.append(model)
|
|
|
|
-
|
|
|
|
- for filter in sorted_filters:
|
|
|
|
- r = None
|
|
|
|
- try:
|
|
|
|
- urlIdx = filter["urlIdx"]
|
|
|
|
-
|
|
|
|
- url = app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
|
|
|
- key = app.state.config.OPENAI_API_KEYS[urlIdx]
|
|
|
|
-
|
|
|
|
- if key == "":
|
|
|
|
- continue
|
|
|
|
-
|
|
|
|
- headers = {"Authorization": f"Bearer {key}"}
|
|
|
|
- r = requests.post(
|
|
|
|
- f"{url}/{filter['id']}/filter/inlet",
|
|
|
|
- headers=headers,
|
|
|
|
- json={
|
|
|
|
- "user": user,
|
|
|
|
- "body": payload,
|
|
|
|
- },
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- r.raise_for_status()
|
|
|
|
- payload = r.json()
|
|
|
|
- except Exception as e:
|
|
|
|
- # Handle connection error here
|
|
|
|
- print(f"Connection error: {e}")
|
|
|
|
-
|
|
|
|
- if r is not None:
|
|
|
|
- res = r.json()
|
|
|
|
- if "detail" in res:
|
|
|
|
- raise Exception(r.status_code, res["detail"])
|
|
|
|
-
|
|
|
|
- return payload
|
|
|
|
-
|
|
|
|
-
|
|
|
|
class PipelineMiddleware(BaseHTTPMiddleware):
|
|
class PipelineMiddleware(BaseHTTPMiddleware):
|
|
async def dispatch(self, request: Request, call_next):
|
|
async def dispatch(self, request: Request, call_next):
|
|
if not request.method == "POST" and any(
|
|
if not request.method == "POST" and any(
|
|
@@ -1335,11 +1251,11 @@ class PipelineMiddleware(BaseHTTPMiddleware):
|
|
content={"detail": e.detail},
|
|
content={"detail": e.detail},
|
|
)
|
|
)
|
|
|
|
|
|
- model_list = await get_all_models()
|
|
|
|
- models = {model["id"]: model for model in model_list}
|
|
|
|
|
|
+ await get_all_models()
|
|
|
|
+ models = app.state.MODELS
|
|
|
|
|
|
try:
|
|
try:
|
|
- data = filter_pipeline(data, user, models)
|
|
|
|
|
|
+ data = process_pipeline_inlet_filter(request, data, user, models)
|
|
except Exception as e:
|
|
except Exception as e:
|
|
if len(e.args) > 1:
|
|
if len(e.args) > 1:
|
|
return JSONResponse(
|
|
return JSONResponse(
|
|
@@ -1447,8 +1363,8 @@ app.include_router(ollama.router, prefix="/ollama", tags=["ollama"])
|
|
app.include_router(openai.router, prefix="/openai", tags=["openai"])
|
|
app.include_router(openai.router, prefix="/openai", tags=["openai"])
|
|
|
|
|
|
|
|
|
|
-app.include_router(pipelines.router, prefix="/pipelines", tags=["pipelines"])
|
|
|
|
-app.include_router(tasks.router, prefix="/tasks", tags=["tasks"])
|
|
|
|
|
|
+app.include_router(pipelines.router, prefix="/api/pipelines", tags=["pipelines"])
|
|
|
|
+app.include_router(tasks.router, prefix="/api/tasks", tags=["tasks"])
|
|
|
|
|
|
|
|
|
|
app.include_router(images.router, prefix="/api/v1/images", tags=["images"])
|
|
app.include_router(images.router, prefix="/api/v1/images", tags=["images"])
|
|
@@ -2105,7 +2021,6 @@ async def generate_chat_completions(
|
|
if model["owned_by"] == "ollama":
|
|
if model["owned_by"] == "ollama":
|
|
# Using /ollama/api/chat endpoint
|
|
# Using /ollama/api/chat endpoint
|
|
form_data = convert_payload_openai_to_ollama(form_data)
|
|
form_data = convert_payload_openai_to_ollama(form_data)
|
|
- form_data = GenerateChatCompletionForm(**form_data)
|
|
|
|
response = await generate_ollama_chat_completion(
|
|
response = await generate_ollama_chat_completion(
|
|
form_data=form_data, user=user, bypass_filter=bypass_filter
|
|
form_data=form_data, user=user, bypass_filter=bypass_filter
|
|
)
|
|
)
|
|
@@ -2124,7 +2039,9 @@ async def generate_chat_completions(
|
|
|
|
|
|
|
|
|
|
@app.post("/api/chat/completed")
|
|
@app.post("/api/chat/completed")
|
|
-async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
|
|
|
|
|
+async def chat_completed(
|
|
|
|
+ request: Request, form_data: dict, user=Depends(get_verified_user)
|
|
|
|
+):
|
|
model_list = await get_all_models()
|
|
model_list = await get_all_models()
|
|
models = {model["id"]: model for model in model_list}
|
|
models = {model["id"]: model for model in model_list}
|
|
|
|
|
|
@@ -2137,53 +2054,14 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
|
)
|
|
)
|
|
|
|
|
|
model = models[model_id]
|
|
model = models[model_id]
|
|
- sorted_filters = get_sorted_filters(model_id, models)
|
|
|
|
- if "pipeline" in model:
|
|
|
|
- sorted_filters = [model] + sorted_filters
|
|
|
|
-
|
|
|
|
- for filter in sorted_filters:
|
|
|
|
- r = None
|
|
|
|
- try:
|
|
|
|
- urlIdx = filter["urlIdx"]
|
|
|
|
-
|
|
|
|
- url = app.state.config.OPENAI_API_BASE_URLS[urlIdx]
|
|
|
|
- key = app.state.config.OPENAI_API_KEYS[urlIdx]
|
|
|
|
-
|
|
|
|
- if key != "":
|
|
|
|
- headers = {"Authorization": f"Bearer {key}"}
|
|
|
|
- r = requests.post(
|
|
|
|
- f"{url}/{filter['id']}/filter/outlet",
|
|
|
|
- headers=headers,
|
|
|
|
- json={
|
|
|
|
- "user": {
|
|
|
|
- "id": user.id,
|
|
|
|
- "name": user.name,
|
|
|
|
- "email": user.email,
|
|
|
|
- "role": user.role,
|
|
|
|
- },
|
|
|
|
- "body": data,
|
|
|
|
- },
|
|
|
|
- )
|
|
|
|
|
|
|
|
- r.raise_for_status()
|
|
|
|
- data = r.json()
|
|
|
|
- except Exception as e:
|
|
|
|
- # Handle connection error here
|
|
|
|
- print(f"Connection error: {e}")
|
|
|
|
-
|
|
|
|
- if r is not None:
|
|
|
|
- try:
|
|
|
|
- res = r.json()
|
|
|
|
- if "detail" in res:
|
|
|
|
- return JSONResponse(
|
|
|
|
- status_code=r.status_code,
|
|
|
|
- content=res,
|
|
|
|
- )
|
|
|
|
- except Exception:
|
|
|
|
- pass
|
|
|
|
-
|
|
|
|
- else:
|
|
|
|
- pass
|
|
|
|
|
|
+ try:
|
|
|
|
+ data = process_pipeline_outlet_filter(request, data, user, models)
|
|
|
|
+ except Exception as e:
|
|
|
|
+ return HTTPException(
|
|
|
|
+ status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
|
+ detail=str(e),
|
|
|
|
+ )
|
|
|
|
|
|
__event_emitter__ = get_event_emitter(
|
|
__event_emitter__ = get_event_emitter(
|
|
{
|
|
{
|
|
@@ -2455,8 +2333,8 @@ async def get_app_config(request: Request):
|
|
"enable_login_form": app.state.config.ENABLE_LOGIN_FORM,
|
|
"enable_login_form": app.state.config.ENABLE_LOGIN_FORM,
|
|
**(
|
|
**(
|
|
{
|
|
{
|
|
- "enable_web_search": retrieval_app.state.config.ENABLE_RAG_WEB_SEARCH,
|
|
|
|
- "enable_image_generation": images_app.state.config.ENABLED,
|
|
|
|
|
|
+ "enable_web_search": app.state.config.ENABLE_RAG_WEB_SEARCH,
|
|
|
|
+ "enable_image_generation": app.state.config.ENABLE_IMAGE_GENERATION,
|
|
"enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING,
|
|
"enable_community_sharing": app.state.config.ENABLE_COMMUNITY_SHARING,
|
|
"enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING,
|
|
"enable_message_rating": app.state.config.ENABLE_MESSAGE_RATING,
|
|
"enable_admin_export": ENABLE_ADMIN_EXPORT,
|
|
"enable_admin_export": ENABLE_ADMIN_EXPORT,
|
|
@@ -2472,17 +2350,17 @@ async def get_app_config(request: Request):
|
|
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
|
|
"default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
|
|
"audio": {
|
|
"audio": {
|
|
"tts": {
|
|
"tts": {
|
|
- "engine": audio_app.state.config.TTS_ENGINE,
|
|
|
|
- "voice": audio_app.state.config.TTS_VOICE,
|
|
|
|
- "split_on": audio_app.state.config.TTS_SPLIT_ON,
|
|
|
|
|
|
+ "engine": app.state.config.TTS_ENGINE,
|
|
|
|
+ "voice": app.state.config.TTS_VOICE,
|
|
|
|
+ "split_on": app.state.config.TTS_SPLIT_ON,
|
|
},
|
|
},
|
|
"stt": {
|
|
"stt": {
|
|
- "engine": audio_app.state.config.STT_ENGINE,
|
|
|
|
|
|
+ "engine": app.state.config.STT_ENGINE,
|
|
},
|
|
},
|
|
},
|
|
},
|
|
"file": {
|
|
"file": {
|
|
- "max_size": retrieval_app.state.config.FILE_MAX_SIZE,
|
|
|
|
- "max_count": retrieval_app.state.config.FILE_MAX_COUNT,
|
|
|
|
|
|
+ "max_size": app.state.config.FILE_MAX_SIZE,
|
|
|
|
+ "max_count": app.state.config.FILE_MAX_COUNT,
|
|
},
|
|
},
|
|
"permissions": {**app.state.config.USER_PERMISSIONS},
|
|
"permissions": {**app.state.config.USER_PERMISSIONS},
|
|
}
|
|
}
|