浏览代码

feat: direct connections integration

Timothy Jaeryang Baek 2 月之前
父节点
当前提交
c83e68282d

+ 37 - 12
backend/open_webui/main.py

@@ -900,20 +900,30 @@ async def chat_completion(
     if not request.app.state.MODELS:
     if not request.app.state.MODELS:
         await get_all_models(request)
         await get_all_models(request)
 
 
+    model_item = form_data.pop("model_item", {})
     tasks = form_data.pop("background_tasks", None)
     tasks = form_data.pop("background_tasks", None)
+
     try:
     try:
-        model_id = form_data.get("model", None)
-        if model_id not in request.app.state.MODELS:
-            raise Exception("Model not found")
-        model = request.app.state.MODELS[model_id]
-        model_info = Models.get_model_by_id(model_id)
-
-        # Check if user has access to the model
-        if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user":
-            try:
-                check_model_access(user, model)
-            except Exception as e:
-                raise e
+        if not model_item.get("direct", False):
+            model_id = form_data.get("model", None)
+            if model_id not in request.app.state.MODELS:
+                raise Exception("Model not found")
+
+            model = request.app.state.MODELS[model_id]
+            model_info = Models.get_model_by_id(model_id)
+
+            # Check if user has access to the model
+            if not BYPASS_MODEL_ACCESS_CONTROL and user.role == "user":
+                try:
+                    check_model_access(user, model)
+                except Exception as e:
+                    raise e
+        else:
+            model = model_item
+            model_info = None
+
+            request.state.direct = True
+            request.state.model = model
 
 
         metadata = {
         metadata = {
             "user_id": user.id,
             "user_id": user.id,
@@ -925,6 +935,7 @@ async def chat_completion(
             "features": form_data.get("features", None),
             "features": form_data.get("features", None),
             "variables": form_data.get("variables", None),
             "variables": form_data.get("variables", None),
             "model": model_info,
             "model": model_info,
+            "direct": model_item.get("direct", False),
             **(
             **(
                 {"function_calling": "native"}
                 {"function_calling": "native"}
                 if form_data.get("params", {}).get("function_calling") == "native"
                 if form_data.get("params", {}).get("function_calling") == "native"
@@ -936,6 +947,7 @@ async def chat_completion(
                 else {}
                 else {}
             ),
             ),
         }
         }
+        request.state.metadata = metadata
         form_data["metadata"] = metadata
         form_data["metadata"] = metadata
 
 
         form_data, metadata, events = await process_chat_payload(
         form_data, metadata, events = await process_chat_payload(
@@ -943,6 +955,7 @@ async def chat_completion(
         )
         )
 
 
     except Exception as e:
     except Exception as e:
+        log.debug(f"Error processing chat payload: {e}")
         raise HTTPException(
         raise HTTPException(
             status_code=status.HTTP_400_BAD_REQUEST,
             status_code=status.HTTP_400_BAD_REQUEST,
             detail=str(e),
             detail=str(e),
@@ -971,6 +984,12 @@ async def chat_completed(
     request: Request, form_data: dict, user=Depends(get_verified_user)
     request: Request, form_data: dict, user=Depends(get_verified_user)
 ):
 ):
     try:
     try:
+        model_item = form_data.pop("model_item", {})
+
+        if model_item.get("direct", False):
+            request.state.direct = True
+            request.state.model = model_item
+
         return await chat_completed_handler(request, form_data, user)
         return await chat_completed_handler(request, form_data, user)
     except Exception as e:
     except Exception as e:
         raise HTTPException(
         raise HTTPException(
@@ -984,6 +1003,12 @@ async def chat_action(
     request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user)
     request: Request, action_id: str, form_data: dict, user=Depends(get_verified_user)
 ):
 ):
     try:
     try:
+        model_item = form_data.pop("model_item", {})
+
+        if model_item.get("direct", False):
+            request.state.direct = True
+            request.state.model = model_item
+
         return await chat_action_handler(request, action_id, form_data, user)
         return await chat_action_handler(request, action_id, form_data, user)
     except Exception as e:
     except Exception as e:
         raise HTTPException(
         raise HTTPException(

+ 54 - 8
backend/open_webui/routers/tasks.py

@@ -139,7 +139,12 @@ async def update_task_config(
 async def generate_title(
 async def generate_title(
     request: Request, form_data: dict, user=Depends(get_verified_user)
     request: Request, form_data: dict, user=Depends(get_verified_user)
 ):
 ):
-    models = request.app.state.MODELS
+    if request.state.direct and request.state.model:
+        models = {
+            request.state.model["id"]: request.state.model,
+        }
+    else:
+        models = request.app.state.MODELS
 
 
     model_id = form_data["model"]
     model_id = form_data["model"]
     if model_id not in models:
     if model_id not in models:
@@ -198,6 +203,7 @@ async def generate_title(
             }
             }
         ),
         ),
         "metadata": {
         "metadata": {
+            **(request.state.metadata if request.state.metadata else {}),
             "task": str(TASKS.TITLE_GENERATION),
             "task": str(TASKS.TITLE_GENERATION),
             "task_body": form_data,
             "task_body": form_data,
             "chat_id": form_data.get("chat_id", None),
             "chat_id": form_data.get("chat_id", None),
@@ -225,7 +231,12 @@ async def generate_chat_tags(
             content={"detail": "Tags generation is disabled"},
             content={"detail": "Tags generation is disabled"},
         )
         )
 
 
-    models = request.app.state.MODELS
+    if request.state.direct and request.state.model:
+        models = {
+            request.state.model["id"]: request.state.model,
+        }
+    else:
+        models = request.app.state.MODELS
 
 
     model_id = form_data["model"]
     model_id = form_data["model"]
     if model_id not in models:
     if model_id not in models:
@@ -261,6 +272,7 @@ async def generate_chat_tags(
         "messages": [{"role": "user", "content": content}],
         "messages": [{"role": "user", "content": content}],
         "stream": False,
         "stream": False,
         "metadata": {
         "metadata": {
+            **(request.state.metadata if request.state.metadata else {}),
             "task": str(TASKS.TAGS_GENERATION),
             "task": str(TASKS.TAGS_GENERATION),
             "task_body": form_data,
             "task_body": form_data,
             "chat_id": form_data.get("chat_id", None),
             "chat_id": form_data.get("chat_id", None),
@@ -281,7 +293,12 @@ async def generate_chat_tags(
 async def generate_image_prompt(
 async def generate_image_prompt(
     request: Request, form_data: dict, user=Depends(get_verified_user)
     request: Request, form_data: dict, user=Depends(get_verified_user)
 ):
 ):
-    models = request.app.state.MODELS
+    if request.state.direct and request.state.model:
+        models = {
+            request.state.model["id"]: request.state.model,
+        }
+    else:
+        models = request.app.state.MODELS
 
 
     model_id = form_data["model"]
     model_id = form_data["model"]
     if model_id not in models:
     if model_id not in models:
@@ -321,6 +338,7 @@ async def generate_image_prompt(
         "messages": [{"role": "user", "content": content}],
         "messages": [{"role": "user", "content": content}],
         "stream": False,
         "stream": False,
         "metadata": {
         "metadata": {
+            **(request.state.metadata if request.state.metadata else {}),
             "task": str(TASKS.IMAGE_PROMPT_GENERATION),
             "task": str(TASKS.IMAGE_PROMPT_GENERATION),
             "task_body": form_data,
             "task_body": form_data,
             "chat_id": form_data.get("chat_id", None),
             "chat_id": form_data.get("chat_id", None),
@@ -356,7 +374,12 @@ async def generate_queries(
                 detail=f"Query generation is disabled",
                 detail=f"Query generation is disabled",
             )
             )
 
 
-    models = request.app.state.MODELS
+    if request.state.direct and request.state.model:
+        models = {
+            request.state.model["id"]: request.state.model,
+        }
+    else:
+        models = request.app.state.MODELS
 
 
     model_id = form_data["model"]
     model_id = form_data["model"]
     if model_id not in models:
     if model_id not in models:
@@ -392,6 +415,7 @@ async def generate_queries(
         "messages": [{"role": "user", "content": content}],
         "messages": [{"role": "user", "content": content}],
         "stream": False,
         "stream": False,
         "metadata": {
         "metadata": {
+            **(request.state.metadata if request.state.metadata else {}),
             "task": str(TASKS.QUERY_GENERATION),
             "task": str(TASKS.QUERY_GENERATION),
             "task_body": form_data,
             "task_body": form_data,
             "chat_id": form_data.get("chat_id", None),
             "chat_id": form_data.get("chat_id", None),
@@ -431,7 +455,12 @@ async def generate_autocompletion(
                 detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
                 detail=f"Input prompt exceeds maximum length of {request.app.state.config.AUTOCOMPLETE_GENERATION_INPUT_MAX_LENGTH}",
             )
             )
 
 
-    models = request.app.state.MODELS
+    if request.state.direct and request.state.model:
+        models = {
+            request.state.model["id"]: request.state.model,
+        }
+    else:
+        models = request.app.state.MODELS
 
 
     model_id = form_data["model"]
     model_id = form_data["model"]
     if model_id not in models:
     if model_id not in models:
@@ -467,6 +496,7 @@ async def generate_autocompletion(
         "messages": [{"role": "user", "content": content}],
         "messages": [{"role": "user", "content": content}],
         "stream": False,
         "stream": False,
         "metadata": {
         "metadata": {
+            **(request.state.metadata if request.state.metadata else {}),
             "task": str(TASKS.AUTOCOMPLETE_GENERATION),
             "task": str(TASKS.AUTOCOMPLETE_GENERATION),
             "task_body": form_data,
             "task_body": form_data,
             "chat_id": form_data.get("chat_id", None),
             "chat_id": form_data.get("chat_id", None),
@@ -488,7 +518,12 @@ async def generate_emoji(
     request: Request, form_data: dict, user=Depends(get_verified_user)
     request: Request, form_data: dict, user=Depends(get_verified_user)
 ):
 ):
 
 
-    models = request.app.state.MODELS
+    if request.state.direct and request.state.model:
+        models = {
+            request.state.model["id"]: request.state.model,
+        }
+    else:
+        models = request.app.state.MODELS
 
 
     model_id = form_data["model"]
     model_id = form_data["model"]
     if model_id not in models:
     if model_id not in models:
@@ -531,7 +566,11 @@ async def generate_emoji(
             }
             }
         ),
         ),
         "chat_id": form_data.get("chat_id", None),
         "chat_id": form_data.get("chat_id", None),
-        "metadata": {"task": str(TASKS.EMOJI_GENERATION), "task_body": form_data},
+        "metadata": {
+            **(request.state.metadata if request.state.metadata else {}),
+            "task": str(TASKS.EMOJI_GENERATION),
+            "task_body": form_data,
+        },
     }
     }
 
 
     try:
     try:
@@ -548,7 +587,13 @@ async def generate_moa_response(
     request: Request, form_data: dict, user=Depends(get_verified_user)
     request: Request, form_data: dict, user=Depends(get_verified_user)
 ):
 ):
 
 
-    models = request.app.state.MODELS
+    if request.state.direct and request.state.model:
+        models = {
+            request.state.model["id"]: request.state.model,
+        }
+    else:
+        models = request.app.state.MODELS
+
     model_id = form_data["model"]
     model_id = form_data["model"]
 
 
     if model_id not in models:
     if model_id not in models:
@@ -581,6 +626,7 @@ async def generate_moa_response(
         "messages": [{"role": "user", "content": content}],
         "messages": [{"role": "user", "content": content}],
         "stream": form_data.get("stream", False),
         "stream": form_data.get("stream", False),
         "metadata": {
         "metadata": {
+            **(request.state.metadata if request.state.metadata else {}),
             "chat_id": form_data.get("chat_id", None),
             "chat_id": form_data.get("chat_id", None),
             "task": str(TASKS.MOA_RESPONSE_GENERATION),
             "task": str(TASKS.MOA_RESPONSE_GENERATION),
             "task_body": form_data,
             "task_body": form_data,

+ 188 - 69
backend/open_webui/utils/chat.py

@@ -7,6 +7,8 @@ from typing import Any, Optional
 import random
 import random
 import json
 import json
 import inspect
 import inspect
+import uuid
+import asyncio
 
 
 from fastapi import Request
 from fastapi import Request
 from starlette.responses import Response, StreamingResponse
 from starlette.responses import Response, StreamingResponse
@@ -15,6 +17,7 @@ from starlette.responses import Response, StreamingResponse
 from open_webui.models.users import UserModel
 from open_webui.models.users import UserModel
 
 
 from open_webui.socket.main import (
 from open_webui.socket.main import (
+    sio,
     get_event_call,
     get_event_call,
     get_event_emitter,
     get_event_emitter,
 )
 )
@@ -57,6 +60,93 @@ log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["MAIN"])
 log.setLevel(SRC_LOG_LEVELS["MAIN"])
 
 
 
 
+async def generate_direct_chat_completion(
+    request: Request,
+    form_data: dict,
+    user: Any,
+    models: dict,
+):
+    print("generate_direct_chat_completion")
+
+    metadata = form_data.pop("metadata", {})
+
+    user_id = metadata.get("user_id")
+    session_id = metadata.get("session_id")
+    request_id = str(uuid.uuid4())  # Generate a unique request ID
+
+    event_emitter = get_event_emitter(metadata)
+    event_caller = get_event_call(metadata)
+
+    channel = f"{user_id}:{session_id}:{request_id}"
+
+    if form_data.get("stream"):
+        q = asyncio.Queue()
+
+        # Define a generator to stream responses
+        async def event_generator():
+            nonlocal q
+
+            async def message_listener(sid, data):
+                """
+                Handle received socket messages and push them into the queue.
+                """
+                await q.put(data)
+
+            # Register the listener
+            sio.on(channel, message_listener)
+
+            # Start processing chat completion in background
+            await event_emitter(
+                {
+                    "type": "request:chat:completion",
+                    "data": {
+                        "form_data": form_data,
+                        "model": models[form_data["model"]],
+                        "channel": channel,
+                        "session_id": session_id,
+                    },
+                }
+            )
+
+            try:
+                while True:
+                    data = await q.get()  # Wait for new messages
+                    if isinstance(data, dict):
+                        if "error" in data:
+                            raise Exception(data["error"])
+
+                        if "done" in data and data["done"]:
+                            break  # Stop streaming when 'done' is received
+
+                        yield f"data: {json.dumps(data)}\n\n"
+                    elif isinstance(data, str):
+                        yield data
+            finally:
+                del sio.handlers["/"][channel]  # Remove the listener
+
+        # Return the streaming response
+        return StreamingResponse(event_generator(), media_type="text/event-stream")
+    else:
+        res = await event_caller(
+            {
+                "type": "request:chat:completion",
+                "data": {
+                    "form_data": form_data,
+                    "model": models[form_data["model"]],
+                    "channel": channel,
+                    "session_id": session_id,
+                },
+            }
+        )
+
+        print(res)
+
+        if "error" in res:
+            raise Exception(res["error"])
+
+        return res
+
+
 async def generate_chat_completion(
 async def generate_chat_completion(
     request: Request,
     request: Request,
     form_data: dict,
     form_data: dict,
@@ -66,7 +156,12 @@ async def generate_chat_completion(
     if BYPASS_MODEL_ACCESS_CONTROL:
     if BYPASS_MODEL_ACCESS_CONTROL:
         bypass_filter = True
         bypass_filter = True
 
 
-    models = request.app.state.MODELS
+    if request.state.direct and request.state.model:
+        models = {
+            request.state.model["id"]: request.state.model,
+        }
+    else:
+        models = request.app.state.MODELS
 
 
     model_id = form_data["model"]
     model_id = form_data["model"]
     if model_id not in models:
     if model_id not in models:
@@ -87,78 +182,90 @@ async def generate_chat_completion(
         except Exception as e:
         except Exception as e:
             raise e
             raise e
 
 
-    if model["owned_by"] == "arena":
-        model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
-        filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
-        if model_ids and filter_mode == "exclude":
-            model_ids = [
-                model["id"]
-                for model in list(request.app.state.MODELS.values())
-                if model.get("owned_by") != "arena" and model["id"] not in model_ids
-            ]
-
-        selected_model_id = None
-        if isinstance(model_ids, list) and model_ids:
-            selected_model_id = random.choice(model_ids)
-        else:
-            model_ids = [
-                model["id"]
-                for model in list(request.app.state.MODELS.values())
-                if model.get("owned_by") != "arena"
-            ]
-            selected_model_id = random.choice(model_ids)
-
-        form_data["model"] = selected_model_id
-
-        if form_data.get("stream") == True:
+    if request.state.direct:
+        return await generate_direct_chat_completion(
+            request, form_data, user=user, models=models
+        )
 
 
-            async def stream_wrapper(stream):
-                yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
-                async for chunk in stream:
-                    yield chunk
+    else:
+        if model["owned_by"] == "arena":
+            model_ids = model.get("info", {}).get("meta", {}).get("model_ids")
+            filter_mode = model.get("info", {}).get("meta", {}).get("filter_mode")
+            if model_ids and filter_mode == "exclude":
+                model_ids = [
+                    model["id"]
+                    for model in list(request.app.state.MODELS.values())
+                    if model.get("owned_by") != "arena" and model["id"] not in model_ids
+                ]
+
+            selected_model_id = None
+            if isinstance(model_ids, list) and model_ids:
+                selected_model_id = random.choice(model_ids)
+            else:
+                model_ids = [
+                    model["id"]
+                    for model in list(request.app.state.MODELS.values())
+                    if model.get("owned_by") != "arena"
+                ]
+                selected_model_id = random.choice(model_ids)
+
+            form_data["model"] = selected_model_id
+
+            if form_data.get("stream") == True:
+
+                async def stream_wrapper(stream):
+                    yield f"data: {json.dumps({'selected_model_id': selected_model_id})}\n\n"
+                    async for chunk in stream:
+                        yield chunk
+
+                response = await generate_chat_completion(
+                    request, form_data, user, bypass_filter=True
+                )
+                return StreamingResponse(
+                    stream_wrapper(response.body_iterator),
+                    media_type="text/event-stream",
+                    background=response.background,
+                )
+            else:
+                return {
+                    **(
+                        await generate_chat_completion(
+                            request, form_data, user, bypass_filter=True
+                        )
+                    ),
+                    "selected_model_id": selected_model_id,
+                }
 
 
-            response = await generate_chat_completion(
-                request, form_data, user, bypass_filter=True
+        if model.get("pipe"):
+            # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
+            return await generate_function_chat_completion(
+                request, form_data, user=user, models=models
             )
             )
-            return StreamingResponse(
-                stream_wrapper(response.body_iterator),
-                media_type="text/event-stream",
-                background=response.background,
+        if model["owned_by"] == "ollama":
+            # Using /ollama/api/chat endpoint
+            form_data = convert_payload_openai_to_ollama(form_data)
+            response = await generate_ollama_chat_completion(
+                request=request,
+                form_data=form_data,
+                user=user,
+                bypass_filter=bypass_filter,
             )
             )
+            if form_data.get("stream"):
+                response.headers["content-type"] = "text/event-stream"
+                return StreamingResponse(
+                    convert_streaming_response_ollama_to_openai(response),
+                    headers=dict(response.headers),
+                    background=response.background,
+                )
+            else:
+                return convert_response_ollama_to_openai(response)
         else:
         else:
-            return {
-                **(
-                    await generate_chat_completion(
-                        request, form_data, user, bypass_filter=True
-                    )
-                ),
-                "selected_model_id": selected_model_id,
-            }
-
-    if model.get("pipe"):
-        # Below does not require bypass_filter because this is the only route the uses this function and it is already bypassing the filter
-        return await generate_function_chat_completion(
-            request, form_data, user=user, models=models
-        )
-    if model["owned_by"] == "ollama":
-        # Using /ollama/api/chat endpoint
-        form_data = convert_payload_openai_to_ollama(form_data)
-        response = await generate_ollama_chat_completion(
-            request=request, form_data=form_data, user=user, bypass_filter=bypass_filter
-        )
-        if form_data.get("stream"):
-            response.headers["content-type"] = "text/event-stream"
-            return StreamingResponse(
-                convert_streaming_response_ollama_to_openai(response),
-                headers=dict(response.headers),
-                background=response.background,
+            return await generate_openai_chat_completion(
+                request=request,
+                form_data=form_data,
+                user=user,
+                bypass_filter=bypass_filter,
             )
             )
-        else:
-            return convert_response_ollama_to_openai(response)
-    else:
-        return await generate_openai_chat_completion(
-            request=request, form_data=form_data, user=user, bypass_filter=bypass_filter
-        )
 
 
 
 
 chat_completion = generate_chat_completion
 chat_completion = generate_chat_completion
@@ -167,7 +274,13 @@ chat_completion = generate_chat_completion
 async def chat_completed(request: Request, form_data: dict, user: Any):
 async def chat_completed(request: Request, form_data: dict, user: Any):
     if not request.app.state.MODELS:
     if not request.app.state.MODELS:
         await get_all_models(request)
         await get_all_models(request)
-    models = request.app.state.MODELS
+
+    if request.state.direct and request.state.model:
+        models = {
+            request.state.model["id"]: request.state.model,
+        }
+    else:
+        models = request.app.state.MODELS
 
 
     data = form_data
     data = form_data
     model_id = data["model"]
     model_id = data["model"]
@@ -227,7 +340,13 @@ async def chat_action(request: Request, action_id: str, form_data: dict, user: A
 
 
     if not request.app.state.MODELS:
     if not request.app.state.MODELS:
         await get_all_models(request)
         await get_all_models(request)
-    models = request.app.state.MODELS
+
+    if request.state.direct and request.state.model:
+        models = {
+            request.state.model["id"]: request.state.model,
+        }
+    else:
+        models = request.app.state.MODELS
 
 
     data = form_data
     data = form_data
     model_id = data["model"]
     model_id = data["model"]

+ 10 - 1
backend/open_webui/utils/middleware.py

@@ -622,7 +622,13 @@ async def process_chat_payload(request, form_data, metadata, user, model):
 
 
     # Initialize events to store additional event to be sent to the client
     # Initialize events to store additional event to be sent to the client
     # Initialize contexts and citation
     # Initialize contexts and citation
-    models = request.app.state.MODELS
+    if request.state.direct and request.state.model:
+        models = {
+            request.state.model["id"]: request.state.model,
+        }
+    else:
+        models = request.app.state.MODELS
+
     task_model_id = get_task_model_id(
     task_model_id = get_task_model_id(
         form_data["model"],
         form_data["model"],
         request.app.state.config.TASK_MODEL,
         request.app.state.config.TASK_MODEL,
@@ -1677,6 +1683,9 @@ async def process_chat_response(
                                             "data": {
                                             "data": {
                                                 "id": str(uuid4()),
                                                 "id": str(uuid4()),
                                                 "code": code,
                                                 "code": code,
+                                                "session_id": metadata.get(
+                                                    "session_id", None
+                                                ),
                                             },
                                             },
                                         }
                                         }
                                     )
                                     )

+ 3 - 0
src/lib/components/chat/Chat.svelte

@@ -838,6 +838,7 @@
 				timestamp: m.timestamp,
 				timestamp: m.timestamp,
 				...(m.sources ? { sources: m.sources } : {})
 				...(m.sources ? { sources: m.sources } : {})
 			})),
 			})),
+			model_item: $models.find((m) => m.id === modelId),
 			chat_id: chatId,
 			chat_id: chatId,
 			session_id: $socket?.id,
 			session_id: $socket?.id,
 			id: responseMessageId
 			id: responseMessageId
@@ -896,6 +897,7 @@
 				...(m.sources ? { sources: m.sources } : {})
 				...(m.sources ? { sources: m.sources } : {})
 			})),
 			})),
 			...(event ? { event: event } : {}),
 			...(event ? { event: event } : {}),
+			model_item: $models.find((m) => m.id === modelId),
 			chat_id: chatId,
 			chat_id: chatId,
 			session_id: $socket?.id,
 			session_id: $socket?.id,
 			id: responseMessageId
 			id: responseMessageId
@@ -1574,6 +1576,7 @@
 						$settings?.userLocation ? await getAndUpdateUserLocation(localStorage.token) : undefined
 						$settings?.userLocation ? await getAndUpdateUserLocation(localStorage.token) : undefined
 					)
 					)
 				},
 				},
+				model_item: $models.find((m) => m.id === model.id),
 
 
 				session_id: $socket?.id,
 				session_id: $socket?.id,
 				chat_id: $chatId,
 				chat_id: $chatId,

+ 92 - 1
src/routes/+layout.svelte

@@ -45,6 +45,7 @@
 	import { getAllTags, getChatList } from '$lib/apis/chats';
 	import { getAllTags, getChatList } from '$lib/apis/chats';
 	import NotificationToast from '$lib/components/NotificationToast.svelte';
 	import NotificationToast from '$lib/components/NotificationToast.svelte';
 	import AppSidebar from '$lib/components/app/AppSidebar.svelte';
 	import AppSidebar from '$lib/components/app/AppSidebar.svelte';
+	import { chatCompletion } from '$lib/apis/openai';
 
 
 	setContext('i18n', i18n);
 	setContext('i18n', i18n);
 
 
@@ -251,10 +252,100 @@
 			} else if (type === 'chat:tags') {
 			} else if (type === 'chat:tags') {
 				tags.set(await getAllTags(localStorage.token));
 				tags.set(await getAllTags(localStorage.token));
 			}
 			}
-		} else {
+		} else if (data?.session_id === $socket.id) {
 			if (type === 'execute:python') {
 			if (type === 'execute:python') {
 				console.log('execute:python', data);
 				console.log('execute:python', data);
 				executePythonAsWorker(data.id, data.code, cb);
 				executePythonAsWorker(data.id, data.code, cb);
+			} else if (type === 'request:chat:completion') {
+				console.log(data, $socket.id);
+				const { session_id, channel, form_data, model } = data;
+
+				try {
+					const directConnections = $settings?.directConnections ?? {};
+
+					if (directConnections) {
+						const urlIdx = model?.urlIdx;
+
+						console.log(model, directConnections);
+
+						const OPENAI_API_URL = directConnections.OPENAI_API_BASE_URLS[urlIdx];
+						const OPENAI_API_KEY = directConnections.OPENAI_API_KEYS[urlIdx];
+						const API_CONFIG = directConnections.OPENAI_API_CONFIGS[urlIdx];
+
+						try {
+							const [res, controller] = await chatCompletion(
+								OPENAI_API_KEY,
+								form_data,
+								OPENAI_API_URL
+							);
+
+							if (res && res.ok) {
+								if (form_data?.stream ?? false) {
+									// res will either be SSE or JSON
+									const reader = res.body.getReader();
+									const decoder = new TextDecoder();
+
+									const processStream = async () => {
+										while (true) {
+											// Read data chunks from the response stream
+											const { done, value } = await reader.read();
+											if (done) {
+												break;
+											}
+
+											// Decode the received chunk
+											const chunk = decoder.decode(value, { stream: true });
+
+											// Process lines within the chunk
+											const lines = chunk.split('\n').filter((line) => line.trim() !== '');
+
+											for (const line of lines) {
+												$socket?.emit(channel, line);
+											}
+										}
+									};
+
+									// Process the stream in the background
+									await processStream();
+								} else {
+									const data = await res.json();
+									cb(data);
+								}
+							} else {
+								throw new Error('An error occurred while fetching the completion');
+							}
+						} catch (error) {
+							console.error('chatCompletion', error);
+
+							if (form_data?.stream ?? false) {
+								$socket.emit(channel, {
+									error: error
+								});
+							} else {
+								cb({
+									error: error
+								});
+							}
+						}
+					}
+				} catch (error) {
+					console.error('chatCompletion', error);
+					if (form_data?.stream ?? false) {
+						$socket.emit(channel, {
+							error: error
+						});
+					} else {
+						cb({
+							error: error
+						});
+					}
+				} finally {
+					$socket.emit(channel, {
+						done: true
+					});
+				}
+			} else {
+				console.log('chatEventHandler', event);
 			}
 			}
 		}
 		}
 	};
 	};