浏览代码

refac: chat completion middleware

Timothy J. Baek 10 月之前
父节点
当前提交
c7a9b5ccfa
共有 3 个文件被更改,包括 304 次插入223 次删除
  1. 7 7
      backend/apps/rag/utils.py
  2. 291 210
      backend/main.py
  3. 6 6
      src/lib/components/chat/Chat.svelte

+ 7 - 7
backend/apps/rag/utils.py

@@ -294,14 +294,16 @@ def get_rag_context(
 
         extracted_collections.extend(collection_names)
 
-    context_string = ""
-
+    contexts = []
     citations = []
+
     for context in relevant_contexts:
         try:
             if "documents" in context:
-                context_string += "\n\n".join(
-                    [text for text in context["documents"][0] if text is not None]
+                contexts.append(
+                    "\n\n".join(
+                        [text for text in context["documents"][0] if text is not None]
+                    )
                 )
 
                 if "metadatas" in context:
@@ -315,9 +317,7 @@ def get_rag_context(
         except Exception as e:
             log.exception(e)
 
-    context_string = context_string.strip()
-
-    return context_string, citations
+    return contexts, citations
 
 
 def get_model_path(model: str, update_model: bool = False):

+ 291 - 210
backend/main.py

@@ -213,7 +213,7 @@ origins = ["*"]
 
 
 async def get_function_call_response(
-    messages, files, tool_id, template, task_model_id, user
+    messages, files, tool_id, template, task_model_id, user, model
 ):
     tool = Tools.get_tool_by_id(tool_id)
     tools_specs = json.dumps(tool.specs, indent=2)
@@ -373,233 +373,308 @@ async def get_function_call_response(
     return None, None, False
 
 
-class ChatCompletionMiddleware(BaseHTTPMiddleware):
-    async def dispatch(self, request: Request, call_next):
-        data_items = []
+def get_task_model_id(default_model_id):
+    # Set the task model
+    task_model_id = default_model_id
+    # Check if the user has a custom task model and use that model
+    if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
+        if (
+            app.state.config.TASK_MODEL
+            and app.state.config.TASK_MODEL in app.state.MODELS
+        ):
+            task_model_id = app.state.config.TASK_MODEL
+    else:
+        if (
+            app.state.config.TASK_MODEL_EXTERNAL
+            and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
+        ):
+            task_model_id = app.state.config.TASK_MODEL_EXTERNAL
 
-        show_citations = False
-        citations = []
+    return task_model_id
 
-        if request.method == "POST" and any(
-            endpoint in request.url.path
-            for endpoint in ["/ollama/api/chat", "/chat/completions"]
-        ):
-            log.debug(f"request.url.path: {request.url.path}")
 
-            # Read the original request body
-            body = await request.body()
-            body_str = body.decode("utf-8")
-            data = json.loads(body_str) if body_str else {}
+def get_filter_function_ids(model):
+    def get_priority(function_id):
+        function = Functions.get_function_by_id(function_id)
+        if function is not None and hasattr(function, "valves"):
+            return (function.valves if function.valves else {}).get("priority", 0)
+        return 0
 
-            user = get_current_user(
-                request,
-                get_http_authorization_cred(request.headers.get("Authorization")),
-            )
-            # Flag to skip RAG completions if file_handler is present in tools/functions
-            skip_files = False
-            if data.get("citations"):
-                show_citations = True
-                del data["citations"]
-
-            model_id = data["model"]
-            if model_id not in app.state.MODELS:
-                raise HTTPException(
-                    status_code=status.HTTP_404_NOT_FOUND,
-                    detail="Model not found",
+    filter_ids = [function.id for function in Functions.get_global_filter_functions()]
+    if "info" in model and "meta" in model["info"]:
+        filter_ids.extend(model["info"]["meta"].get("filterIds", []))
+        filter_ids = list(set(filter_ids))
+
+    enabled_filter_ids = [
+        function.id
+        for function in Functions.get_functions_by_type("filter", active_only=True)
+    ]
+
+    filter_ids = [
+        filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
+    ]
+
+    filter_ids.sort(key=get_priority)
+    return filter_ids
+
+
+async def chat_completion_functions_handler(body, model, user):
+    skip_files = None
+
+    filter_ids = get_filter_function_ids(model)
+    for filter_id in filter_ids:
+        filter = Functions.get_function_by_id(filter_id)
+        if filter:
+            if filter_id in webui_app.state.FUNCTIONS:
+                function_module = webui_app.state.FUNCTIONS[filter_id]
+            else:
+                function_module, function_type, frontmatter = (
+                    load_function_module_by_id(filter_id)
                 )
-            model = app.state.MODELS[model_id]
+                webui_app.state.FUNCTIONS[filter_id] = function_module
 
-            def get_priority(function_id):
-                function = Functions.get_function_by_id(function_id)
-                if function is not None and hasattr(function, "valves"):
-                    return (function.valves if function.valves else {}).get(
-                        "priority", 0
-                    )
-                return 0
+            # Check if the function has a file_handler variable
+            if hasattr(function_module, "file_handler"):
+                skip_files = function_module.file_handler
 
-            filter_ids = [
-                function.id for function in Functions.get_global_filter_functions()
-            ]
-            if "info" in model and "meta" in model["info"]:
-                filter_ids.extend(model["info"]["meta"].get("filterIds", []))
-                filter_ids = list(set(filter_ids))
-
-            enabled_filter_ids = [
-                function.id
-                for function in Functions.get_functions_by_type(
-                    "filter", active_only=True
+            if hasattr(function_module, "valves") and hasattr(
+                function_module, "Valves"
+            ):
+                valves = Functions.get_function_valves_by_id(filter_id)
+                function_module.valves = function_module.Valves(
+                    **(valves if valves else {})
                 )
-            ]
-            filter_ids = [
-                filter_id for filter_id in filter_ids if filter_id in enabled_filter_ids
-            ]
 
-            filter_ids.sort(key=get_priority)
-            for filter_id in filter_ids:
-                filter = Functions.get_function_by_id(filter_id)
-                if filter:
-                    if filter_id in webui_app.state.FUNCTIONS:
-                        function_module = webui_app.state.FUNCTIONS[filter_id]
+            try:
+                if hasattr(function_module, "inlet"):
+                    inlet = function_module.inlet
+
+                    # Get the signature of the function
+                    sig = inspect.signature(inlet)
+                    params = {"body": body}
+
+                    if "__user__" in sig.parameters:
+                        __user__ = {
+                            "id": user.id,
+                            "email": user.email,
+                            "name": user.name,
+                            "role": user.role,
+                        }
+
+                        try:
+                            if hasattr(function_module, "UserValves"):
+                                __user__["valves"] = function_module.UserValves(
+                                    **Functions.get_user_valves_by_id_and_user_id(
+                                        filter_id, user.id
+                                    )
+                                )
+                        except Exception as e:
+                            print(e)
+
+                        params = {**params, "__user__": __user__}
+
+                    if "__id__" in sig.parameters:
+                        params = {
+                            **params,
+                            "__id__": filter_id,
+                        }
+
+                    if "__model__" in sig.parameters:
+                        params = {
+                            **params,
+                            "__model__": model,
+                        }
+
+                    if inspect.iscoroutinefunction(inlet):
+                        body = await inlet(**params)
                     else:
-                        function_module, function_type, frontmatter = (
-                            load_function_module_by_id(filter_id)
-                        )
-                        webui_app.state.FUNCTIONS[filter_id] = function_module
-
-                    # Check if the function has a file_handler variable
-                    if hasattr(function_module, "file_handler"):
-                        skip_files = function_module.file_handler
-
-                    if hasattr(function_module, "valves") and hasattr(
-                        function_module, "Valves"
-                    ):
-                        valves = Functions.get_function_valves_by_id(filter_id)
-                        function_module.valves = function_module.Valves(
-                            **(valves if valves else {})
-                        )
+                        body = inlet(**params)
 
-                    try:
-                        if hasattr(function_module, "inlet"):
-                            inlet = function_module.inlet
-
-                            # Get the signature of the function
-                            sig = inspect.signature(inlet)
-                            params = {"body": data}
-
-                            if "__user__" in sig.parameters:
-                                __user__ = {
-                                    "id": user.id,
-                                    "email": user.email,
-                                    "name": user.name,
-                                    "role": user.role,
-                                }
-
-                                try:
-                                    if hasattr(function_module, "UserValves"):
-                                        __user__["valves"] = function_module.UserValves(
-                                            **Functions.get_user_valves_by_id_and_user_id(
-                                                filter_id, user.id
-                                            )
-                                        )
-                                except Exception as e:
-                                    print(e)
-
-                                params = {**params, "__user__": __user__}
-
-                            if "__id__" in sig.parameters:
-                                params = {
-                                    **params,
-                                    "__id__": filter_id,
-                                }
-
-                            if "__model__" in sig.parameters:
-                                params = {
-                                    **params,
-                                    "__model__": model,
-                                }
-
-                            if inspect.iscoroutinefunction(inlet):
-                                data = await inlet(**params)
-                            else:
-                                data = inlet(**params)
-
-                    except Exception as e:
-                        print(f"Error: {e}")
-                        return JSONResponse(
-                            status_code=status.HTTP_400_BAD_REQUEST,
-                            content={"detail": str(e)},
-                        )
+            except Exception as e:
+                print(f"Error: {e}")
+                raise e
 
-            # Set the task model
-            task_model_id = data["model"]
-            # Check if the user has a custom task model and use that model
-            if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
-                if (
-                    app.state.config.TASK_MODEL
-                    and app.state.config.TASK_MODEL in app.state.MODELS
-                ):
-                    task_model_id = app.state.config.TASK_MODEL
-            else:
-                if (
-                    app.state.config.TASK_MODEL_EXTERNAL
-                    and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
-                ):
-                    task_model_id = app.state.config.TASK_MODEL_EXTERNAL
-
-            prompt = get_last_user_message(data["messages"])
-            context = ""
-
-            # If tool_ids field is present, call the functions
-            if "tool_ids" in data:
-                print(data["tool_ids"])
-                for tool_id in data["tool_ids"]:
-                    print(tool_id)
-                    try:
-                        response, citation, file_handler = (
-                            await get_function_call_response(
-                                messages=data["messages"],
-                                files=data.get("files", []),
-                                tool_id=tool_id,
-                                template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
-                                task_model_id=task_model_id,
-                                user=user,
-                            )
-                        )
+    if skip_files:
+        if "files" in body:
+            del body["files"]
 
-                        print(file_handler)
-                        if isinstance(response, str):
-                            context += ("\n" if context != "" else "") + response
-
-                        if citation:
-                            citations.append(citation)
-                            show_citations = True
-
-                        if file_handler:
-                            skip_files = True
-
-                    except Exception as e:
-                        print(f"Error: {e}")
-                del data["tool_ids"]
-
-                print(f"tool_context: {context}")
-
-            # If files field is present, generate RAG completions
-            # If skip_files is True, skip the RAG completions
-            if "files" in data:
-                if not skip_files:
-                    data = {**data}
-                    rag_context, rag_citations = get_rag_context(
-                        files=data["files"],
-                        messages=data["messages"],
-                        embedding_function=rag_app.state.EMBEDDING_FUNCTION,
-                        k=rag_app.state.config.TOP_K,
-                        reranking_function=rag_app.state.sentence_transformer_rf,
-                        r=rag_app.state.config.RELEVANCE_THRESHOLD,
-                        hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
-                    )
-                    if rag_context:
-                        context += ("\n" if context != "" else "") + rag_context
+    return body, {}
 
-                    log.debug(f"rag_context: {rag_context}, citations: {citations}")
 
-                    if rag_citations:
-                        citations.extend(rag_citations)
+async def chat_completion_tools_handler(body, model, user):
+    skip_files = None
 
-                del data["files"]
+    contexts = []
+    citations = None
 
-            if show_citations and len(citations) > 0:
-                data_items.append({"citations": citations})
+    task_model_id = get_task_model_id(body["model"])
+
+    # If tool_ids field is present, call the functions
+    if "tool_ids" in body:
+        print(body["tool_ids"])
+        for tool_id in body["tool_ids"]:
+            print(tool_id)
+            try:
+                response, citation, file_handler = await get_function_call_response(
+                    messages=body["messages"],
+                    files=body.get("files", []),
+                    tool_id=tool_id,
+                    template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
+                    task_model_id=task_model_id,
+                    user=user,
+                    model=model,
+                )
+
+                print(file_handler)
+                if isinstance(response, str):
+                    contexts.append(response)
+
+                if citation:
+                    if citations is None:
+                        citations = [citation]
+                    else:
+                        citations.append(citation)
+
+                if file_handler:
+                    skip_files = True
+
+            except Exception as e:
+                print(f"Error: {e}")
+        del body["tool_ids"]
+        print(f"tool_contexts: {contexts}")
+
+    if skip_files:
+        if "files" in body:
+            del body["files"]
+
+    return body, {
+        **({"contexts": contexts} if contexts is not None else {}),
+        **({"citations": citations} if citations is not None else {}),
+    }
+
+
+async def chat_completion_files_handler(body):
+    contexts = []
+    citations = None
+
+    if "files" in body:
+        files = body["files"]
+        del body["files"]
+
+        contexts, citations = get_rag_context(
+            files=files,
+            messages=body["messages"],
+            embedding_function=rag_app.state.EMBEDDING_FUNCTION,
+            k=rag_app.state.config.TOP_K,
+            reranking_function=rag_app.state.sentence_transformer_rf,
+            r=rag_app.state.config.RELEVANCE_THRESHOLD,
+            hybrid_search=rag_app.state.config.ENABLE_RAG_HYBRID_SEARCH,
+        )
+
+        log.debug(f"rag_contexts: {contexts}, citations: {citations}")
+
+    return body, {
+        **({"contexts": contexts} if contexts is not None else {}),
+        **({"citations": citations} if citations is not None else {}),
+    }
+
+
+async def get_body_and_model_and_user(request):
+    # Read the original request body
+    body = await request.body()
+    body_str = body.decode("utf-8")
+    body = json.loads(body_str) if body_str else {}
+
+    model_id = body["model"]
+    if model_id not in app.state.MODELS:
+        raise "Model not found"
+    model = app.state.MODELS[model_id]
 
-            if context != "":
-                system_prompt = rag_template(
-                    rag_app.state.config.RAG_TEMPLATE, context, prompt
+    user = get_current_user(
+        request,
+        get_http_authorization_cred(request.headers.get("Authorization")),
+    )
+
+    return body, model, user
+
+
+class ChatCompletionMiddleware(BaseHTTPMiddleware):
+    async def dispatch(self, request: Request, call_next):
+        if request.method == "POST" and any(
+            endpoint in request.url.path
+            for endpoint in ["/ollama/api/chat", "/chat/completions"]
+        ):
+            log.debug(f"request.url.path: {request.url.path}")
+
+            try:
+                body, model, user = await get_body_and_model_and_user(request)
+            except Exception as e:
+                return JSONResponse(
+                    status_code=status.HTTP_400_BAD_REQUEST,
+                    content={"detail": str(e)},
+                )
+
+            # Extract chat_id and message_id from the request body
+            chat_id = None
+            if "chat_id" in body:
+                chat_id = body["chat_id"]
+                del body["chat_id"]
+            message_id = None
+            if "id" in body:
+                message_id = body["id"]
+                del body["id"]
+
+            # Initialize data_items to store additional data to be sent to the client
+            data_items = []
+
+            # Initialize context, and citations
+            contexts = []
+            citations = []
+
+            print(body)
+
+            try:
+                body, flags = await chat_completion_functions_handler(body, model, user)
+            except Exception as e:
+                return JSONResponse(
+                    status_code=status.HTTP_400_BAD_REQUEST,
+                    content={"detail": str(e)},
                 )
-                print(system_prompt)
-                data["messages"] = add_or_update_system_message(
-                    system_prompt, data["messages"]
+
+            try:
+                body, flags = await chat_completion_tools_handler(body, model, user)
+
+                contexts.extend(flags.get("contexts", []))
+                citations.extend(flags.get("citations", []))
+            except Exception as e:
+                print(e)
+                pass
+
+            try:
+                body, flags = await chat_completion_files_handler(body)
+
+                contexts.extend(flags.get("contexts", []))
+                citations.extend(flags.get("citations", []))
+            except Exception as e:
+                print(e)
+                pass
+
+            # If context is not empty, insert it into the messages
+            if len(contexts) > 0:
+                context_string = "/n".join(contexts).strip()
+                prompt = get_last_user_message(body["messages"])
+                body["messages"] = add_or_update_system_message(
+                    rag_template(
+                        rag_app.state.config.RAG_TEMPLATE, context_string, prompt
+                    ),
+                    body["messages"],
                 )
 
-            modified_body_bytes = json.dumps(data).encode("utf-8")
+            # If there are citations, add them to the data_items
+            if len(citations) > 0:
+                data_items.append({"citations": citations})
+
+            modified_body_bytes = json.dumps(body).encode("utf-8")
             # Replace the request body with the modified one
             request._body = modified_body_bytes
             # Set custom header to ensure content-length matches new body length
@@ -721,9 +796,6 @@ def filter_pipeline(payload, user):
                 pass
 
     if "pipeline" not in app.state.MODELS[model_id]:
-        if "chat_id" in payload:
-            del payload["chat_id"]
-
         if "title" in payload:
             del payload["title"]
 
@@ -1225,6 +1297,9 @@ async def generate_title(form_data: dict, user=Depends(get_verified_user)):
             content={"detail": e.args[1]},
         )
 
+    if "chat_id" in payload:
+        del payload["chat_id"]
+
     return await generate_chat_completions(form_data=payload, user=user)
 
 
@@ -1285,6 +1360,9 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
             content={"detail": e.args[1]},
         )
 
+    if "chat_id" in payload:
+        del payload["chat_id"]
+
     return await generate_chat_completions(form_data=payload, user=user)
 
 
@@ -1349,6 +1427,9 @@ Message: """{{prompt}}"""
             content={"detail": e.args[1]},
         )
 
+    if "chat_id" in payload:
+        del payload["chat_id"]
+
     return await generate_chat_completions(form_data=payload, user=user)
 
 

+ 6 - 6
src/lib/components/chat/Chat.svelte

@@ -665,6 +665,7 @@
 		await tick();
 
 		const [res, controller] = await generateChatCompletion(localStorage.token, {
+			stream: true,
 			model: model.id,
 			messages: messagesBody,
 			options: {
@@ -682,8 +683,8 @@
 			keep_alive: $settings.keepAlive ?? undefined,
 			tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
 			files: files.length > 0 ? files : undefined,
-			citations: files.length > 0 ? true : undefined,
-			chat_id: $chatId
+			chat_id: $chatId,
+			id: responseMessageId
 		});
 
 		if (res && res.ok) {
@@ -912,8 +913,8 @@
 			const [res, controller] = await generateOpenAIChatCompletion(
 				localStorage.token,
 				{
-					model: model.id,
 					stream: true,
+					model: model.id,
 					stream_options:
 						model.info?.meta?.capabilities?.usage ?? false
 							? {
@@ -983,9 +984,8 @@
 					max_tokens: $settings?.params?.max_tokens ?? undefined,
 					tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
 					files: files.length > 0 ? files : undefined,
-					citations: files.length > 0 ? true : undefined,
-
-					chat_id: $chatId
+					chat_id: $chatId,
+					id: responseMessageId
 				},
 				`${WEBUI_BASE_URL}/api`
 			);