|
@@ -54,7 +54,7 @@ import uuid
|
|
import time
|
|
import time
|
|
import json
|
|
import json
|
|
|
|
|
|
-from typing import Iterator, Generator, Optional
|
|
|
|
|
|
+from typing import Iterator, Generator, AsyncGenerator, Optional
|
|
from pydantic import BaseModel
|
|
from pydantic import BaseModel
|
|
|
|
|
|
app = FastAPI()
|
|
app = FastAPI()
|
|
@@ -411,6 +411,25 @@ async def generate_function_chat_completion(form_data, user):
|
|
yield f"data: {json.dumps(finish_message)}\n\n"
|
|
yield f"data: {json.dumps(finish_message)}\n\n"
|
|
yield f"data: [DONE]"
|
|
yield f"data: [DONE]"
|
|
|
|
|
|
|
|
+ if isinstance(res, AsyncGenerator):
|
|
|
|
+ async for line in res:
|
|
|
|
+ if isinstance(line, BaseModel):
|
|
|
|
+ line = line.model_dump_json()
|
|
|
|
+ line = f"data: {line}"
|
|
|
|
+ if isinstance(line, dict):
|
|
|
|
+ line = f"data: {json.dumps(line)}"
|
|
|
|
+
|
|
|
|
+ try:
|
|
|
|
+ line = line.decode("utf-8")
|
|
|
|
+ except:
|
|
|
|
+ pass
|
|
|
|
+
|
|
|
|
+ if line.startswith("data:"):
|
|
|
|
+ yield f"{line}\n\n"
|
|
|
|
+ else:
|
|
|
|
+ line = stream_message_template(form_data["model"], line)
|
|
|
|
+ yield f"data: {json.dumps(line)}\n\n"
|
|
|
|
+
|
|
return StreamingResponse(stream_content(), media_type="text/event-stream")
|
|
return StreamingResponse(stream_content(), media_type="text/event-stream")
|
|
else:
|
|
else:
|
|
|
|
|
|
@@ -434,9 +453,12 @@ async def generate_function_chat_completion(form_data, user):
|
|
message = ""
|
|
message = ""
|
|
if isinstance(res, str):
|
|
if isinstance(res, str):
|
|
message = res
|
|
message = res
|
|
- if isinstance(res, Generator):
|
|
|
|
|
|
+ elif isinstance(res, Generator):
|
|
for stream in res:
|
|
for stream in res:
|
|
message = f"{message}{stream}"
|
|
message = f"{message}{stream}"
|
|
|
|
+ elif isinstance(res, AsyncGenerator):
|
|
|
|
+ async for stream in res:
|
|
|
|
+ message = f"{message}{stream}"
|
|
|
|
|
|
return {
|
|
return {
|
|
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
|
|
"id": f"{form_data['model']}-{str(uuid.uuid4())}",
|