소스 검색

is_chat_completion_request helper, remove nesting

Michael Poluektov 8 달 전
부모
커밋
589efcdc5f
1개의 변경된 파일136개의 추가작업 그리고 144개의 파일을 삭제
  1. 136 144
      backend/main.py

+ 136 - 144
backend/main.py

@@ -605,129 +605,126 @@ async def chat_completion_files_handler(body):
     }
 
 
-class ChatCompletionMiddleware(BaseHTTPMiddleware):
-    async def dispatch(self, request: Request, call_next):
-        if request.method == "POST" and any(
-            endpoint in request.url.path
-            for endpoint in ["/ollama/api/chat", "/chat/completions"]
-        ):
-            log.debug(f"request.url.path: {request.url.path}")
+def is_chat_completion_request(request):
+    return request.method == "POST" and any(
+        endpoint in request.url.path
+        for endpoint in ["/ollama/api/chat", "/chat/completions"]
+    )
 
-            try:
-                body, model, user = await get_body_and_model_and_user(request)
-            except Exception as e:
-                return JSONResponse(
-                    status_code=status.HTTP_400_BAD_REQUEST,
-                    content={"detail": str(e)},
-                )
 
-            metadata = {
-                "chat_id": body.pop("chat_id", None),
-                "message_id": body.pop("id", None),
-                "session_id": body.pop("session_id", None),
-                "valves": body.pop("valves", None),
-            }
+class ChatCompletionMiddleware(BaseHTTPMiddleware):
+    async def dispatch(self, request: Request, call_next):
+        if not is_chat_completion_request(request):
+            return await call_next(request)
+        log.debug(f"request.url.path: {request.url.path}")
 
-            __event_emitter__ = get_event_emitter(metadata)
-            __event_call__ = get_event_call(metadata)
+        try:
+            body, model, user = await get_body_and_model_and_user(request)
+        except Exception as e:
+            return JSONResponse(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                content={"detail": str(e)},
+            )
 
-            # Initialize data_items to store additional data to be sent to the client
-            data_items = []
+        metadata = {
+            "chat_id": body.pop("chat_id", None),
+            "message_id": body.pop("id", None),
+            "session_id": body.pop("session_id", None),
+            "valves": body.pop("valves", None),
+        }
 
-            # Initialize context, and citations
-            contexts = []
-            citations = []
+        __event_emitter__ = get_event_emitter(metadata)
+        __event_call__ = get_event_call(metadata)
 
-            try:
-                body, flags = await chat_completion_functions_handler(
-                    body, model, user, __event_emitter__, __event_call__
-                )
-            except Exception as e:
-                return JSONResponse(
-                    status_code=status.HTTP_400_BAD_REQUEST,
-                    content={"detail": str(e)},
-                )
+        # Initialize data_items to store additional data to be sent to the client
+        data_items = []
 
-            try:
-                body, flags = await chat_completion_tools_handler(
-                    body, user, __event_emitter__, __event_call__
-                )
+        # Initialize context, and citations
+        contexts = []
+        citations = []
 
-                contexts.extend(flags.get("contexts", []))
-                citations.extend(flags.get("citations", []))
-            except Exception as e:
-                print(e)
-                pass
+        try:
+            body, flags = await chat_completion_functions_handler(
+                body, model, user, __event_emitter__, __event_call__
+            )
+        except Exception as e:
+            return JSONResponse(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                content={"detail": str(e)},
+            )
 
-            try:
-                body, flags = await chat_completion_files_handler(body)
+        try:
+            body, flags = await chat_completion_tools_handler(
+                body, user, __event_emitter__, __event_call__
+            )
 
-                contexts.extend(flags.get("contexts", []))
-                citations.extend(flags.get("citations", []))
-            except Exception as e:
-                print(e)
-                pass
+            contexts.extend(flags.get("contexts", []))
+            citations.extend(flags.get("citations", []))
+        except Exception as e:
+            print(e)
+            pass
 
-            # If context is not empty, insert it into the messages
-            if len(contexts) > 0:
-                context_string = "/n".join(contexts).strip()
-                prompt = get_last_user_message(body["messages"])
-
-                # Workaround for Ollama 2.0+ system prompt issue
-                # TODO: replace with add_or_update_system_message
-                if model["owned_by"] == "ollama":
-                    body["messages"] = prepend_to_first_user_message_content(
-                        rag_template(
-                            rag_app.state.config.RAG_TEMPLATE, context_string, prompt
-                        ),
-                        body["messages"],
-                    )
-                else:
-                    body["messages"] = add_or_update_system_message(
-                        rag_template(
-                            rag_app.state.config.RAG_TEMPLATE, context_string, prompt
-                        ),
-                        body["messages"],
-                    )
-
-            # If there are citations, add them to the data_items
-            if len(citations) > 0:
-                data_items.append({"citations": citations})
-
-            body["metadata"] = metadata
-            modified_body_bytes = json.dumps(body).encode("utf-8")
-            # Replace the request body with the modified one
-            request._body = modified_body_bytes
-            # Set custom header to ensure content-length matches new body length
-            request.headers.__dict__["_list"] = [
-                (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
-                *[
-                    (k, v)
-                    for k, v in request.headers.raw
-                    if k.lower() != b"content-length"
-                ],
-            ]
+        try:
+            body, flags = await chat_completion_files_handler(body)
 
-            response = await call_next(request)
-            if isinstance(response, StreamingResponse):
-                # If it's a streaming response, inject it as SSE event or NDJSON line
-                content_type = response.headers.get("Content-Type")
-                if "text/event-stream" in content_type:
-                    return StreamingResponse(
-                        self.openai_stream_wrapper(response.body_iterator, data_items),
-                    )
-                if "application/x-ndjson" in content_type:
-                    return StreamingResponse(
-                        self.ollama_stream_wrapper(response.body_iterator, data_items),
-                    )
-
-                return response
+            contexts.extend(flags.get("contexts", []))
+            citations.extend(flags.get("citations", []))
+        except Exception as e:
+            print(e)
+            pass
+
+        # If context is not empty, insert it into the messages
+        if len(contexts) > 0:
+            context_string = "/n".join(contexts).strip()
+            prompt = get_last_user_message(body["messages"])
+
+            # Workaround for Ollama 2.0+ system prompt issue
+            # TODO: replace with add_or_update_system_message
+            if model["owned_by"] == "ollama":
+                body["messages"] = prepend_to_first_user_message_content(
+                    rag_template(
+                        rag_app.state.config.RAG_TEMPLATE, context_string, prompt
+                    ),
+                    body["messages"],
+                )
             else:
-                return response
+                body["messages"] = add_or_update_system_message(
+                    rag_template(
+                        rag_app.state.config.RAG_TEMPLATE, context_string, prompt
+                    ),
+                    body["messages"],
+                )
+
+        # If there are citations, add them to the data_items
+        if len(citations) > 0:
+            data_items.append({"citations": citations})
+
+        body["metadata"] = metadata
+        modified_body_bytes = json.dumps(body).encode("utf-8")
+        # Replace the request body with the modified one
+        request._body = modified_body_bytes
+        # Set custom header to ensure content-length matches new body length
+        request.headers.__dict__["_list"] = [
+            (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
+            *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
+        ]
 
-        # If it's not a chat completion request, just pass it through
         response = await call_next(request)
-        return response
+        if isinstance(response, StreamingResponse):
+            # If it's a streaming response, inject it as SSE event or NDJSON line
+            content_type = response.headers.get("Content-Type")
+            if "text/event-stream" in content_type:
+                return StreamingResponse(
+                    self.openai_stream_wrapper(response.body_iterator, data_items),
+                )
+            if "application/x-ndjson" in content_type:
+                return StreamingResponse(
+                    self.ollama_stream_wrapper(response.body_iterator, data_items),
+                )
+
+            return response
+        else:
+            return response
 
     async def _receive(self, body: bytes):
         return {"type": "http.request", "body": body, "more_body": False}
@@ -820,44 +817,39 @@ def filter_pipeline(payload, user):
 
 class PipelineMiddleware(BaseHTTPMiddleware):
     async def dispatch(self, request: Request, call_next):
-        if request.method == "POST" and (
-            "/ollama/api/chat" in request.url.path
-            or "/chat/completions" in request.url.path
-        ):
-            log.debug(f"request.url.path: {request.url.path}")
-
-            # Read the original request body
-            body = await request.body()
-            # Decode body to string
-            body_str = body.decode("utf-8")
-            # Parse string to JSON
-            data = json.loads(body_str) if body_str else {}
-
-            user = get_current_user(
-                request,
-                get_http_authorization_cred(request.headers.get("Authorization")),
-            )
+        if not is_chat_completion_request(request):
+            return await call_next(request)
 
-            try:
-                data = filter_pipeline(data, user)
-            except Exception as e:
-                return JSONResponse(
-                    status_code=e.args[0],
-                    content={"detail": e.args[1]},
-                )
+        log.debug(f"request.url.path: {request.url.path}")
 
-            modified_body_bytes = json.dumps(data).encode("utf-8")
-            # Replace the request body with the modified one
-            request._body = modified_body_bytes
-            # Set custom header to ensure content-length matches new body length
-            request.headers.__dict__["_list"] = [
-                (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
-                *[
-                    (k, v)
-                    for k, v in request.headers.raw
-                    if k.lower() != b"content-length"
-                ],
-            ]
+        # Read the original request body
+        body = await request.body()
+        # Decode body to string
+        body_str = body.decode("utf-8")
+        # Parse string to JSON
+        data = json.loads(body_str) if body_str else {}
+
+        user = get_current_user(
+            request,
+            get_http_authorization_cred(request.headers.get("Authorization")),
+        )
+
+        try:
+            data = filter_pipeline(data, user)
+        except Exception as e:
+            return JSONResponse(
+                status_code=e.args[0],
+                content={"detail": e.args[1]},
+            )
+
+        modified_body_bytes = json.dumps(data).encode("utf-8")
+        # Replace the request body with the modified one
+        request._body = modified_body_bytes
+        # Set custom header to ensure content-length matches new body length
+        request.headers.__dict__["_list"] = [
+            (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
+            *[(k, v) for k, v in request.headers.raw if k.lower() != b"content-length"],
+        ]
 
         response = await call_next(request)
         return response