Browse Source

enh: arena model send selected model id

Timothy J. Baek 6 tháng trước cách đây
mục cha
commit
6d52f913d2
1 tập tin đã thay đổi với 21 bổ sung5 xóa
  1. 21 5
      backend/open_webui/main.py

+ 21 - 5
backend/open_webui/main.py

@@ -1102,9 +1102,9 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
 
     if model["owned_by"] == "arena":
         model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
-        model_id = None
+        selected_model_id = None
         if isinstance(model_ids, list) and model_ids:
-            model_id = random.choice(model_ids)
+            selected_model_id = random.choice(model_ids)
         else:
             model_ids = [
                 model["id"]
@@ -1112,10 +1112,26 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
                 if model.get("owned_by") != "arena"
                 and not model.get("info", {}).get("meta", {}).get("hidden", False)
             ]
-            model_id = random.choice(model_ids)
+            selected_model_id = random.choice(model_ids)
 
-        form_data["model"] = model_id
-        return await generate_chat_completions(form_data, user)
+        form_data["model"] = selected_model_id
+
+        if form_data.get("stream") == True:
+
+            async def stream_wrapper(stream):
+                yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
+                async for chunk in stream:
+                    yield chunk
+
+            response = await generate_chat_completions(form_data, user)
+            return StreamingResponse(
+                stream_wrapper(response.body_iterator), media_type="text/event-stream"
+            )
+        else:
+            return {
+                **(await generate_chat_completions(form_data, user)),
+                "selected_model_id": selected_model_id,
+            }
     if model.get("pipe"):
         return await generate_function_chat_completion(form_data, user=user)
     if model["owned_by"] == "ollama":