Timothy J. Baek 9 месяцев назад
Родитель
Сommit
23e69bcdb4
1 измененных файлов с 24 добавлено и 2 удалено
  1. 24 2
      backend/apps/webui/main.py

+ 24 - 2
backend/apps/webui/main.py

@@ -54,7 +54,7 @@ import uuid
 import time
 import json
 
-from typing import Iterator, Generator, Optional
+from typing import Iterator, Generator, AsyncGenerator, Optional
 from pydantic import BaseModel
 
 app = FastAPI()
@@ -411,6 +411,25 @@ async def generate_function_chat_completion(form_data, user):
                     yield f"data: {json.dumps(finish_message)}\n\n"
                     yield f"data: [DONE]"
 
+                if isinstance(res, AsyncGenerator):
+                    async for line in res:
+                        if isinstance(line, BaseModel):
+                            line = line.model_dump_json()
+                            line = f"data: {line}"
+                        if isinstance(line, dict):
+                            line = f"data: {json.dumps(line)}"
+
+                        try:
+                            line = line.decode("utf-8")
+                        except:
+                            pass
+
+                        if line.startswith("data:"):
+                            yield f"{line}\n\n"
+                        else:
+                            line = stream_message_template(form_data["model"], line)
+                            yield f"data: {json.dumps(line)}\n\n"
+
             return StreamingResponse(stream_content(), media_type="text/event-stream")
         else:
 
@@ -434,9 +453,12 @@ async def generate_function_chat_completion(form_data, user):
                 message = ""
                 if isinstance(res, str):
                     message = res
-                if isinstance(res, Generator):
+                elif isinstance(res, Generator):
                     for stream in res:
                         message = f"{message}{stream}"
+                elif isinstance(res, AsyncGenerator):
+                    async for stream in res:
+                        message = f"{message}{stream}"
 
                 return {
                     "id": f"{form_data['model']}-{str(uuid.uuid4())}",