Browse Source

refac: chat completion middleware

Timothy J. Baek 10 months ago
parent
commit
6b8a7b9939
1 changed files with 20 additions and 22 deletions
  1. 20 22
      backend/main.py

+ 20 - 22
backend/main.py

@@ -316,7 +316,7 @@ async def get_function_call_response(
 
 
 class ChatCompletionMiddleware(BaseHTTPMiddleware):
 class ChatCompletionMiddleware(BaseHTTPMiddleware):
     async def dispatch(self, request: Request, call_next):
     async def dispatch(self, request: Request, call_next):
-        return_citations = False
+        data_items = []
 
 
         if request.method == "POST" and (
         if request.method == "POST" and (
             "/ollama/api/chat" in request.url.path
             "/ollama/api/chat" in request.url.path
@@ -326,23 +326,17 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
 
 
             # Read the original request body
             # Read the original request body
             body = await request.body()
             body = await request.body()
-            # Decode body to string
             body_str = body.decode("utf-8")
             body_str = body.decode("utf-8")
-            # Parse string to JSON
             data = json.loads(body_str) if body_str else {}
             data = json.loads(body_str) if body_str else {}
 
 
+            model_id = data["model"]
             user = get_current_user(
             user = get_current_user(
                 request,
                 request,
                 get_http_authorization_cred(request.headers.get("Authorization")),
                 get_http_authorization_cred(request.headers.get("Authorization")),
             )
             )
 
 
-            # Remove the citations from the body
-            return_citations = data.get("citations", False)
-            if "citations" in data:
-                del data["citations"]
-
             # Set the task model
             # Set the task model
-            task_model_id = data["model"]
+            task_model_id = model_id
             if task_model_id not in app.state.MODELS:
             if task_model_id not in app.state.MODELS:
                 raise HTTPException(
                 raise HTTPException(
                     status_code=status.HTTP_404_NOT_FOUND,
                     status_code=status.HTTP_404_NOT_FOUND,
@@ -364,12 +358,11 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                 ):
                 ):
                     task_model_id = app.state.config.TASK_MODEL_EXTERNAL
                     task_model_id = app.state.config.TASK_MODEL_EXTERNAL
 
 
+            skip_files = False
             prompt = get_last_user_message(data["messages"])
             prompt = get_last_user_message(data["messages"])
             context = ""
             context = ""
 
 
             # If tool_ids field is present, call the functions
             # If tool_ids field is present, call the functions
-
-            skip_files = False
             if "tool_ids" in data:
             if "tool_ids" in data:
                 print(data["tool_ids"])
                 print(data["tool_ids"])
                 for tool_id in data["tool_ids"]:
                 for tool_id in data["tool_ids"]:
@@ -415,8 +408,9 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                         context += ("\n" if context != "" else "") + rag_context
                         context += ("\n" if context != "" else "") + rag_context
 
 
                     log.debug(f"rag_context: {rag_context}, citations: {citations}")
                     log.debug(f"rag_context: {rag_context}, citations: {citations}")
-                else:
-                    return_citations = False
+
+                    if citations:
+                        data_items.append({"citations": citations})
 
 
                 del data["files"]
                 del data["files"]
 
 
@@ -426,7 +420,7 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                 )
                 )
                 print(system_prompt)
                 print(system_prompt)
                 data["messages"] = add_or_update_system_message(
                 data["messages"] = add_or_update_system_message(
-                    f"\n{system_prompt}", data["messages"]
+                    system_prompt, data["messages"]
                 )
                 )
 
 
             modified_body_bytes = json.dumps(data).encode("utf-8")
             modified_body_bytes = json.dumps(data).encode("utf-8")
@@ -444,18 +438,18 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
 
 
         response = await call_next(request)
         response = await call_next(request)
 
 
-        if return_citations:
-            # Inject the citations into the response
+        # If there are data_items to inject into the response
+        if len(data_items) > 0:
             if isinstance(response, StreamingResponse):
             if isinstance(response, StreamingResponse):
                 # If it's a streaming response, inject it as SSE event or NDJSON line
                 # If it's a streaming response, inject it as SSE event or NDJSON line
                 content_type = response.headers.get("Content-Type")
                 content_type = response.headers.get("Content-Type")
                 if "text/event-stream" in content_type:
                 if "text/event-stream" in content_type:
                     return StreamingResponse(
                     return StreamingResponse(
-                        self.openai_stream_wrapper(response.body_iterator, citations),
+                        self.openai_stream_wrapper(response.body_iterator, data_items),
                     )
                     )
                 if "application/x-ndjson" in content_type:
                 if "application/x-ndjson" in content_type:
                     return StreamingResponse(
                     return StreamingResponse(
-                        self.ollama_stream_wrapper(response.body_iterator, citations),
+                        self.ollama_stream_wrapper(response.body_iterator, data_items),
                     )
                     )
 
 
         return response
         return response
@@ -463,13 +457,17 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
     async def _receive(self, body: bytes):
     async def _receive(self, body: bytes):
         return {"type": "http.request", "body": body, "more_body": False}
         return {"type": "http.request", "body": body, "more_body": False}
 
 
-    async def openai_stream_wrapper(self, original_generator, citations):
-        yield f"data: {json.dumps({'citations': citations})}\n\n"
+    async def openai_stream_wrapper(self, original_generator, data_items):
+        for item in data_items:
+            yield f"data: {json.dumps(item)}\n\n"
+
         async for data in original_generator:
         async for data in original_generator:
             yield data
             yield data
 
 
-    async def ollama_stream_wrapper(self, original_generator, citations):
-        yield f"{json.dumps({'citations': citations})}\n"
+    async def ollama_stream_wrapper(self, original_generator, data_items):
+        for item in data_items:
+            yield f"{json.dumps(item)}\n"
+
         async for data in original_generator:
         async for data in original_generator:
             yield data
             yield data