|
@@ -7,6 +7,8 @@ from typing import Any, Optional
|
|
|
import random
|
|
|
import json
|
|
|
import inspect
|
|
|
+import uuid
|
|
|
+import asyncio
|
|
|
|
|
|
from fastapi import Request
|
|
|
from starlette.responses import Response, StreamingResponse
|
|
@@ -15,6 +17,7 @@ from starlette.responses import Response, StreamingResponse
|
|
|
from open_webui.models.users import UserModel
|
|
|
|
|
|
from open_webui.socket.main import (
|
|
|
+ sio,
|
|
|
get_event_call,
|
|
|
get_event_emitter,
|
|
|
)
|
|
@@ -57,6 +60,93 @@ log = logging.getLogger(__name__)
|
|
|
log.setLevel(SRC_LOG_LEVELS["MAIN"])
|
|
|
|
|
|
|
|
|
+async def generate_direct_chat_completion(
|
|
|
+ request: Request,
|
|
|
+ form_data: dict,
|
|
|
+ user: Any,
|
|
|
+ models: dict,
|
|
|
+):
|
|
|
+ print("generate_direct_chat_completion")
|
|
|
+
|
|
|
+ metadata = form_data.pop("metadata", {})
|
|
|
+
|
|
|
+ user_id = metadata.get("user_id")
|
|
|
+ session_id = metadata.get("session_id")
|
|
|
+ request_id = str(uuid.uuid4()) # Generate a unique request ID
|
|
|
+
|
|
|
+ event_emitter = get_event_emitter(metadata)
|
|
|
+ event_caller = get_event_call(metadata)
|
|
|
+
|
|
|
+ channel = f"{user_id}:{session_id}:{request_id}"
|
|
|
+
|
|
|
+ if form_data.get("stream"):
|
|
|
+ q = asyncio.Queue()
|
|
|
+
|
|
|
+ # Define a generator to stream responses
|
|
|
+ async def event_generator():
|
|
|
+ nonlocal q
|
|
|
+
|
|
|
+ async def message_listener(sid, data):
|
|
|
+ """
|
|
|
+ Handle received socket messages and push them into the queue.
|
|
|
+ """
|
|
|
+ await q.put(data)
|
|
|
+
|
|
|
+ # Register the listener
|
|
|
+ sio.on(channel, message_listener)
|
|
|
+
|
|
|
+ # Start processing chat completion in background
|
|
|
+ await event_emitter(
|
|
|
+ {
|
|
|
+ "type": "request:chat:completion",
|
|
|
+ "data": {
|
|
|
+ "form_data": form_data,
|
|
|
+ "model": models[form_data["model"]],
|
|
|
+ "channel": channel,
|
|
|
+ "session_id": session_id,
|
|
|
+ },
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ try:
|
|
|
+ while True:
|
|
|
+ data = await q.get() # Wait for new messages
|
|
|
+ if isinstance(data, dict):
|
|
|
+ if "error" in data:
|
|
|
+ raise Exception(data["error"])
|
|
|
+
|
|
|
+ if "done" in data and data["done"]:
|
|
|
+ break # Stop streaming when 'done' is received
|
|
|
+
|
|
|
+ yield f"data: {json.dumps(data)}\n\n"
|
|
|
+ elif isinstance(data, str):
|
|
|
+ yield data
|
|
|
+ finally:
|
|
|
+ del sio.handlers["/"][channel] # Remove the listener
|
|
|
+
|
|
|
+ # Return the streaming response
|
|
|
+ return StreamingResponse(event_generator(), media_type="text/event-stream")
|
|
|
+ else:
|
|
|
+ res = await event_caller(
|
|
|
+ {
|
|
|
+ "type": "request:chat:completion",
|
|
|
+ "data": {
|
|
|
+ "form_data": form_data,
|
|
|
+ "model": models[form_data["model"]],
|
|
|
+ "channel": channel,
|
|
|
+ "session_id": session_id,
|
|
|
+ },
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ print(res)
|
|
|
+
|
|
|
+ if "error" in res:
|
|
|
+ raise Exception(res["error"])
|
|
|
+
|
|
|
+ return res
|
|
|
+
|
|
|
+
|
|
|
async def generate_chat_completion(
|
|
|
request: Request,
|
|
|
form_data: dict,
|
|
@@ -66,7 +156,12 @@ async def generate_chat_completion(
|
|
|
if BYPASS_MODEL_ACCESS_CONTROL:
|
|
|
bypass_filter = True
|
|
|
|
|
|
- models = request.app.state.MODELS
|
|
|
+ if request.state.direct and request.state.model:
|
|
|
+ models = {
|
|
|
+ request.state.model["id"]: request.state.model,
|
|
|
+ }
|
|
|
+ else:
|
|
|
+ models = request.app.state.MODELS
|
|
|
|
|
|
model_id = form_data["model"]
|
|
|
if model_id not in models:
|
|
@@ -87,78 +182,90 @@ async def generate_chat_completion(
|
|
|
except Exception as e:
|
|
|
raise e
|
|
|
|
|
|
- if model["owned_by"] == "arena":
|
|
|
- model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
|
|
|
- filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
|
|
|
- if model_ids and filter_mode == "exclude":
|
|
|
- model_ids = [
|
|
|
- model["id"]
|
|
|
- for model in list(request.app.state.MODELS.values())
|
|
|
- if model.get("owned_by") != "arena" and model["id"] not in model_ids
|
|
|
- ]
|
|
|
-
|
|
|
- selected_model_id = None
|
|
|
- if isinstance(model_ids, list) and model_ids:
|
|
|
- selected_model_id = random.choice(model_ids)
|
|
|
- else:
|
|
|
- model_ids = [
|
|
|
- model["id"]
|
|
|
- for model in list(request.app.state.MODELS.values())
|
|
|
- if model.get("owned_by") != "arena"
|
|
|
- ]
|
|
|
- selected_model_id = random.choice(model_ids)
|
|
|
-
|
|
|
- form_data["model"] = selected_model_id
|
|
|
-
|
|
|
- if form_data.get("stream") == True:
|
|
|
+ if request.state.direct:
|
|
|
+ return await generate_direct_chat_completion(
|
|
|
+ request, form_data, user=user, models=models
|
|
|
+ )
|
|
|
|
|
|
- async def stream_wrapper(stream):
|
|
|
- yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
|
|
|
- async for chunk in stream:
|
|
|
- yield chunk
|
|
|
+ else:
|
|
|
+ if model["owned_by"] == "arena":
|
|
|
+ model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
|
|
|
+ filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
|
|
|
+ if model_ids and filter_mode == "exclude":
|
|
|
+ model_ids = [
|
|
|
+ model["id"]
|
|
|
+ for model in list(request.app.state.MODELS.values())
|
|
|
+ if model.get("owned_by") != "arena" and model["id"] not in model_ids
|
|
|
+ ]
|
|
|
+
|
|
|
+ selected_model_id = None
|
|
|
+ if isinstance(model_ids, list) and model_ids:
|
|
|
+ selected_model_id = random.choice(model_ids)
|
|
|
+ else:
|
|
|
+ model_ids = [
|
|
|
+ model["id"]
|
|
|
+ for model in list(request.app.state.MODELS.values())
|
|
|
+ if model.get("owned_by") != "arena"
|
|
|
+ ]
|
|
|
+ selected_model_id = random.choice(model_ids)
|
|
|
+
|
|
|
+ form_data["model"] = selected_model_id
|
|
|
+
|
|
|
+ if form_data.get("stream") == True:
|
|
|
+
|
|
|
+ async def stream_wrapper(stream):
|
|
|
+ yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
|
|
|
+ async for chunk in stream:
|
|
|
+ yield chunk
|
|
|
+
|
|
|
+ response = await generate_chat_completion(
|
|
|
+ request, form_data, user, bypass_filter=True
|
|
|
+ )
|
|
|
+ return StreamingResponse(
|
|
|
+ stream_wrapper(response.body_iterator),
|
|
|
+ media_type="text/event-stream",
|
|
|
+ background=response.background,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ return {
|
|
|
+ **(
|
|
|
+ await generate_chat_completion(
|
|
|
+ request, form_data, user, bypass_filter=True
|
|
|
+ )
|
|
|
+ ),
|
|
|
+ "selected_model_id": selected_model_id,
|
|
|
+ }
|
|
|
|
|
|
- response = await generate_chat_completion(
|
|
|
- request, form_data, user, bypass_filter=True
|
|
|
+ if model.get("pipe"):
|
|
|
+ # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
|
|
|
+ return await generate_function_chat_completion(
|
|
|
+ request, form_data, user=user, models=models
|
|
|
)
|
|
|
- return StreamingResponse(
|
|
|
- stream_wrapper(response.body_iterator),
|
|
|
- media_type="text/event-stream",
|
|
|
- background=response.background,
|
|
|
+ if model["owned_by"] == "ollama":
|
|
|
+ # Using /ollama/api/chat endpoint
|
|
|
+ form_data = convert_payload_openai_to_ollama(form_data)
|
|
|
+ response = await generate_ollama_chat_completion(
|
|
|
+ request=request,
|
|
|
+ form_data=form_data,
|
|
|
+ user=user,
|
|
|
+ bypass_filter=bypass_filter,
|
|
|
)
|
|
|
+ if form_data.get("stream"):
|
|
|
+ response.headers["content-type"] = "text/event-stream"
|
|
|
+ return StreamingResponse(
|
|
|
+ convert_streaming_response_ollama_to_openai(response),
|
|
|
+ headers=dict(response.headers),
|
|
|
+ background=response.background,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ return convert_response_ollama_to_openai(response)
|
|
|
else:
|
|
|
- return {
|
|
|
- **(
|
|
|
- await generate_chat_completion(
|
|
|
- request, form_data, user, bypass_filter=True
|
|
|
- )
|
|
|
- ),
|
|
|
- "selected_model_id": selected_model_id,
|
|
|
- }
|
|
|
-
|
|
|
- if model.get("pipe"):
|
|
|
- # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
|
|
|
- return await generate_function_chat_completion(
|
|
|
- request, form_data, user=user, models=models
|
|
|
- )
|
|
|
- if model["owned_by"] == "ollama":
|
|
|
- # Using /ollama/api/chat endpoint
|
|
|
- form_data = convert_payload_openai_to_ollama(form_data)
|
|
|
- response = await generate_ollama_chat_completion(
|
|
|
- request=request, form_data=form_data, user=user, bypass_filter=bypass_filter
|
|
|
- )
|
|
|
- if form_data.get("stream"):
|
|
|
- response.headers["content-type"] = "text/event-stream"
|
|
|
- return StreamingResponse(
|
|
|
- convert_streaming_response_ollama_to_openai(response),
|
|
|
- headers=dict(response.headers),
|
|
|
- background=response.background,
|
|
|
+ return await generate_openai_chat_completion(
|
|
|
+ request=request,
|
|
|
+ form_data=form_data,
|
|
|
+ user=user,
|
|
|
+ bypass_filter=bypass_filter,
|
|
|
)
|
|
|
- else:
|
|
|
- return convert_response_ollama_to_openai(response)
|
|
|
- else:
|
|
|
- return await generate_openai_chat_completion(
|
|
|
- request=request, form_data=form_data, user=user, bypass_filter=bypass_filter
|
|
|
- )
|
|
|
|
|
|
|
|
|
chat_completion = generate_chat_completion
|
|
@@ -167,7 +274,13 @@ chat_completion = generate_chat_completion
|
|
|
async def chat_completed(request: Request, form_data: dict, user: Any):
|
|
|
if not request.app.state.MODELS:
|
|
|
await get_all_models(request)
|
|
|
- models = request.app.state.MODELS
|
|
|
+
|
|
|
+ if request.state.direct and request.state.model:
|
|
|
+ models = {
|
|
|
+ request.state.model["id"]: request.state.model,
|
|
|
+ }
|
|
|
+ else:
|
|
|
+ models = request.app.state.MODELS
|
|
|
|
|
|
data = form_data
|
|
|
model_id = data["model"]
|
|
@@ -227,7 +340,13 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
|
|
|
|
|
|
if not request.app.state.MODELS:
|
|
|
await get_all_models(request)
|
|
|
- models = request.app.state.MODELS
|
|
|
+
|
|
|
+ if request.state.direct and request.state.model:
|
|
|
+ models = {
|
|
|
+ request.state.model["id"]: request.state.model,
|
|
|
+ }
|
|
|
+ else:
|
|
|
+ models = request.app.state.MODELS
|
|
|
|
|
|
data = form_data
|
|
|
model_id = data["model"]
|