|
@@ -11,6 +11,7 @@ import random
|
|
from contextlib import asynccontextmanager
|
|
from contextlib import asynccontextmanager
|
|
from typing import Optional
|
|
from typing import Optional
|
|
|
|
|
|
|
|
+from aiocache import cached
|
|
import aiohttp
|
|
import aiohttp
|
|
import requests
|
|
import requests
|
|
from fastapi import (
|
|
from fastapi import (
|
|
@@ -45,6 +46,7 @@ from open_webui.apps.openai.main import (
|
|
app as openai_app,
|
|
app as openai_app,
|
|
generate_chat_completion as generate_openai_chat_completion,
|
|
generate_chat_completion as generate_openai_chat_completion,
|
|
get_all_models as get_openai_models,
|
|
get_all_models as get_openai_models,
|
|
|
|
+ get_all_models_responses as get_openai_models_responses,
|
|
)
|
|
)
|
|
from open_webui.apps.retrieval.main import app as retrieval_app
|
|
from open_webui.apps.retrieval.main import app as retrieval_app
|
|
from open_webui.apps.retrieval.utils import get_rag_context, rag_template
|
|
from open_webui.apps.retrieval.utils import get_rag_context, rag_template
|
|
@@ -70,13 +72,11 @@ from open_webui.config import (
|
|
DEFAULT_LOCALE,
|
|
DEFAULT_LOCALE,
|
|
ENABLE_ADMIN_CHAT_ACCESS,
|
|
ENABLE_ADMIN_CHAT_ACCESS,
|
|
ENABLE_ADMIN_EXPORT,
|
|
ENABLE_ADMIN_EXPORT,
|
|
- ENABLE_MODEL_FILTER,
|
|
|
|
ENABLE_OLLAMA_API,
|
|
ENABLE_OLLAMA_API,
|
|
ENABLE_OPENAI_API,
|
|
ENABLE_OPENAI_API,
|
|
ENABLE_TAGS_GENERATION,
|
|
ENABLE_TAGS_GENERATION,
|
|
ENV,
|
|
ENV,
|
|
FRONTEND_BUILD_DIR,
|
|
FRONTEND_BUILD_DIR,
|
|
- MODEL_FILTER_LIST,
|
|
|
|
OAUTH_PROVIDERS,
|
|
OAUTH_PROVIDERS,
|
|
ENABLE_SEARCH_QUERY,
|
|
ENABLE_SEARCH_QUERY,
|
|
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
@@ -135,6 +135,7 @@ from open_webui.utils.utils import (
|
|
get_http_authorization_cred,
|
|
get_http_authorization_cred,
|
|
get_verified_user,
|
|
get_verified_user,
|
|
)
|
|
)
|
|
|
|
+from open_webui.utils.access_control import has_access
|
|
|
|
|
|
if SAFE_MODE:
|
|
if SAFE_MODE:
|
|
print("SAFE MODE ENABLED")
|
|
print("SAFE MODE ENABLED")
|
|
@@ -183,7 +184,10 @@ async def lifespan(app: FastAPI):
|
|
|
|
|
|
|
|
|
|
app = FastAPI(
|
|
app = FastAPI(
|
|
- docs_url="/docs" if ENV == "dev" else None, openapi_url="/openapi.json" if ENV == "dev" else None, redoc_url=None, lifespan=lifespan
|
|
|
|
|
|
+ docs_url="/docs" if ENV == "dev" else None,
|
|
|
|
+ openapi_url="/openapi.json" if ENV == "dev" else None,
|
|
|
|
+ redoc_url=None,
|
|
|
|
+ lifespan=lifespan,
|
|
)
|
|
)
|
|
|
|
|
|
app.state.config = AppConfig()
|
|
app.state.config = AppConfig()
|
|
@@ -191,27 +195,26 @@ app.state.config = AppConfig()
|
|
app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
|
|
app.state.config.ENABLE_OPENAI_API = ENABLE_OPENAI_API
|
|
app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
|
|
app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
|
|
|
|
|
|
-app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
|
|
|
-app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
|
|
|
-
|
|
|
|
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
|
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
|
|
|
|
|
app.state.config.TASK_MODEL = TASK_MODEL
|
|
app.state.config.TASK_MODEL = TASK_MODEL
|
|
app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
|
|
app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
|
|
|
|
+
|
|
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
|
|
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
|
|
-app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE
|
|
|
|
|
|
+
|
|
app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
|
|
app.state.config.ENABLE_TAGS_GENERATION = ENABLE_TAGS_GENERATION
|
|
|
|
+app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE = TAGS_GENERATION_PROMPT_TEMPLATE
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+app.state.config.ENABLE_SEARCH_QUERY = ENABLE_SEARCH_QUERY
|
|
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
|
app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
|
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
|
SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
|
)
|
|
)
|
|
-app.state.config.ENABLE_SEARCH_QUERY = ENABLE_SEARCH_QUERY
|
|
|
|
|
|
+
|
|
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
|
app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
|
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
|
TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
|
|
)
|
|
)
|
|
|
|
|
|
-app.state.MODELS = {}
|
|
|
|
-
|
|
|
|
-
|
|
|
|
##################################
|
|
##################################
|
|
#
|
|
#
|
|
# ChatCompletion Middleware
|
|
# ChatCompletion Middleware
|
|
@@ -219,26 +222,6 @@ app.state.MODELS = {}
|
|
##################################
|
|
##################################
|
|
|
|
|
|
|
|
|
|
-def get_task_model_id(default_model_id):
|
|
|
|
- # Set the task model
|
|
|
|
- task_model_id = default_model_id
|
|
|
|
- # Check if the user has a custom task model and use that model
|
|
|
|
- if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
|
|
|
|
- if (
|
|
|
|
- app.state.config.TASK_MODEL
|
|
|
|
- and app.state.config.TASK_MODEL in app.state.MODELS
|
|
|
|
- ):
|
|
|
|
- task_model_id = app.state.config.TASK_MODEL
|
|
|
|
- else:
|
|
|
|
- if (
|
|
|
|
- app.state.config.TASK_MODEL_EXTERNAL
|
|
|
|
- and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
|
|
|
|
- ):
|
|
|
|
- task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
|
|
|
-
|
|
|
|
- return task_model_id
|
|
|
|
-
|
|
|
|
-
|
|
|
|
def get_filter_function_ids(model):
|
|
def get_filter_function_ids(model):
|
|
def get_priority(function_id):
|
|
def get_priority(function_id):
|
|
function = Functions.get_function_by_id(function_id)
|
|
function = Functions.get_function_by_id(function_id)
|
|
@@ -368,8 +351,24 @@ async def get_content_from_response(response) -> Optional[str]:
|
|
return content
|
|
return content
|
|
|
|
|
|
|
|
|
|
|
|
+def get_task_model_id(
|
|
|
|
+ default_model_id: str, task_model: str, task_model_external: str, models
|
|
|
|
+) -> str:
|
|
|
|
+ # Set the task model
|
|
|
|
+ task_model_id = default_model_id
|
|
|
|
+ # Check if the user has a custom task model and use that model
|
|
|
|
+ if models[task_model_id]["owned_by"] == "ollama":
|
|
|
|
+ if task_model and task_model in models:
|
|
|
|
+ task_model_id = task_model
|
|
|
|
+ else:
|
|
|
|
+ if task_model_external and task_model_external in models:
|
|
|
|
+ task_model_id = task_model_external
|
|
|
|
+
|
|
|
|
+ return task_model_id
|
|
|
|
+
|
|
|
|
+
|
|
async def chat_completion_tools_handler(
|
|
async def chat_completion_tools_handler(
|
|
- body: dict, user: UserModel, extra_params: dict
|
|
|
|
|
|
+ body: dict, user: UserModel, models, extra_params: dict
|
|
) -> tuple[dict, dict]:
|
|
) -> tuple[dict, dict]:
|
|
# If tool_ids field is present, call the functions
|
|
# If tool_ids field is present, call the functions
|
|
metadata = body.get("metadata", {})
|
|
metadata = body.get("metadata", {})
|
|
@@ -383,14 +382,19 @@ async def chat_completion_tools_handler(
|
|
contexts = []
|
|
contexts = []
|
|
citations = []
|
|
citations = []
|
|
|
|
|
|
- task_model_id = get_task_model_id(body["model"])
|
|
|
|
|
|
+ task_model_id = get_task_model_id(
|
|
|
|
+ body["model"],
|
|
|
|
+ app.state.config.TASK_MODEL,
|
|
|
|
+ app.state.config.TASK_MODEL_EXTERNAL,
|
|
|
|
+ models,
|
|
|
|
+ )
|
|
tools = get_tools(
|
|
tools = get_tools(
|
|
webui_app,
|
|
webui_app,
|
|
tool_ids,
|
|
tool_ids,
|
|
user,
|
|
user,
|
|
{
|
|
{
|
|
**extra_params,
|
|
**extra_params,
|
|
- "__model__": app.state.MODELS[task_model_id],
|
|
|
|
|
|
+ "__model__": models[task_model_id],
|
|
"__messages__": body["messages"],
|
|
"__messages__": body["messages"],
|
|
"__files__": metadata.get("files", []),
|
|
"__files__": metadata.get("files", []),
|
|
},
|
|
},
|
|
@@ -414,7 +418,7 @@ async def chat_completion_tools_handler(
|
|
)
|
|
)
|
|
|
|
|
|
try:
|
|
try:
|
|
- payload = filter_pipeline(payload, user)
|
|
|
|
|
|
+ payload = filter_pipeline(payload, user, models)
|
|
except Exception as e:
|
|
except Exception as e:
|
|
raise e
|
|
raise e
|
|
|
|
|
|
@@ -515,16 +519,16 @@ def is_chat_completion_request(request):
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
-async def get_body_and_model_and_user(request):
|
|
|
|
|
|
+async def get_body_and_model_and_user(request, models):
|
|
# Read the original request body
|
|
# Read the original request body
|
|
body = await request.body()
|
|
body = await request.body()
|
|
body_str = body.decode("utf-8")
|
|
body_str = body.decode("utf-8")
|
|
body = json.loads(body_str) if body_str else {}
|
|
body = json.loads(body_str) if body_str else {}
|
|
|
|
|
|
model_id = body["model"]
|
|
model_id = body["model"]
|
|
- if model_id not in app.state.MODELS:
|
|
|
|
|
|
+ if model_id not in models:
|
|
raise Exception("Model not found")
|
|
raise Exception("Model not found")
|
|
- model = app.state.MODELS[model_id]
|
|
|
|
|
|
+ model = models[model_id]
|
|
|
|
|
|
user = get_current_user(
|
|
user = get_current_user(
|
|
request,
|
|
request,
|
|
@@ -540,14 +544,27 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
return await call_next(request)
|
|
return await call_next(request)
|
|
log.debug(f"request.url.path: {request.url.path}")
|
|
log.debug(f"request.url.path: {request.url.path}")
|
|
|
|
|
|
|
|
+ model_list = await get_all_models()
|
|
|
|
+ models = {model["id"]: model for model in model_list}
|
|
|
|
+
|
|
try:
|
|
try:
|
|
- body, model, user = await get_body_and_model_and_user(request)
|
|
|
|
|
|
+ body, model, user = await get_body_and_model_and_user(request, models)
|
|
except Exception as e:
|
|
except Exception as e:
|
|
return JSONResponse(
|
|
return JSONResponse(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
content={"detail": str(e)},
|
|
content={"detail": str(e)},
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+ model_info = Models.get_model_by_id(model["id"])
|
|
|
|
+ if user.role == "user":
|
|
|
|
+ if model_info and not has_access(
|
|
|
|
+ user.id, type="read", access_control=model_info.access_control
|
|
|
|
+ ):
|
|
|
|
+ return JSONResponse(
|
|
|
|
+ status_code=status.HTTP_403_FORBIDDEN,
|
|
|
|
+ content={"detail": "User does not have access to the model"},
|
|
|
|
+ )
|
|
|
|
+
|
|
metadata = {
|
|
metadata = {
|
|
"chat_id": body.pop("chat_id", None),
|
|
"chat_id": body.pop("chat_id", None),
|
|
"message_id": body.pop("id", None),
|
|
"message_id": body.pop("id", None),
|
|
@@ -584,15 +601,20 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
content={"detail": str(e)},
|
|
content={"detail": str(e)},
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+ tool_ids = body.pop("tool_ids", None)
|
|
|
|
+ files = body.pop("files", None)
|
|
|
|
+
|
|
metadata = {
|
|
metadata = {
|
|
**metadata,
|
|
**metadata,
|
|
- "tool_ids": body.pop("tool_ids", None),
|
|
|
|
- "files": body.pop("files", None),
|
|
|
|
|
|
+ "tool_ids": tool_ids,
|
|
|
|
+ "files": files,
|
|
}
|
|
}
|
|
body["metadata"] = metadata
|
|
body["metadata"] = metadata
|
|
|
|
|
|
try:
|
|
try:
|
|
- body, flags = await chat_completion_tools_handler(body, user, extra_params)
|
|
|
|
|
|
+ body, flags = await chat_completion_tools_handler(
|
|
|
|
+ body, user, models, extra_params
|
|
|
|
+ )
|
|
contexts.extend(flags.get("contexts", []))
|
|
contexts.extend(flags.get("contexts", []))
|
|
citations.extend(flags.get("citations", []))
|
|
citations.extend(flags.get("citations", []))
|
|
except Exception as e:
|
|
except Exception as e:
|
|
@@ -689,10 +711,10 @@ app.add_middleware(ChatCompletionMiddleware)
|
|
##################################
|
|
##################################
|
|
|
|
|
|
|
|
|
|
-def get_sorted_filters(model_id):
|
|
|
|
|
|
+def get_sorted_filters(model_id, models):
|
|
filters = [
|
|
filters = [
|
|
model
|
|
model
|
|
- for model in app.state.MODELS.values()
|
|
|
|
|
|
+ for model in models.values()
|
|
if "pipeline" in model
|
|
if "pipeline" in model
|
|
and "type" in model["pipeline"]
|
|
and "type" in model["pipeline"]
|
|
and model["pipeline"]["type"] == "filter"
|
|
and model["pipeline"]["type"] == "filter"
|
|
@@ -708,12 +730,12 @@ def get_sorted_filters(model_id):
|
|
return sorted_filters
|
|
return sorted_filters
|
|
|
|
|
|
|
|
|
|
-def filter_pipeline(payload, user):
|
|
|
|
|
|
+def filter_pipeline(payload, user, models):
|
|
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
|
user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
|
|
model_id = payload["model"]
|
|
model_id = payload["model"]
|
|
- sorted_filters = get_sorted_filters(model_id)
|
|
|
|
|
|
|
|
- model = app.state.MODELS[model_id]
|
|
|
|
|
|
+ sorted_filters = get_sorted_filters(model_id, models)
|
|
|
|
+ model = models[model_id]
|
|
|
|
|
|
if "pipeline" in model:
|
|
if "pipeline" in model:
|
|
sorted_filters.append(model)
|
|
sorted_filters.append(model)
|
|
@@ -784,8 +806,11 @@ class PipelineMiddleware(BaseHTTPMiddleware):
|
|
content={"detail": "Not authenticated"},
|
|
content={"detail": "Not authenticated"},
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+ model_list = await get_all_models()
|
|
|
|
+ models = {model["id"]: model for model in model_list}
|
|
|
|
+
|
|
try:
|
|
try:
|
|
- data = filter_pipeline(data, user)
|
|
|
|
|
|
+ data = filter_pipeline(data, user, models)
|
|
except Exception as e:
|
|
except Exception as e:
|
|
if len(e.args) > 1:
|
|
if len(e.args) > 1:
|
|
return JSONResponse(
|
|
return JSONResponse(
|
|
@@ -864,16 +889,10 @@ async def commit_session_after_request(request: Request, call_next):
|
|
|
|
|
|
@app.middleware("http")
|
|
@app.middleware("http")
|
|
async def check_url(request: Request, call_next):
|
|
async def check_url(request: Request, call_next):
|
|
- if len(app.state.MODELS) == 0:
|
|
|
|
- await get_all_models()
|
|
|
|
- else:
|
|
|
|
- pass
|
|
|
|
-
|
|
|
|
start_time = int(time.time())
|
|
start_time = int(time.time())
|
|
response = await call_next(request)
|
|
response = await call_next(request)
|
|
process_time = int(time.time()) - start_time
|
|
process_time = int(time.time()) - start_time
|
|
response.headers["X-Process-Time"] = str(process_time)
|
|
response.headers["X-Process-Time"] = str(process_time)
|
|
-
|
|
|
|
return response
|
|
return response
|
|
|
|
|
|
|
|
|
|
@@ -913,12 +932,10 @@ app.mount("/retrieval/api/v1", retrieval_app)
|
|
|
|
|
|
app.mount("/api/v1", webui_app)
|
|
app.mount("/api/v1", webui_app)
|
|
|
|
|
|
-
|
|
|
|
webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION
|
|
webui_app.state.EMBEDDING_FUNCTION = retrieval_app.state.EMBEDDING_FUNCTION
|
|
|
|
|
|
|
|
|
|
-async def get_all_models():
|
|
|
|
- # TODO: Optimize this function
|
|
|
|
|
|
+async def get_all_base_models():
|
|
open_webui_models = []
|
|
open_webui_models = []
|
|
openai_models = []
|
|
openai_models = []
|
|
ollama_models = []
|
|
ollama_models = []
|
|
@@ -944,9 +961,15 @@ async def get_all_models():
|
|
open_webui_models = await get_open_webui_models()
|
|
open_webui_models = await get_open_webui_models()
|
|
|
|
|
|
models = open_webui_models + openai_models + ollama_models
|
|
models = open_webui_models + openai_models + ollama_models
|
|
|
|
+ return models
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+@cached(ttl=1)
|
|
|
|
+async def get_all_models():
|
|
|
|
+ models = await get_all_base_models()
|
|
|
|
|
|
# If there are no models, return an empty list
|
|
# If there are no models, return an empty list
|
|
- if len([model for model in models if model["owned_by"] != "arena"]) == 0:
|
|
|
|
|
|
+ if len([model for model in models if not model.get("arena", False)]) == 0:
|
|
return []
|
|
return []
|
|
|
|
|
|
global_action_ids = [
|
|
global_action_ids = [
|
|
@@ -965,15 +988,23 @@ async def get_all_models():
|
|
custom_model.id == model["id"]
|
|
custom_model.id == model["id"]
|
|
or custom_model.id == model["id"].split(":")[0]
|
|
or custom_model.id == model["id"].split(":")[0]
|
|
):
|
|
):
|
|
- model["name"] = custom_model.name
|
|
|
|
- model["info"] = custom_model.model_dump()
|
|
|
|
|
|
+ if custom_model.is_active:
|
|
|
|
+ model["name"] = custom_model.name
|
|
|
|
+ model["info"] = custom_model.model_dump()
|
|
|
|
+
|
|
|
|
+ action_ids = []
|
|
|
|
+ if "info" in model and "meta" in model["info"]:
|
|
|
|
+ action_ids.extend(
|
|
|
|
+ model["info"]["meta"].get("actionIds", [])
|
|
|
|
+ )
|
|
|
|
|
|
- action_ids = []
|
|
|
|
- if "info" in model and "meta" in model["info"]:
|
|
|
|
- action_ids.extend(model["info"]["meta"].get("actionIds", []))
|
|
|
|
|
|
+ model["action_ids"] = action_ids
|
|
|
|
+ else:
|
|
|
|
+ models.remove(model)
|
|
|
|
|
|
- model["action_ids"] = action_ids
|
|
|
|
- else:
|
|
|
|
|
|
+ elif custom_model.is_active and (
|
|
|
|
+ custom_model.id not in [model["id"] for model in models]
|
|
|
|
+ ):
|
|
owned_by = "openai"
|
|
owned_by = "openai"
|
|
pipe = None
|
|
pipe = None
|
|
action_ids = []
|
|
action_ids = []
|
|
@@ -995,7 +1026,7 @@ async def get_all_models():
|
|
|
|
|
|
models.append(
|
|
models.append(
|
|
{
|
|
{
|
|
- "id": custom_model.id,
|
|
|
|
|
|
+ "id": f"{custom_model.id}",
|
|
"name": custom_model.name,
|
|
"name": custom_model.name,
|
|
"object": "model",
|
|
"object": "model",
|
|
"created": custom_model.created_at,
|
|
"created": custom_model.created_at,
|
|
@@ -1007,66 +1038,54 @@ async def get_all_models():
|
|
}
|
|
}
|
|
)
|
|
)
|
|
|
|
|
|
- for model in models:
|
|
|
|
- action_ids = []
|
|
|
|
- if "action_ids" in model:
|
|
|
|
- action_ids = model["action_ids"]
|
|
|
|
- del model["action_ids"]
|
|
|
|
|
|
+ # Process action_ids to get the actions
|
|
|
|
+ def get_action_items_from_module(module):
|
|
|
|
+ actions = []
|
|
|
|
+ if hasattr(module, "actions"):
|
|
|
|
+ actions = module.actions
|
|
|
|
+ return [
|
|
|
|
+ {
|
|
|
|
+ "id": f"{module.id}.{action['id']}",
|
|
|
|
+ "name": action.get("name", f"{module.name} ({action['id']})"),
|
|
|
|
+ "description": module.meta.description,
|
|
|
|
+ "icon_url": action.get(
|
|
|
|
+ "icon_url", module.meta.manifest.get("icon_url", None)
|
|
|
|
+ ),
|
|
|
|
+ }
|
|
|
|
+ for action in actions
|
|
|
|
+ ]
|
|
|
|
+ else:
|
|
|
|
+ return [
|
|
|
|
+ {
|
|
|
|
+ "id": module.id,
|
|
|
|
+ "name": module.name,
|
|
|
|
+ "description": module.meta.description,
|
|
|
|
+ "icon_url": module.meta.manifest.get("icon_url", None),
|
|
|
|
+ }
|
|
|
|
+ ]
|
|
|
|
+
|
|
|
|
+ def get_function_module_by_id(function_id):
|
|
|
|
+ if function_id in webui_app.state.FUNCTIONS:
|
|
|
|
+ function_module = webui_app.state.FUNCTIONS[function_id]
|
|
|
|
+ else:
|
|
|
|
+ function_module, _, _ = load_function_module_by_id(function_id)
|
|
|
|
+ webui_app.state.FUNCTIONS[function_id] = function_module
|
|
|
|
|
|
- action_ids = action_ids + global_action_ids
|
|
|
|
- action_ids = list(set(action_ids))
|
|
|
|
|
|
+ for model in models:
|
|
action_ids = [
|
|
action_ids = [
|
|
- action_id for action_id in action_ids if action_id in enabled_action_ids
|
|
|
|
|
|
+ action_id
|
|
|
|
+ for action_id in list(set(model.pop("action_ids", []) + global_action_ids))
|
|
|
|
+ if action_id in enabled_action_ids
|
|
]
|
|
]
|
|
|
|
|
|
model["actions"] = []
|
|
model["actions"] = []
|
|
for action_id in action_ids:
|
|
for action_id in action_ids:
|
|
- action = Functions.get_function_by_id(action_id)
|
|
|
|
- if action is None:
|
|
|
|
|
|
+ action_function = Functions.get_function_by_id(action_id)
|
|
|
|
+ if action_function is None:
|
|
raise Exception(f"Action not found: {action_id}")
|
|
raise Exception(f"Action not found: {action_id}")
|
|
|
|
|
|
- if action_id in webui_app.state.FUNCTIONS:
|
|
|
|
- function_module = webui_app.state.FUNCTIONS[action_id]
|
|
|
|
- else:
|
|
|
|
- function_module, _, _ = load_function_module_by_id(action_id)
|
|
|
|
- webui_app.state.FUNCTIONS[action_id] = function_module
|
|
|
|
-
|
|
|
|
- __webui__ = False
|
|
|
|
- if hasattr(function_module, "__webui__"):
|
|
|
|
- __webui__ = function_module.__webui__
|
|
|
|
-
|
|
|
|
- if hasattr(function_module, "actions"):
|
|
|
|
- actions = function_module.actions
|
|
|
|
- model["actions"].extend(
|
|
|
|
- [
|
|
|
|
- {
|
|
|
|
- "id": f"{action_id}.{_action['id']}",
|
|
|
|
- "name": _action.get(
|
|
|
|
- "name", f"{action.name} ({_action['id']})"
|
|
|
|
- ),
|
|
|
|
- "description": action.meta.description,
|
|
|
|
- "icon_url": _action.get(
|
|
|
|
- "icon_url", action.meta.manifest.get("icon_url", None)
|
|
|
|
- ),
|
|
|
|
- **({"__webui__": __webui__} if __webui__ else {}),
|
|
|
|
- }
|
|
|
|
- for _action in actions
|
|
|
|
- ]
|
|
|
|
- )
|
|
|
|
- else:
|
|
|
|
- model["actions"].append(
|
|
|
|
- {
|
|
|
|
- "id": action_id,
|
|
|
|
- "name": action.name,
|
|
|
|
- "description": action.meta.description,
|
|
|
|
- "icon_url": action.meta.manifest.get("icon_url", None),
|
|
|
|
- **({"__webui__": __webui__} if __webui__ else {}),
|
|
|
|
- }
|
|
|
|
- )
|
|
|
|
-
|
|
|
|
- app.state.MODELS = {model["id"]: model for model in models}
|
|
|
|
- webui_app.state.MODELS = app.state.MODELS
|
|
|
|
-
|
|
|
|
|
|
+ function_module = get_function_module_by_id(action_id)
|
|
|
|
+ model["actions"].extend(get_action_items_from_module(function_module))
|
|
return models
|
|
return models
|
|
|
|
|
|
|
|
|
|
@@ -1081,16 +1100,29 @@ async def get_models(user=Depends(get_verified_user)):
|
|
if "pipeline" not in model or model["pipeline"].get("type", None) != "filter"
|
|
if "pipeline" not in model or model["pipeline"].get("type", None) != "filter"
|
|
]
|
|
]
|
|
|
|
|
|
- if app.state.config.ENABLE_MODEL_FILTER:
|
|
|
|
- if user.role == "user":
|
|
|
|
- models = list(
|
|
|
|
- filter(
|
|
|
|
- lambda model: model["id"] in app.state.config.MODEL_FILTER_LIST,
|
|
|
|
- models,
|
|
|
|
- )
|
|
|
|
- )
|
|
|
|
- return {"data": models}
|
|
|
|
|
|
+ # Filter out models that the user does not have access to
|
|
|
|
+ if user.role == "user":
|
|
|
|
+ filtered_models = []
|
|
|
|
+ for model in models:
|
|
|
|
+ model_info = Models.get_model_by_id(model["id"])
|
|
|
|
+ if model_info:
|
|
|
|
+ if has_access(
|
|
|
|
+ user.id, type="read", access_control=model_info.access_control
|
|
|
|
+ ):
|
|
|
|
+ filtered_models.append(model)
|
|
|
|
+ else:
|
|
|
|
+ filtered_models.append(model)
|
|
|
|
+ models = filtered_models
|
|
|
|
+
|
|
|
|
+ return {"data": models}
|
|
|
|
+
|
|
|
|
|
|
|
|
+@app.get("/api/models/base")
|
|
|
|
+async def get_base_models(user=Depends(get_admin_user)):
|
|
|
|
+ models = await get_all_base_models()
|
|
|
|
+
|
|
|
|
+ # Filter out arena models
|
|
|
|
+ models = [model for model in models if not model.get("arena", False)]
|
|
return {"data": models}
|
|
return {"data": models}
|
|
|
|
|
|
|
|
|
|
@@ -1098,23 +1130,28 @@ async def get_models(user=Depends(get_verified_user)):
|
|
async def generate_chat_completions(
|
|
async def generate_chat_completions(
|
|
form_data: dict, user=Depends(get_verified_user), bypass_filter: bool = False
|
|
form_data: dict, user=Depends(get_verified_user), bypass_filter: bool = False
|
|
):
|
|
):
|
|
- model_id = form_data["model"]
|
|
|
|
|
|
+ model_list = await get_all_models()
|
|
|
|
+ models = {model["id"]: model for model in model_list}
|
|
|
|
|
|
- if model_id not in app.state.MODELS:
|
|
|
|
|
|
+ model_id = form_data["model"]
|
|
|
|
+ if model_id not in models:
|
|
raise HTTPException(
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Model not found",
|
|
detail="Model not found",
|
|
)
|
|
)
|
|
|
|
|
|
- if not bypass_filter and app.state.config.ENABLE_MODEL_FILTER:
|
|
|
|
- if user.role == "user" and model_id not in app.state.config.MODEL_FILTER_LIST:
|
|
|
|
|
|
+ model = models[model_id]
|
|
|
|
+ # Check if user has access to the model
|
|
|
|
+ if user.role == "user":
|
|
|
|
+ model_info = Models.get_model_by_id(model_id)
|
|
|
|
+ if not has_access(
|
|
|
|
+ user.id, type="read", access_control=model_info.access_control
|
|
|
|
+ ):
|
|
raise HTTPException(
|
|
raise HTTPException(
|
|
- status_code=status.HTTP_403_FORBIDDEN,
|
|
|
|
|
|
+ status_code=403,
|
|
detail="Model not found",
|
|
detail="Model not found",
|
|
)
|
|
)
|
|
|
|
|
|
- model = app.state.MODELS[model_id]
|
|
|
|
-
|
|
|
|
if model["owned_by"] == "arena":
|
|
if model["owned_by"] == "arena":
|
|
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
|
|
model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
|
|
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
|
|
filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
|
|
@@ -1161,14 +1198,18 @@ async def generate_chat_completions(
|
|
),
|
|
),
|
|
"selected_model_id": selected_model_id,
|
|
"selected_model_id": selected_model_id,
|
|
}
|
|
}
|
|
|
|
+
|
|
if model.get("pipe"):
|
|
if model.get("pipe"):
|
|
- return await generate_function_chat_completion(form_data, user=user)
|
|
|
|
|
|
+ # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
|
|
|
|
+ return await generate_function_chat_completion(
|
|
|
|
+ form_data, user=user, models=models
|
|
|
|
+ )
|
|
if model["owned_by"] == "ollama":
|
|
if model["owned_by"] == "ollama":
|
|
# Using /ollama/api/chat endpoint
|
|
# Using /ollama/api/chat endpoint
|
|
form_data = convert_payload_openai_to_ollama(form_data)
|
|
form_data = convert_payload_openai_to_ollama(form_data)
|
|
form_data = GenerateChatCompletionForm(**form_data)
|
|
form_data = GenerateChatCompletionForm(**form_data)
|
|
response = await generate_ollama_chat_completion(
|
|
response = await generate_ollama_chat_completion(
|
|
- form_data=form_data, user=user, bypass_filter=True
|
|
|
|
|
|
+ form_data=form_data, user=user, bypass_filter=bypass_filter
|
|
)
|
|
)
|
|
if form_data.stream:
|
|
if form_data.stream:
|
|
response.headers["content-type"] = "text/event-stream"
|
|
response.headers["content-type"] = "text/event-stream"
|
|
@@ -1179,21 +1220,27 @@ async def generate_chat_completions(
|
|
else:
|
|
else:
|
|
return convert_response_ollama_to_openai(response)
|
|
return convert_response_ollama_to_openai(response)
|
|
else:
|
|
else:
|
|
- return await generate_openai_chat_completion(form_data, user=user)
|
|
|
|
|
|
+ return await generate_openai_chat_completion(
|
|
|
|
+ form_data, user=user, bypass_filter=bypass_filter
|
|
|
|
+ )
|
|
|
|
|
|
|
|
|
|
@app.post("/api/chat/completed")
|
|
@app.post("/api/chat/completed")
|
|
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
|
async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
|
|
|
|
+
|
|
|
|
+ model_list = await get_all_models()
|
|
|
|
+ models = {model["id"]: model for model in model_list}
|
|
|
|
+
|
|
data = form_data
|
|
data = form_data
|
|
model_id = data["model"]
|
|
model_id = data["model"]
|
|
- if model_id not in app.state.MODELS:
|
|
|
|
|
|
+ if model_id not in models:
|
|
raise HTTPException(
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Model not found",
|
|
detail="Model not found",
|
|
)
|
|
)
|
|
- model = app.state.MODELS[model_id]
|
|
|
|
|
|
|
|
- sorted_filters = get_sorted_filters(model_id)
|
|
|
|
|
|
+ model = models[model_id]
|
|
|
|
+ sorted_filters = get_sorted_filters(model_id, models)
|
|
if "pipeline" in model:
|
|
if "pipeline" in model:
|
|
sorted_filters = [model] + sorted_filters
|
|
sorted_filters = [model] + sorted_filters
|
|
|
|
|
|
@@ -1368,14 +1415,18 @@ async def chat_action(action_id: str, form_data: dict, user=Depends(get_verified
|
|
detail="Action not found",
|
|
detail="Action not found",
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+ model_list = await get_all_models()
|
|
|
|
+ models = {model["id"]: model for model in model_list}
|
|
|
|
+
|
|
data = form_data
|
|
data = form_data
|
|
model_id = data["model"]
|
|
model_id = data["model"]
|
|
- if model_id not in app.state.MODELS:
|
|
|
|
|
|
+
|
|
|
|
+ if model_id not in models:
|
|
raise HTTPException(
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Model not found",
|
|
detail="Model not found",
|
|
)
|
|
)
|
|
- model = app.state.MODELS[model_id]
|
|
|
|
|
|
+ model = models[model_id]
|
|
|
|
|
|
__event_emitter__ = get_event_emitter(
|
|
__event_emitter__ = get_event_emitter(
|
|
{
|
|
{
|
|
@@ -1529,8 +1580,11 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
|
|
async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
|
async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
|
print("generate_title")
|
|
print("generate_title")
|
|
|
|
|
|
|
|
+ model_list = await get_all_models()
|
|
|
|
+ models = {model["id"]: model for model in model_list}
|
|
|
|
+
|
|
model_id = form_data["model"]
|
|
model_id = form_data["model"]
|
|
- if model_id not in app.state.MODELS:
|
|
|
|
|
|
+ if model_id not in models:
|
|
raise HTTPException(
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Model not found",
|
|
detail="Model not found",
|
|
@@ -1538,10 +1592,16 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
|
|
|
|
|
# Check if the user has a custom task model
|
|
# Check if the user has a custom task model
|
|
# If the user has a custom task model, use that model
|
|
# If the user has a custom task model, use that model
|
|
- task_model_id = get_task_model_id(model_id)
|
|
|
|
|
|
+ task_model_id = get_task_model_id(
|
|
|
|
+ model_id,
|
|
|
|
+ app.state.config.TASK_MODEL,
|
|
|
|
+ app.state.config.TASK_MODEL_EXTERNAL,
|
|
|
|
+ models,
|
|
|
|
+ )
|
|
|
|
+
|
|
print(task_model_id)
|
|
print(task_model_id)
|
|
|
|
|
|
- model = app.state.MODELS[task_model_id]
|
|
|
|
|
|
+ model = models[task_model_id]
|
|
|
|
|
|
if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
|
|
if app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE != "":
|
|
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
|
|
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
|
|
@@ -1575,7 +1635,7 @@ Artificial Intelligence in Healthcare
|
|
"stream": False,
|
|
"stream": False,
|
|
**(
|
|
**(
|
|
{"max_tokens": 50}
|
|
{"max_tokens": 50}
|
|
- if app.state.MODELS[task_model_id]["owned_by"] == "ollama"
|
|
|
|
|
|
+ if models[task_model_id]["owned_by"] == "ollama"
|
|
else {
|
|
else {
|
|
"max_completion_tokens": 50,
|
|
"max_completion_tokens": 50,
|
|
}
|
|
}
|
|
@@ -1587,7 +1647,7 @@ Artificial Intelligence in Healthcare
|
|
|
|
|
|
# Handle pipeline filters
|
|
# Handle pipeline filters
|
|
try:
|
|
try:
|
|
- payload = filter_pipeline(payload, user)
|
|
|
|
|
|
+ payload = filter_pipeline(payload, user, models)
|
|
except Exception as e:
|
|
except Exception as e:
|
|
if len(e.args) > 1:
|
|
if len(e.args) > 1:
|
|
return JSONResponse(
|
|
return JSONResponse(
|
|
@@ -1614,8 +1674,11 @@ async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)):
|
|
content={"detail": "Tags generation is disabled"},
|
|
content={"detail": "Tags generation is disabled"},
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+ model_list = await get_all_models()
|
|
|
|
+ models = {model["id"]: model for model in model_list}
|
|
|
|
+
|
|
model_id = form_data["model"]
|
|
model_id = form_data["model"]
|
|
- if model_id not in app.state.MODELS:
|
|
|
|
|
|
+ if model_id not in models:
|
|
raise HTTPException(
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Model not found",
|
|
detail="Model not found",
|
|
@@ -1623,7 +1686,12 @@ async def generate_chat_tags(form_data: dict, user=Depends(get_verified_user)):
|
|
|
|
|
|
# Check if the user has a custom task model
|
|
# Check if the user has a custom task model
|
|
# If the user has a custom task model, use that model
|
|
# If the user has a custom task model, use that model
|
|
- task_model_id = get_task_model_id(model_id)
|
|
|
|
|
|
+ task_model_id = get_task_model_id(
|
|
|
|
+ model_id,
|
|
|
|
+ app.state.config.TASK_MODEL,
|
|
|
|
+ app.state.config.TASK_MODEL_EXTERNAL,
|
|
|
|
+ models,
|
|
|
|
+ )
|
|
print(task_model_id)
|
|
print(task_model_id)
|
|
|
|
|
|
if app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "":
|
|
if app.state.config.TAGS_GENERATION_PROMPT_TEMPLATE != "":
|
|
@@ -1661,7 +1729,7 @@ JSON format: { "tags": ["tag1", "tag2", "tag3"] }
|
|
|
|
|
|
# Handle pipeline filters
|
|
# Handle pipeline filters
|
|
try:
|
|
try:
|
|
- payload = filter_pipeline(payload, user)
|
|
|
|
|
|
+ payload = filter_pipeline(payload, user, models)
|
|
except Exception as e:
|
|
except Exception as e:
|
|
if len(e.args) > 1:
|
|
if len(e.args) > 1:
|
|
return JSONResponse(
|
|
return JSONResponse(
|
|
@@ -1688,8 +1756,11 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
|
|
detail=f"Search query generation is disabled",
|
|
detail=f"Search query generation is disabled",
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+ model_list = await get_all_models()
|
|
|
|
+ models = {model["id"]: model for model in model_list}
|
|
|
|
+
|
|
model_id = form_data["model"]
|
|
model_id = form_data["model"]
|
|
- if model_id not in app.state.MODELS:
|
|
|
|
|
|
+ if model_id not in models:
|
|
raise HTTPException(
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Model not found",
|
|
detail="Model not found",
|
|
@@ -1697,10 +1768,15 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
|
|
|
|
|
|
# Check if the user has a custom task model
|
|
# Check if the user has a custom task model
|
|
# If the user has a custom task model, use that model
|
|
# If the user has a custom task model, use that model
|
|
- task_model_id = get_task_model_id(model_id)
|
|
|
|
|
|
+ task_model_id = get_task_model_id(
|
|
|
|
+ model_id,
|
|
|
|
+ app.state.config.TASK_MODEL,
|
|
|
|
+ app.state.config.TASK_MODEL_EXTERNAL,
|
|
|
|
+ models,
|
|
|
|
+ )
|
|
print(task_model_id)
|
|
print(task_model_id)
|
|
|
|
|
|
- model = app.state.MODELS[task_model_id]
|
|
|
|
|
|
+ model = models[task_model_id]
|
|
|
|
|
|
if app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE != "":
|
|
if app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE != "":
|
|
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
|
template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
|
@@ -1727,7 +1803,7 @@ Search Query:"""
|
|
"stream": False,
|
|
"stream": False,
|
|
**(
|
|
**(
|
|
{"max_tokens": 30}
|
|
{"max_tokens": 30}
|
|
- if app.state.MODELS[task_model_id]["owned_by"] == "ollama"
|
|
|
|
|
|
+ if models[task_model_id]["owned_by"] == "ollama"
|
|
else {
|
|
else {
|
|
"max_completion_tokens": 30,
|
|
"max_completion_tokens": 30,
|
|
}
|
|
}
|
|
@@ -1738,7 +1814,7 @@ Search Query:"""
|
|
|
|
|
|
# Handle pipeline filters
|
|
# Handle pipeline filters
|
|
try:
|
|
try:
|
|
- payload = filter_pipeline(payload, user)
|
|
|
|
|
|
+ payload = filter_pipeline(payload, user, models)
|
|
except Exception as e:
|
|
except Exception as e:
|
|
if len(e.args) > 1:
|
|
if len(e.args) > 1:
|
|
return JSONResponse(
|
|
return JSONResponse(
|
|
@@ -1760,8 +1836,11 @@ Search Query:"""
|
|
async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
|
|
async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
|
|
print("generate_emoji")
|
|
print("generate_emoji")
|
|
|
|
|
|
|
|
+ model_list = await get_all_models()
|
|
|
|
+ models = {model["id"]: model for model in model_list}
|
|
|
|
+
|
|
model_id = form_data["model"]
|
|
model_id = form_data["model"]
|
|
- if model_id not in app.state.MODELS:
|
|
|
|
|
|
+ if model_id not in models:
|
|
raise HTTPException(
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Model not found",
|
|
detail="Model not found",
|
|
@@ -1769,10 +1848,15 @@ async def generate_emoji(form_data: dict, user=Depends(get_verified_user)):
|
|
|
|
|
|
# Check if the user has a custom task model
|
|
# Check if the user has a custom task model
|
|
# If the user has a custom task model, use that model
|
|
# If the user has a custom task model, use that model
|
|
- task_model_id = get_task_model_id(model_id)
|
|
|
|
|
|
+ task_model_id = get_task_model_id(
|
|
|
|
+ model_id,
|
|
|
|
+ app.state.config.TASK_MODEL,
|
|
|
|
+ app.state.config.TASK_MODEL_EXTERNAL,
|
|
|
|
+ models,
|
|
|
|
+ )
|
|
print(task_model_id)
|
|
print(task_model_id)
|
|
|
|
|
|
- model = app.state.MODELS[task_model_id]
|
|
|
|
|
|
+ model = models[task_model_id]
|
|
|
|
|
|
template = '''
|
|
template = '''
|
|
Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
|
|
Your task is to reflect the speaker's likely facial expression through a fitting emoji. Interpret emotions from the message and reflect their facial expression using fitting, diverse emojis (e.g., 😊, 😢, 😡, 😱).
|
|
@@ -1794,7 +1878,7 @@ Message: """{{prompt}}"""
|
|
"stream": False,
|
|
"stream": False,
|
|
**(
|
|
**(
|
|
{"max_tokens": 4}
|
|
{"max_tokens": 4}
|
|
- if app.state.MODELS[task_model_id]["owned_by"] == "ollama"
|
|
|
|
|
|
+ if models[task_model_id]["owned_by"] == "ollama"
|
|
else {
|
|
else {
|
|
"max_completion_tokens": 4,
|
|
"max_completion_tokens": 4,
|
|
}
|
|
}
|
|
@@ -1806,7 +1890,7 @@ Message: """{{prompt}}"""
|
|
|
|
|
|
# Handle pipeline filters
|
|
# Handle pipeline filters
|
|
try:
|
|
try:
|
|
- payload = filter_pipeline(payload, user)
|
|
|
|
|
|
+ payload = filter_pipeline(payload, user, models)
|
|
except Exception as e:
|
|
except Exception as e:
|
|
if len(e.args) > 1:
|
|
if len(e.args) > 1:
|
|
return JSONResponse(
|
|
return JSONResponse(
|
|
@@ -1828,8 +1912,11 @@ Message: """{{prompt}}"""
|
|
async def generate_moa_response(form_data: dict, user=Depends(get_verified_user)):
|
|
async def generate_moa_response(form_data: dict, user=Depends(get_verified_user)):
|
|
print("generate_moa_response")
|
|
print("generate_moa_response")
|
|
|
|
|
|
|
|
+ model_list = await get_all_models()
|
|
|
|
+ models = {model["id"]: model for model in model_list}
|
|
|
|
+
|
|
model_id = form_data["model"]
|
|
model_id = form_data["model"]
|
|
- if model_id not in app.state.MODELS:
|
|
|
|
|
|
+ if model_id not in models:
|
|
raise HTTPException(
|
|
raise HTTPException(
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
status_code=status.HTTP_404_NOT_FOUND,
|
|
detail="Model not found",
|
|
detail="Model not found",
|
|
@@ -1837,10 +1924,15 @@ async def generate_moa_response(form_data: dict, user=Depends(get_verified_user)
|
|
|
|
|
|
# Check if the user has a custom task model
|
|
# Check if the user has a custom task model
|
|
# If the user has a custom task model, use that model
|
|
# If the user has a custom task model, use that model
|
|
- task_model_id = get_task_model_id(model_id)
|
|
|
|
|
|
+ task_model_id = get_task_model_id(
|
|
|
|
+ model_id,
|
|
|
|
+ app.state.config.TASK_MODEL,
|
|
|
|
+ app.state.config.TASK_MODEL_EXTERNAL,
|
|
|
|
+ models,
|
|
|
|
+ )
|
|
print(task_model_id)
|
|
print(task_model_id)
|
|
|
|
|
|
- model = app.state.MODELS[task_model_id]
|
|
|
|
|
|
+ model = models[task_model_id]
|
|
|
|
|
|
template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}"
|
|
template = """You have been provided with a set of responses from various models to the latest user query: "{{prompt}}"
|
|
|
|
|
|
@@ -1867,7 +1959,7 @@ Responses from models: {{responses}}"""
|
|
log.debug(payload)
|
|
log.debug(payload)
|
|
|
|
|
|
try:
|
|
try:
|
|
- payload = filter_pipeline(payload, user)
|
|
|
|
|
|
+ payload = filter_pipeline(payload, user, models)
|
|
except Exception as e:
|
|
except Exception as e:
|
|
if len(e.args) > 1:
|
|
if len(e.args) > 1:
|
|
return JSONResponse(
|
|
return JSONResponse(
|
|
@@ -1897,7 +1989,7 @@ Responses from models: {{responses}}"""
|
|
|
|
|
|
@app.get("/api/pipelines/list")
|
|
@app.get("/api/pipelines/list")
|
|
async def get_pipelines_list(user=Depends(get_admin_user)):
|
|
async def get_pipelines_list(user=Depends(get_admin_user)):
|
|
- responses = await get_openai_models(raw=True)
|
|
|
|
|
|
+ responses = await get_openai_models_responses()
|
|
|
|
|
|
print(responses)
|
|
print(responses)
|
|
urlIdxs = [
|
|
urlIdxs = [
|
|
@@ -2297,32 +2389,6 @@ async def get_app_config(request: Request):
|
|
}
|
|
}
|
|
|
|
|
|
|
|
|
|
-@app.get("/api/config/model/filter")
|
|
|
|
-async def get_model_filter_config(user=Depends(get_admin_user)):
|
|
|
|
- return {
|
|
|
|
- "enabled": app.state.config.ENABLE_MODEL_FILTER,
|
|
|
|
- "models": app.state.config.MODEL_FILTER_LIST,
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-class ModelFilterConfigForm(BaseModel):
|
|
|
|
- enabled: bool
|
|
|
|
- models: list[str]
|
|
|
|
-
|
|
|
|
-
|
|
|
|
-@app.post("/api/config/model/filter")
|
|
|
|
-async def update_model_filter_config(
|
|
|
|
- form_data: ModelFilterConfigForm, user=Depends(get_admin_user)
|
|
|
|
-):
|
|
|
|
- app.state.config.ENABLE_MODEL_FILTER = form_data.enabled
|
|
|
|
- app.state.config.MODEL_FILTER_LIST = form_data.models
|
|
|
|
-
|
|
|
|
- return {
|
|
|
|
- "enabled": app.state.config.ENABLE_MODEL_FILTER,
|
|
|
|
- "models": app.state.config.MODEL_FILTER_LIST,
|
|
|
|
- }
|
|
|
|
-
|
|
|
|
-
|
|
|
|
# TODO: webhook endpoint should be under config endpoints
|
|
# TODO: webhook endpoint should be under config endpoints
|
|
|
|
|
|
|
|
|