|
@@ -680,26 +680,26 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
|
|
|
]
|
|
|
|
|
|
response = await call_next(request)
|
|
|
- if isinstance(response, StreamingResponse):
|
|
|
- content_type = response.headers["Content-Type"]
|
|
|
- is_openai = "text/event-stream" in content_type
|
|
|
- is_ollama = "application/x-ndjson" in content_type
|
|
|
- if not is_openai and not is_ollama:
|
|
|
- return response
|
|
|
+ if not isinstance(response, StreamingResponse):
|
|
|
+ return response
|
|
|
|
|
|
- def wrap_item(item):
|
|
|
- return f"data: {item}\n\n" if is_openai else f"{item}\n"
|
|
|
+ content_type = response.headers["Content-Type"]
|
|
|
+ is_openai = "text/event-stream" in content_type
|
|
|
+ is_ollama = "application/x-ndjson" in content_type
|
|
|
+ if not is_openai and not is_ollama:
|
|
|
+ return response
|
|
|
|
|
|
- async def stream_wrapper(original_generator, data_items):
|
|
|
- for item in data_items:
|
|
|
- yield wrap_item(json.dumps(item))
|
|
|
+ def wrap_item(item):
|
|
|
+ return f"data: {item}\n\n" if is_openai else f"{item}\n"
|
|
|
|
|
|
- async for data in original_generator:
|
|
|
- yield data
|
|
|
+ async def stream_wrapper(original_generator, data_items):
|
|
|
+ for item in data_items:
|
|
|
+ yield wrap_item(json.dumps(item))
|
|
|
|
|
|
- return StreamingResponse(stream_wrapper(response.body_iterator, data_items))
|
|
|
+ async for data in original_generator:
|
|
|
+ yield data
|
|
|
|
|
|
- return response
|
|
|
+ return StreamingResponse(stream_wrapper(response.body_iterator, data_items))
|
|
|
|
|
|
async def _receive(self, body: bytes):
|
|
|
return {"type": "http.request", "body": body, "more_body": False}
|