|
@@ -53,7 +53,7 @@ from utils.utils import (
|
|
|
get_current_user,
|
|
|
get_http_authorization_cred,
|
|
|
)
|
|
|
-from utils.task import title_generation_template
|
|
|
+from utils.task import title_generation_template, search_query_generation_template
|
|
|
|
|
|
from apps.rag.utils import rag_messages
|
|
|
|
|
@@ -77,7 +77,10 @@ from config import (
|
|
|
WEBHOOK_URL,
|
|
|
ENABLE_ADMIN_EXPORT,
|
|
|
WEBUI_BUILD_HASH,
|
|
|
+ TASK_MODEL,
|
|
|
+ TASK_MODEL_EXTERNAL,
|
|
|
TITLE_GENERATION_PROMPT_TEMPLATE,
|
|
|
+ SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
|
AppConfig,
|
|
|
)
|
|
|
from constants import ERROR_MESSAGES
|
|
@@ -132,9 +135,15 @@ app.state.config.ENABLE_OLLAMA_API = ENABLE_OLLAMA_API
|
|
|
app.state.config.ENABLE_MODEL_FILTER = ENABLE_MODEL_FILTER
|
|
|
app.state.config.MODEL_FILTER_LIST = MODEL_FILTER_LIST
|
|
|
|
|
|
-
|
|
|
app.state.config.WEBHOOK_URL = WEBHOOK_URL
|
|
|
+
|
|
|
+
|
|
|
+app.state.config.TASK_MODEL = TASK_MODEL
|
|
|
+app.state.config.TASK_MODEL_EXTERNAL = TASK_MODEL_EXTERNAL
|
|
|
app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = TITLE_GENERATION_PROMPT_TEMPLATE
|
|
|
+app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
|
|
+ SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
|
|
+)
|
|
|
|
|
|
app.state.MODELS = {}
|
|
|
|
|
@@ -494,9 +503,46 @@ async def get_models(user=Depends(get_verified_user)):
|
|
|
return {"data": models}
|
|
|
|
|
|
|
|
|
+@app.get("/api/task/config")
|
|
|
+async def get_task_config(user=Depends(get_verified_user)):
|
|
|
+ return {
|
|
|
+ "TASK_MODEL": app.state.config.TASK_MODEL,
|
|
|
+ "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
|
|
|
+ "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
|
|
+ "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+class TaskConfigForm(BaseModel):
|
|
|
+ TASK_MODEL: Optional[str]
|
|
|
+ TASK_MODEL_EXTERNAL: Optional[str]
|
|
|
+ TITLE_GENERATION_PROMPT_TEMPLATE: str
|
|
|
+ SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
|
|
|
+
|
|
|
+
|
|
|
+@app.post("/api/task/config/update")
|
|
|
+async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_user)):
|
|
|
+ app.state.config.TASK_MODEL = form_data.TASK_MODEL
|
|
|
+ app.state.config.TASK_MODEL_EXTERNAL = form_data.TASK_MODEL_EXTERNAL
|
|
|
+ app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE = (
|
|
|
+ form_data.TITLE_GENERATION_PROMPT_TEMPLATE
|
|
|
+ )
|
|
|
+ app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
|
|
|
+ form_data.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
|
|
+ )
|
|
|
+
|
|
|
+ return {
|
|
|
+ "TASK_MODEL": app.state.config.TASK_MODEL,
|
|
|
+ "TASK_MODEL_EXTERNAL": app.state.config.TASK_MODEL_EXTERNAL,
|
|
|
+ "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
|
|
|
+ "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
@app.post("/api/task/title/completions")
|
|
|
async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
|
|
print("generate_title")
|
|
|
+
|
|
|
model_id = form_data["model"]
|
|
|
if model_id not in app.state.MODELS:
|
|
|
raise HTTPException(
|
|
@@ -504,6 +550,20 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
|
|
detail="Model not found",
|
|
|
)
|
|
|
|
|
|
+ # Check if the user has a custom task model
|
|
|
+ # If the user has a custom task model, use that model
|
|
|
+ if app.state.MODELS[model_id]["owned_by"] == "ollama":
|
|
|
+ if app.state.config.TASK_MODEL:
|
|
|
+ task_model_id = app.state.config.TASK_MODEL
|
|
|
+ if task_model_id in app.state.MODELS:
|
|
|
+ model_id = task_model_id
|
|
|
+ else:
|
|
|
+ if app.state.config.TASK_MODEL_EXTERNAL:
|
|
|
+ task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
|
|
+ if task_model_id in app.state.MODELS:
|
|
|
+ model_id = task_model_id
|
|
|
+
|
|
|
+ print(model_id)
|
|
|
model = app.state.MODELS[model_id]
|
|
|
|
|
|
template = app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE
|
|
@@ -532,6 +592,57 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
|
|
return await generate_openai_chat_completion(payload, user=user)
|
|
|
|
|
|
|
|
|
+@app.post("/api/task/query/completions")
|
|
|
+async def generate_search_query(form_data: dict, user=Depends(get_verified_user)):
|
|
|
+ print("generate_search_query")
|
|
|
+
|
|
|
+ model_id = form_data["model"]
|
|
|
+ if model_id not in app.state.MODELS:
|
|
|
+ raise HTTPException(
|
|
|
+ status_code=status.HTTP_404_NOT_FOUND,
|
|
|
+ detail="Model not found",
|
|
|
+ )
|
|
|
+
|
|
|
+ # Check if the user has a custom task model
|
|
|
+ # If the user has a custom task model, use that model
|
|
|
+ if app.state.MODELS[model_id]["owned_by"] == "ollama":
|
|
|
+ if app.state.config.TASK_MODEL:
|
|
|
+ task_model_id = app.state.config.TASK_MODEL
|
|
|
+ if task_model_id in app.state.MODELS:
|
|
|
+ model_id = task_model_id
|
|
|
+ else:
|
|
|
+ if app.state.config.TASK_MODEL_EXTERNAL:
|
|
|
+ task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
|
|
+ if task_model_id in app.state.MODELS:
|
|
|
+ model_id = task_model_id
|
|
|
+
|
|
|
+ print(model_id)
|
|
|
+ model = app.state.MODELS[model_id]
|
|
|
+
|
|
|
+ template = app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE
|
|
|
+
|
|
|
+ content = search_query_generation_template(
|
|
|
+ template, form_data["prompt"], user.model_dump()
|
|
|
+ )
|
|
|
+
|
|
|
+ payload = {
|
|
|
+ "model": model_id,
|
|
|
+ "messages": [{"role": "user", "content": content}],
|
|
|
+ "stream": False,
|
|
|
+ "max_tokens": 30,
|
|
|
+ }
|
|
|
+
|
|
|
+ print(payload)
|
|
|
+ payload = filter_pipeline(payload, user)
|
|
|
+
|
|
|
+ if model["owned_by"] == "ollama":
|
|
|
+ return await generate_ollama_chat_completion(
|
|
|
+ OpenAIChatCompletionForm(**payload), user=user
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ return await generate_openai_chat_completion(payload, user=user)
|
|
|
+
|
|
|
+
|
|
|
@app.post("/api/chat/completions")
|
|
|
async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
|
|
|
model_id = form_data["model"]
|
|
@@ -542,7 +653,6 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
|
|
|
)
|
|
|
|
|
|
model = app.state.MODELS[model_id]
|
|
|
-
|
|
|
print(model)
|
|
|
|
|
|
if model["owned_by"] == "ollama":
|