浏览代码

refac: task ollama stream support

Timothy J. Baek 7 月之前
父节点
当前提交
3a0a1aca11
共有 3 个文件被更改,包括 69 次插入12 次删除
  1. 36 5
      backend/open_webui/main.py
  2. 12 4
      backend/open_webui/utils/misc.py
  3. 21 3
      backend/open_webui/utils/response.py

+ 36 - 5
backend/open_webui/main.py

@@ -138,7 +138,10 @@ from open_webui.utils.utils import (
 from open_webui.utils.webhook import post_webhook
 from open_webui.utils.webhook import post_webhook
 
 
 from open_webui.utils.payload import convert_payload_openai_to_ollama
 from open_webui.utils.payload import convert_payload_openai_to_ollama
-from open_webui.utils.response import convert_response_ollama_to_openai
+from open_webui.utils.response import (
+    convert_response_ollama_to_openai,
+    convert_streaming_response_ollama_to_openai,
+)
 
 
 if SAFE_MODE:
 if SAFE_MODE:
     print("SAFE MODE ENABLED")
     print("SAFE MODE ENABLED")
@@ -1470,7 +1473,14 @@ Prompt: {{prompt:middletruncate:8000}}"""
         payload = convert_payload_openai_to_ollama(payload)
         payload = convert_payload_openai_to_ollama(payload)
         form_data = GenerateChatCompletionForm(**payload)
         form_data = GenerateChatCompletionForm(**payload)
         response = await generate_ollama_chat_completion(form_data=form_data, user=user)
         response = await generate_ollama_chat_completion(form_data=form_data, user=user)
-        return convert_response_ollama_to_openai(response)
+        if form_data.stream:
+            response.headers["content-type"] = "text/event-stream"
+            return StreamingResponse(
+                convert_streaming_response_ollama_to_openai(response),
+                headers=dict(response.headers),
+            )
+        else:
+            return convert_response_ollama_to_openai(response)
     else:
     else:
         return await generate_chat_completions(form_data=payload, user=user)
         return await generate_chat_completions(form_data=payload, user=user)
 
 
@@ -1554,7 +1564,14 @@ Search Query:"""
         payload = convert_payload_openai_to_ollama(payload)
         payload = convert_payload_openai_to_ollama(payload)
         form_data = GenerateChatCompletionForm(**payload)
         form_data = GenerateChatCompletionForm(**payload)
         response = await generate_ollama_chat_completion(form_data=form_data, user=user)
         response = await generate_ollama_chat_completion(form_data=form_data, user=user)
-        return convert_response_ollama_to_openai(response)
+        if form_data.stream:
+            response.headers["content-type"] = "text/event-stream"
+            return StreamingResponse(
+                convert_streaming_response_ollama_to_openai(response),
+                headers=dict(response.headers),
+            )
+        else:
+            return convert_response_ollama_to_openai(response)
     else:
     else:
         return await generate_chat_completions(form_data=payload, user=user)
         return await generate_chat_completions(form_data=payload, user=user)
 
 
@@ -1629,7 +1646,14 @@ Message: """{{prompt}}"""
         payload = convert_payload_openai_to_ollama(payload)
         payload = convert_payload_openai_to_ollama(payload)
         form_data = GenerateChatCompletionForm(**payload)
         form_data = GenerateChatCompletionForm(**payload)
         response = await generate_ollama_chat_completion(form_data=form_data, user=user)
         response = await generate_ollama_chat_completion(form_data=form_data, user=user)
-        return convert_response_ollama_to_openai(response)
+        if form_data.stream:
+            response.headers["content-type"] = "text/event-stream"
+            return StreamingResponse(
+                convert_streaming_response_ollama_to_openai(response),
+                headers=dict(response.headers),
+            )
+        else:
+            return convert_response_ollama_to_openai(response)
     else:
     else:
         return await generate_chat_completions(form_data=payload, user=user)
         return await generate_chat_completions(form_data=payload, user=user)
 
 
@@ -1694,7 +1718,14 @@ Responses from models: {{responses}}"""
         payload = convert_payload_openai_to_ollama(payload)
         payload = convert_payload_openai_to_ollama(payload)
         form_data = GenerateChatCompletionForm(**payload)
         form_data = GenerateChatCompletionForm(**payload)
         response = await generate_ollama_chat_completion(form_data=form_data, user=user)
         response = await generate_ollama_chat_completion(form_data=form_data, user=user)
-        return convert_response_ollama_to_openai(response)
+        if form_data.stream:
+            response.headers["content-type"] = "text/event-stream"
+            return StreamingResponse(
+                convert_streaming_response_ollama_to_openai(response),
+                headers=dict(response.headers),
+            )
+        else:
+            return convert_response_ollama_to_openai(response)
     else:
     else:
         return await generate_chat_completions(form_data=payload, user=user)
         return await generate_chat_completions(form_data=payload, user=user)
 
 

+ 12 - 4
backend/open_webui/utils/misc.py

@@ -105,17 +105,25 @@ def openai_chat_message_template(model: str):
     }
     }
 
 
 
 
-def openai_chat_chunk_message_template(model: str, message: str) -> dict:
+def openai_chat_chunk_message_template(
+    model: str, message: Optional[str] = None
+) -> dict:
     template = openai_chat_message_template(model)
     template = openai_chat_message_template(model)
     template["object"] = "chat.completion.chunk"
     template["object"] = "chat.completion.chunk"
-    template["choices"][0]["delta"] = {"content": message}
+    if message:
+        template["choices"][0]["delta"] = {"content": message}
+    else:
+        template["choices"][0]["finish_reason"] = "stop"
     return template
     return template
 
 
 
 
-def openai_chat_completion_message_template(model: str, message: str) -> dict:
+def openai_chat_completion_message_template(
+    model: str, message: Optional[str] = None
+) -> dict:
     template = openai_chat_message_template(model)
     template = openai_chat_message_template(model)
     template["object"] = "chat.completion"
     template["object"] = "chat.completion"
-    template["choices"][0]["message"] = {"content": message, "role": "assistant"}
+    if message:
+        template["choices"][0]["message"] = {"content": message, "role": "assistant"}
     template["choices"][0]["finish_reason"] = "stop"
     template["choices"][0]["finish_reason"] = "stop"
     return template
     return template
 
 

+ 21 - 3
backend/open_webui/utils/response.py

@@ -1,10 +1,9 @@
-from open_webui.utils.task import prompt_template
+import json
 from open_webui.utils.misc import (
 from open_webui.utils.misc import (
+    openai_chat_chunk_message_template,
     openai_chat_completion_message_template,
     openai_chat_completion_message_template,
 )
 )
 
 
-from typing import Callable, Optional
-
 
 
 def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
 def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
     model = ollama_response.get("model", "ollama")
     model = ollama_response.get("model", "ollama")
@@ -12,3 +11,22 @@ def convert_response_ollama_to_openai(ollama_response: dict) -> dict:
 
 
     response = openai_chat_completion_message_template(model, message_content)
     response = openai_chat_completion_message_template(model, message_content)
     return response
     return response
+
+
+async def convert_streaming_response_ollama_to_openai(ollama_streaming_response):
+    async for data in ollama_streaming_response.body_iterator:
+        data = json.loads(data)
+
+        model = data.get("model", "ollama")
+        message_content = data.get("message", {}).get("content", "")
+        done = data.get("done", False)
+
+        data = openai_chat_chunk_message_template(
+            model, message_content if not done else None
+        )
+
+        line = f"data: {json.dumps(data)}\n\n"
+        if done:
+            line += "data: [DONE]\n\n"
+
+        yield line