|
@@ -213,7 +213,7 @@ origins = ["*"]
|
|
|
|
|
|
|
|
|
async def get_function_call_response(
|
|
|
- messages, files, tool_id, template, task_model_id, user
|
|
|
+ messages, files, tool_id, template, task_model_id, user, model
|
|
|
):
|
|
|
tool = Tools.get_tool_by_id(tool_id)
|
|
|
tools_specs = json.dumps(tool.specs, indent=2)
|
|
@@ -373,233 +373,308 @@ async def get_function_call_response(
|
|
|
return None, None, False
|
|
|
|
|
|
|
|
|
-class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
- async def dispatch(self, request: Request, call_next):
|
|
|
- data_items = []
|
|
|
+def get_task_model_id(default_model_id):
|
|
|
+ # Set the task model
|
|
|
+ task_model_id = default_model_id
|
|
|
+ # Check if the user has a custom task model and use that model
|
|
|
+ if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
|
|
|
+ if (
|
|
|
+ app.state.config.TASK_MODEL
|
|
|
+ and app.state.config.TASK_MODEL in app.state.MODELS
|
|
|
+ ):
|
|
|
+ task_model_id = app.state.config.TASK_MODEL
|
|
|
+ else:
|
|
|
+ if (
|
|
|
+ app.state.config.TASK_MODEL_EXTERNAL
|
|
|
+ and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
|
|
|
+ ):
|
|
|
+ task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
|
|
|
|
|
- show_citations = False
|
|
|
- citations = []
|
|
|
+ return task_model_id
|
|
|
|
|
|
- if request.method == "POST" and any(
|
|
|
- endpoint in request.url.path
|
|
|
- for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
|
|
- ):
|
|
|
- log.debug(f"request.url.path: {request.url.path}")
|
|
|
|
|
|
- # Read the original request body
|
|
|
- body = await request.body()
|
|
|
- body_str = body.decode("utf-8")
|
|
|
- data = json.loads(body_str) if body_str else {}
|
|
|
+def get_filter_function_ids(model):
|
|
|
+ def get_priority(function_id):
|
|
|
+ function = Functions.get_function_by_id(function_id)
|
|
|
+ if function is not None and hasattr(function, "valves"):
|
|
|
+ return (function.valves if function.valves else {}).get("priority", 0)
|
|
|
+ return 0
|
|
|
|
|
|
- user = get_current_user(
|
|
|
- request,
|
|
|
- get_http_authorization_cred(request.headers.get("Authorization")),
|
|
|
- )
|
|
|
- # Flag to skip RAG completions if file_handler is present in tools/functions
|
|
|
- skip_files = False
|
|
|
- if data.get("citations"):
|
|
|
- show_citations = True
|
|
|
- del data["citations"]
|
|
|
-
|
|
|
- model_id = data["model"]
|
|
|
- if model_id not in app.state.MODELS:
|
|
|
- raise HTTPException(
|
|
|
- status_code=status.HTTP_404_NOT_FOUND,
|
|
|
- detail="Model not found",
|
|
|
+ filter_ids = [function.id for function in Functions.get_global_filter_functions()]
|
|
|
+ if "info" in model and "meta" in model["info"]:
|
|
|
+ filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
|
|
+ filter_ids = list(set(filter_ids))
|
|
|
+
|
|
|
+ enabled_filter_ids = [
|
|
|
+ function.id
|
|
|
+ for function in Functions.get_functions_by_type("filter", active_only=True)
|
|
|
+ ]
|
|
|
+
|
|
|
+ filter_ids = [
|
|
|
+ filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
|
|
|
+ ]
|
|
|
+
|
|
|
+ filter_ids.sort(key=get_priority)
|
|
|
+ return filter_ids
|
|
|
+
|
|
|
+
|
|
|
+async def chat_completion_functions_handler(body, model, user):
|
|
|
+ skip_files = None
|
|
|
+
|
|
|
+ filter_ids = get_filter_function_ids(model)
|
|
|
+ for filter_id in filter_ids:
|
|
|
+ filter = Functions.get_function_by_id(filter_id)
|
|
|
+ if filter:
|
|
|
+ if filter_id in webui_app.state.FUNCTIONS:
|
|
|
+ function_module = webui_app.state.FUNCTIONS[filter_id]
|
|
|
+ else:
|
|
|
+ function_module, function_type, frontmatter = (
|
|
|
+ load_function_module_by_id(filter_id)
|
|
|
)
|
|
|
- model = app.state.MODELS[model_id]
|
|
|
+ webui_app.state.FUNCTIONS[filter_id] = function_module
|
|
|
|
|
|
- def get_priority(function_id):
|
|
|
- function = Functions.get_function_by_id(function_id)
|
|
|
- if function is not None and hasattr(function, "valves"):
|
|
|
- return (function.valves if function.valves else {}).get(
|
|
|
- "priority", 0
|
|
|
- )
|
|
|
- return 0
|
|
|
+ # Check if the function has a file_handler variable
|
|
|
+ if hasattr(function_module, "file_handler"):
|
|
|
+ skip_files = function_module.file_handler
|
|
|
|
|
|
- filter_ids = [
|
|
|
- function.id for function in Functions.get_global_filter_functions()
|
|
|
- ]
|
|
|
- if "info" in model and "meta" in model["info"]:
|
|
|
- filter_ids.extend(model["info"]["meta"].get("filterIds", []))
|
|
|
- filter_ids = list(set(filter_ids))
|
|
|
-
|
|
|
- enabled_filter_ids = [
|
|
|
- function.id
|
|
|
- for function in Functions.get_functions_by_type(
|
|
|
- "filter", active_only=True
|
|
|
+ if hasattr(function_module, "valves") and hasattr(
|
|
|
+ function_module, "Valves"
|
|
|
+ ):
|
|
|
+ valves = Functions.get_function_valves_by_id(filter_id)
|
|
|
+ function_module.valves = function_module.Valves(
|
|
|
+ **(valves if valves else {})
|
|
|
)
|
|
|
- ]
|
|
|
- filter_ids = [
|
|
|
- filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
|
|
|
- ]
|
|
|
|
|
|
- filter_ids.sort(key=get_priority)
|
|
|
- for filter_id in filter_ids:
|
|
|
- filter = Functions.get_function_by_id(filter_id)
|
|
|
- if filter:
|
|
|
- if filter_id in webui_app.state.FUNCTIONS:
|
|
|
- function_module = webui_app.state.FUNCTIONS[filter_id]
|
|
|
+ try:
|
|
|
+ if hasattr(function_module, "inlet"):
|
|
|
+ inlet = function_module.inlet
|
|
|
+
|
|
|
+ # Get the signature of the function
|
|
|
+ sig = inspect.signature(inlet)
|
|
|
+ params = {"body": body}
|
|
|
+
|
|
|
+ if "__user__" in sig.parameters:
|
|
|
+ __user__ = {
|
|
|
+ "id": user.id,
|
|
|
+ "email": user.email,
|
|
|
+ "name": user.name,
|
|
|
+ "role": user.role,
|
|
|
+ }
|
|
|
+
|
|
|
+ try:
|
|
|
+ if hasattr(function_module, "UserValves"):
|
|
|
+ __user__["valves"] = function_module.UserValves(
|
|
|
+ **Functions.get_user_valves_by_id_and_user_id(
|
|
|
+ filter_id, user.id
|
|
|
+ )
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ print(e)
|
|
|
+
|
|
|
+ params = {**params, "__user__": __user__}
|
|
|
+
|
|
|
+ if "__id__" in sig.parameters:
|
|
|
+ params = {
|
|
|
+ **params,
|
|
|
+ "__id__": filter_id,
|
|
|
+ }
|
|
|
+
|
|
|
+ if "__model__" in sig.parameters:
|
|
|
+ params = {
|
|
|
+ **params,
|
|
|
+ "__model__": model,
|
|
|
+ }
|
|
|
+
|
|
|
+ if inspect.iscoroutinefunction(inlet):
|
|
|
+ body = await inlet(**params)
|
|
|
else:
|
|
|
- function_module, function_type, frontmatter = (
|
|
|
- load_function_module_by_id(filter_id)
|
|
|
- )
|
|
|
- webui_app.state.FUNCTIONS[filter_id] = function_module
|
|
|
-
|
|
|
- # Check if the function has a file_handler variable
|
|
|
- if hasattr(function_module, "file_handler"):
|
|
|
- skip_files = function_module.file_handler
|
|
|
-
|
|
|
- if hasattr(function_module, "valves") and hasattr(
|
|
|
- function_module, "Valves"
|
|
|
- ):
|
|
|
- valves = Functions.get_function_valves_by_id(filter_id)
|
|
|
- function_module.valves = function_module.Valves(
|
|
|
- **(valves if valves else {})
|
|
|
- )
|
|
|
+ body = inlet(**params)
|
|
|
|
|
|
- try:
|
|
|
- if hasattr(function_module, "inlet"):
|
|
|
- inlet = function_module.inlet
|
|
|
-
|
|
|
- # Get the signature of the function
|
|
|
- sig = inspect.signature(inlet)
|
|
|
- params = {"body": data}
|
|
|
-
|
|
|
- if "__user__" in sig.parameters:
|
|
|
- __user__ = {
|
|
|
- "id": user.id,
|
|
|
- "email": user.email,
|
|
|
- "name": user.name,
|
|
|
- "role": user.role,
|
|
|
- }
|
|
|
-
|
|
|
- try:
|
|
|
- if hasattr(function_module, "UserValves"):
|
|
|
- __user__["valves"] = function_module.UserValves(
|
|
|
- **Functions.get_user_valves_by_id_and_user_id(
|
|
|
- filter_id, user.id
|
|
|
- )
|
|
|
- )
|
|
|
- except Exception as e:
|
|
|
- print(e)
|
|
|
-
|
|
|
- params = {**params, "__user__": __user__}
|
|
|
-
|
|
|
- if "__id__" in sig.parameters:
|
|
|
- params = {
|
|
|
- **params,
|
|
|
- "__id__": filter_id,
|
|
|
- }
|
|
|
-
|
|
|
- if "__model__" in sig.parameters:
|
|
|
- params = {
|
|
|
- **params,
|
|
|
- "__model__": model,
|
|
|
- }
|
|
|
-
|
|
|
- if inspect.iscoroutinefunction(inlet):
|
|
|
- data = await inlet(**params)
|
|
|
- else:
|
|
|
- data = inlet(**params)
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- print(f"Error: {e}")
|
|
|
- return JSONResponse(
|
|
|
- status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
- content={"detail": str(e)},
|
|
|
- )
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error: {e}")
|
|
|
+ raise e
|
|
|
|
|
|
- # Set the task model
|
|
|
- task_model_id = data["model"]
|
|
|
- # Check if the user has a custom task model and use that model
|
|
|
- if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
|
|
|
- if (
|
|
|
- app.state.config.TASK_MODEL
|
|
|
- and app.state.config.TASK_MODEL in app.state.MODELS
|
|
|
- ):
|
|
|
- task_model_id = app.state.config.TASK_MODEL
|
|
|
- else:
|
|
|
- if (
|
|
|
- app.state.config.TASK_MODEL_EXTERNAL
|
|
|
- and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
|
|
|
- ):
|
|
|
- task_model_id = app.state.config.TASK_MODEL_EXTERNAL
|
|
|
-
|
|
|
- prompt = get_last_user_message(data["messages"])
|
|
|
- context = ""
|
|
|
-
|
|
|
- # If tool_ids field is present, call the functions
|
|
|
- if "tool_ids" in data:
|
|
|
- print(data["tool_ids"])
|
|
|
- for tool_id in data["tool_ids"]:
|
|
|
- print(tool_id)
|
|
|
- try:
|
|
|
- response, citation, file_handler = (
|
|
|
- await get_function_call_response(
|
|
|
- messages=data["messages"],
|
|
|
- files=data.get("files", []),
|
|
|
- tool_id=tool_id,
|
|
|
- template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
|
- task_model_id=task_model_id,
|
|
|
- user=user,
|
|
|
- )
|
|
|
- )
|
|
|
+ if skip_files:
|
|
|
+ if "files" in body:
|
|
|
+ del body["files"]
|
|
|
|
|
|
- print(file_handler)
|
|
|
- if isinstance(response, str):
|
|
|
- context += ("\n" if context != "" else "") + response
|
|
|
-
|
|
|
- if citation:
|
|
|
- citations.append(citation)
|
|
|
- show_citations = True
|
|
|
-
|
|
|
- if file_handler:
|
|
|
- skip_files = True
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- print(f"Error: {e}")
|
|
|
- del data["tool_ids"]
|
|
|
-
|
|
|
- print(f"tool_context: {context}")
|
|
|
-
|
|
|
- # If files field is present, generate RAG completions
|
|
|
- # If skip_files is True, skip the RAG completions
|
|
|
- if "files" in data:
|
|
|
- if not skip_files:
|
|
|
- data = {**data}
|
|
|
- rag_context, rag_citations = get_rag_context(
|
|
|
- files=data["files"],
|
|
|
- messages=data["messages"],
|
|
|
- embedding_function=rag_app.state.EMBEDDING_FUNCTION,
|
|
|
- k=rag_app.state.config.TOP_K,
|
|
|
- reranking_function=rag_app.state.sentence_transformer_rf,
|
|
|
- r=rag_app.state.config.RELEVANCE_THRESHOLD,
|
|
|
- hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
|
|
- )
|
|
|
- if rag_context:
|
|
|
- context += ("\n" if context != "" else "") + rag_context
|
|
|
+ return body, {}
|
|
|
|
|
|
- log.debug(f"rag_context: {rag_context}, citations: {citations}")
|
|
|
|
|
|
- if rag_citations:
|
|
|
- citations.extend(rag_citations)
|
|
|
+async def chat_completion_tools_handler(body, model, user):
|
|
|
+ skip_files = None
|
|
|
|
|
|
- del data["files"]
|
|
|
+ contexts = []
|
|
|
+ citations = None
|
|
|
|
|
|
- if show_citations and len(citations) > 0:
|
|
|
- data_items.append({"citations": citations})
|
|
|
+ task_model_id = get_task_model_id(body["model"])
|
|
|
+
|
|
|
+ # If tool_ids field is present, call the functions
|
|
|
+ if "tool_ids" in body:
|
|
|
+ print(body["tool_ids"])
|
|
|
+ for tool_id in body["tool_ids"]:
|
|
|
+ print(tool_id)
|
|
|
+ try:
|
|
|
+ response, citation, file_handler = await get_function_call_response(
|
|
|
+ messages=body["messages"],
|
|
|
+ files=body.get("files", []),
|
|
|
+ tool_id=tool_id,
|
|
|
+ template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
|
|
|
+ task_model_id=task_model_id,
|
|
|
+ user=user,
|
|
|
+ model=model,
|
|
|
+ )
|
|
|
+
|
|
|
+ print(file_handler)
|
|
|
+ if isinstance(response, str):
|
|
|
+ contexts.append(response)
|
|
|
+
|
|
|
+ if citation:
|
|
|
+ if citations is None:
|
|
|
+ citations = [citation]
|
|
|
+ else:
|
|
|
+ citations.append(citation)
|
|
|
+
|
|
|
+ if file_handler:
|
|
|
+ skip_files = True
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error: {e}")
|
|
|
+ del body["tool_ids"]
|
|
|
+ print(f"tool_contexts: {contexts}")
|
|
|
+
|
|
|
+ if skip_files:
|
|
|
+ if "files" in body:
|
|
|
+ del body["files"]
|
|
|
+
|
|
|
+ return body, {
|
|
|
+ **({"contexts": contexts} if contexts is not None else {}),
|
|
|
+ **({"citations": citations} if citations is not None else {}),
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+async def chat_completion_files_handler(body):
|
|
|
+ contexts = []
|
|
|
+ citations = None
|
|
|
+
|
|
|
+ if "files" in body:
|
|
|
+ files = body["files"]
|
|
|
+ del body["files"]
|
|
|
+
|
|
|
+ contexts, citations = get_rag_context(
|
|
|
+ files=files,
|
|
|
+ messages=body["messages"],
|
|
|
+ embedding_function=rag_app.state.EMBEDDING_FUNCTION,
|
|
|
+ k=rag_app.state.config.TOP_K,
|
|
|
+ reranking_function=rag_app.state.sentence_transformer_rf,
|
|
|
+ r=rag_app.state.config.RELEVANCE_THRESHOLD,
|
|
|
+ hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
|
|
|
+ )
|
|
|
+
|
|
|
+ log.debug(f"rag_contexts: {contexts}, citations: {citations}")
|
|
|
+
|
|
|
+ return body, {
|
|
|
+ **({"contexts": contexts} if contexts is not None else {}),
|
|
|
+ **({"citations": citations} if citations is not None else {}),
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+async def get_body_and_model_and_user(request):
|
|
|
+ # Read the original request body
|
|
|
+ body = await request.body()
|
|
|
+ body_str = body.decode("utf-8")
|
|
|
+ body = json.loads(body_str) if body_str else {}
|
|
|
+
|
|
|
+ model_id = body["model"]
|
|
|
+ if model_id not in app.state.MODELS:
|
|
|
+ raise "Model not found"
|
|
|
+ model = app.state.MODELS[model_id]
|
|
|
|
|
|
- if context != "":
|
|
|
- system_prompt = rag_template(
|
|
|
- rag_app.state.config.RAG_TEMPLATE, context, prompt
|
|
|
+ user = get_current_user(
|
|
|
+ request,
|
|
|
+ get_http_authorization_cred(request.headers.get("Authorization")),
|
|
|
+ )
|
|
|
+
|
|
|
+ return body, model, user
|
|
|
+
|
|
|
+
|
|
|
+class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
+ async def dispatch(self, request: Request, call_next):
|
|
|
+ if request.method == "POST" and any(
|
|
|
+ endpoint in request.url.path
|
|
|
+ for endpoint in ["/ollama/api/chat", "/chat/completions"]
|
|
|
+ ):
|
|
|
+ log.debug(f"request.url.path: {request.url.path}")
|
|
|
+
|
|
|
+ try:
|
|
|
+ body, model, user = await get_body_and_model_and_user(request)
|
|
|
+ except Exception as e:
|
|
|
+ return JSONResponse(
|
|
|
+ status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
+ content={"detail": str(e)},
|
|
|
+ )
|
|
|
+
|
|
|
+ # Extract chat_id and message_id from the request body
|
|
|
+ chat_id = None
|
|
|
+ if "chat_id" in body:
|
|
|
+ chat_id = body["chat_id"]
|
|
|
+ del body["chat_id"]
|
|
|
+ message_id = None
|
|
|
+ if "id" in body:
|
|
|
+ message_id = body["id"]
|
|
|
+ del body["id"]
|
|
|
+
|
|
|
+ # Initialize data_items to store additional data to be sent to the client
|
|
|
+ data_items = []
|
|
|
+
|
|
|
+ # Initialize context, and citations
|
|
|
+ contexts = []
|
|
|
+ citations = []
|
|
|
+
|
|
|
+ print(body)
|
|
|
+
|
|
|
+ try:
|
|
|
+ body, flags = await chat_completion_functions_handler(body, model, user)
|
|
|
+ except Exception as e:
|
|
|
+ return JSONResponse(
|
|
|
+ status_code=status.HTTP_400_BAD_REQUEST,
|
|
|
+ content={"detail": str(e)},
|
|
|
)
|
|
|
- print(system_prompt)
|
|
|
- data["messages"] = add_or_update_system_message(
|
|
|
- system_prompt, data["messages"]
|
|
|
+
|
|
|
+ try:
|
|
|
+ body, flags = await chat_completion_tools_handler(body, model, user)
|
|
|
+
|
|
|
+ contexts.extend(flags.get("contexts", []))
|
|
|
+ citations.extend(flags.get("citations", []))
|
|
|
+ except Exception as e:
|
|
|
+ print(e)
|
|
|
+ pass
|
|
|
+
|
|
|
+ try:
|
|
|
+ body, flags = await chat_completion_files_handler(body)
|
|
|
+
|
|
|
+ contexts.extend(flags.get("contexts", []))
|
|
|
+ citations.extend(flags.get("citations", []))
|
|
|
+ except Exception as e:
|
|
|
+ print(e)
|
|
|
+ pass
|
|
|
+
|
|
|
+ # If context is not empty, insert it into the messages
|
|
|
+ if len(contexts) > 0:
|
|
|
+ context_string = "/n".join(contexts).strip()
|
|
|
+ prompt = get_last_user_message(body["messages"])
|
|
|
+ body["messages"] = add_or_update_system_message(
|
|
|
+ rag_template(
|
|
|
+ rag_app.state.config.RAG_TEMPLATE, context_string, prompt
|
|
|
+ ),
|
|
|
+ body["messages"],
|
|
|
)
|
|
|
|
|
|
- modified_body_bytes = json.dumps(data).encode("utf-8")
|
|
|
+ # If there are citations, add them to the data_items
|
|
|
+ if len(citations) > 0:
|
|
|
+ data_items.append({"citations": citations})
|
|
|
+
|
|
|
+ modified_body_bytes = json.dumps(body).encode("utf-8")
|
|
|
# Replace the request body with the modified one
|
|
|
request._body = modified_body_bytes
|
|
|
# Set custom header to ensure content-length matches new body length
|
|
@@ -721,9 +796,6 @@ def filter_pipeline(payload, user):
|
|
|
pass
|
|
|
|
|
|
if "pipeline" not in app.state.MODELS[model_id]:
|
|
|
- if "chat_id" in payload:
|
|
|
- del payload["chat_id"]
|
|
|
-
|
|
|
if "title" in payload:
|
|
|
del payload["title"]
|
|
|
|
|
@@ -1225,6 +1297,9 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
|
|
|
content={"detail": e.args[1]},
|
|
|
)
|
|
|
|
|
|
+ if "chat_id" in payload:
|
|
|
+ del payload["chat_id"]
|
|
|
+
|
|
|
return await generate_chat_completions(form_data=payload, user=user)
|
|
|
|
|
|
|
|
@@ -1285,6 +1360,9 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
|
|
|
content={"detail": e.args[1]},
|
|
|
)
|
|
|
|
|
|
+ if "chat_id" in payload:
|
|
|
+ del payload["chat_id"]
|
|
|
+
|
|
|
return await generate_chat_completions(form_data=payload, user=user)
|
|
|
|
|
|
|
|
@@ -1349,6 +1427,9 @@ Message: """{{prompt}}"""
|
|
|
content={"detail": e.args[1]},
|
|
|
)
|
|
|
|
|
|
+ if "chat_id" in payload:
|
|
|
+ del payload["chat_id"]
|
|
|
+
|
|
|
return await generate_chat_completions(form_data=payload, user=user)
|
|
|
|
|
|
|