Przeglądaj źródła

feat: pipe function

Timothy J. Baek 10 miesięcy temu
rodzic
commit
d6e4aef607
3 zmienionych plików z 231 dodań i 61 usunięć
  1. 58 0
      backend/apps/webui/main.py
  2. 154 61
      backend/main.py
  3. 19 0
      backend/utils/misc.py

+ 58 - 0
backend/apps/webui/main.py

@@ -15,6 +15,9 @@ from apps.webui.routers import (
     files,
     functions,
 )
+from apps.webui.models.functions import Functions
+from apps.webui.utils import load_function_module_by_id
+
 from config import (
     WEBUI_BUILD_HASH,
     SHOW_ADMIN_DETAILS,
@@ -97,3 +100,58 @@ async def get_status():
         "default_models": app.state.config.DEFAULT_MODELS,
         "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
     }
+
+
+async def get_pipe_models():
+    pipes = Functions.get_functions_by_type("pipe")
+    pipe_models = []
+
+    for pipe in pipes:
+        # Check if function is already loaded
+        if pipe.id not in app.state.FUNCTIONS:
+            function_module, function_type = load_function_module_by_id(pipe.id)
+            app.state.FUNCTIONS[pipe.id] = function_module
+        else:
+            function_module = app.state.FUNCTIONS[pipe.id]
+
+        # Check if function is a manifold
+        if hasattr(function_module, "type"):
+            if function_module.type == "manifold":
+                manifold_pipes = []
+
+                # Check if pipes is a function or a list
+                if callable(pipe.pipes):
+                    manifold_pipes = pipe.pipes()
+                else:
+                    manifold_pipes = pipe.pipes
+
+                for p in manifold_pipes:
+                    manifold_pipe_id = f'{pipe.id}.{p["id"]}'
+                    manifold_pipe_name = p["name"]
+
+                    if hasattr(pipe, "name"):
+                        manifold_pipe_name = f"{pipe.name}{manifold_pipe_name}"
+
+                    pipe_models.append(
+                        {
+                            "id": manifold_pipe_id,
+                            "name": manifold_pipe_name,
+                            "object": "model",
+                            "created": pipe.created_at,
+                            "owned_by": "openai",
+                            "pipe": {"type": pipe.type},
+                        }
+                    )
+        else:
+            pipe_models.append(
+                {
+                    "id": pipe.id,
+                    "name": pipe.name,
+                    "object": "model",
+                    "created": pipe.created_at,
+                    "owned_by": "openai",
+                    "pipe": {"type": "pipe"},
+                }
+            )
+
+    return pipe_models

+ 154 - 61
backend/main.py

@@ -15,6 +15,7 @@ import uuid
 import inspect
 import asyncio
 
+from fastapi.concurrency import run_in_threadpool
 from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
 from fastapi.staticfiles import StaticFiles
 from fastapi.responses import JSONResponse
@@ -46,7 +47,7 @@ from apps.webui.main import app as webui_app, get_pipe_models
 
 
 from pydantic import BaseModel
-from typing import List, Optional
+from typing import List, Optional, Iterator, Generator, Union
 
 from apps.webui.models.models import Models, ModelModel
 from apps.webui.models.tools import Tools
@@ -66,7 +67,11 @@ from utils.task import (
     search_query_generation_template,
     tools_function_calling_generation_template,
 )
-from utils.misc import get_last_user_message, add_or_update_system_message
+from utils.misc import (
+    get_last_user_message,
+    add_or_update_system_message,
+    stream_message_template,
+)
 
 from apps.rag.utils import get_rag_context, rag_template
 
@@ -347,38 +352,39 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
             model = app.state.MODELS[model_id]
 
             # Check if the model has any filters
-            for filter_id in model["info"]["meta"].get("filterIds", []):
-                filter = Functions.get_function_by_id(filter_id)
-                if filter:
-                    if filter_id in webui_app.state.FUNCTIONS:
-                        function_module = webui_app.state.FUNCTIONS[filter_id]
-                    else:
-                        function_module, function_type = load_function_module_by_id(
-                            filter_id
-                        )
-                        webui_app.state.FUNCTIONS[filter_id] = function_module
+            if "info" in model and "meta" in model["info"]:
+                for filter_id in model["info"]["meta"].get("filterIds", []):
+                    filter = Functions.get_function_by_id(filter_id)
+                    if filter:
+                        if filter_id in webui_app.state.FUNCTIONS:
+                            function_module = webui_app.state.FUNCTIONS[filter_id]
+                        else:
+                            function_module, function_type = load_function_module_by_id(
+                                filter_id
+                            )
+                            webui_app.state.FUNCTIONS[filter_id] = function_module
 
-                    # Check if the function has a file_handler variable
-                    if getattr(function_module, "file_handler"):
-                        skip_files = True
+                        # Check if the function has a file_handler variable
+                        if getattr(function_module, "file_handler"):
+                            skip_files = True
 
-                    try:
-                        if hasattr(function_module, "inlet"):
-                            data = function_module.inlet(
-                                data,
-                                {
-                                    "id": user.id,
-                                    "email": user.email,
-                                    "name": user.name,
-                                    "role": user.role,
-                                },
+                        try:
+                            if hasattr(function_module, "inlet"):
+                                data = function_module.inlet(
+                                    data,
+                                    {
+                                        "id": user.id,
+                                        "email": user.email,
+                                        "name": user.name,
+                                        "role": user.role,
+                                    },
+                                )
+                        except Exception as e:
+                            print(f"Error: {e}")
+                            return JSONResponse(
+                                status_code=status.HTTP_400_BAD_REQUEST,
+                                content={"detail": str(e)},
                             )
-                    except Exception as e:
-                        print(f"Error: {e}")
-                        return JSONResponse(
-                            status_code=status.HTTP_400_BAD_REQUEST,
-                            content={"detail": str(e)},
-                        )
 
             # Set the task model
             task_model_id = data["model"]
@@ -794,13 +800,97 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
     model = app.state.MODELS[model_id]
     print(model)
 
-    
+    pipe = model.get("pipe")
+    if pipe:
+
+        def job():
+            pipe_id = form_data["model"]
+            if "." in pipe_id:
+                pipe_id, sub_pipe_id = pipe_id.split(".", 1)
+            print(pipe_id)
+
+            pipe = webui_app.state.FUNCTIONS[pipe_id].pipe
+            if form_data["stream"]:
+
+                def stream_content():
+                    res = pipe(body=form_data)
+
+                    if isinstance(res, str):
+                        message = stream_message_template(form_data["model"], res)
+                        yield f"data: {json.dumps(message)}\n\n"
+
+                    if isinstance(res, Iterator):
+                        for line in res:
+                            if isinstance(line, BaseModel):
+                                line = line.model_dump_json()
+                                line = f"data: {line}"
+                            try:
+                                line = line.decode("utf-8")
+                            except:
+                                pass
+
+                            if line.startswith("data:"):
+                                yield f"{line}\n\n"
+                            else:
+                                line = stream_message_template(form_data["model"], line)
+                                yield f"data: {json.dumps(line)}\n\n"
+
+                    if isinstance(res, str) or isinstance(res, Generator):
+                        finish_message = {
+                            "id": f"{form_data['model']}-{str(uuid.uuid4())}",
+                            "object": "chat.completion.chunk",
+                            "created": int(time.time()),
+                            "model": form_data["model"],
+                            "choices": [
+                                {
+                                    "index": 0,
+                                    "delta": {},
+                                    "logprobs": None,
+                                    "finish_reason": "stop",
+                                }
+                            ],
+                        }
+
+                        yield f"data: {json.dumps(finish_message)}\n\n"
+                        yield f"data: [DONE]"
 
-    if model.get('pipe') == True:
-        print('hi')
-    
-    
-    
+                return StreamingResponse(
+                    stream_content(), media_type="text/event-stream"
+                )
+            else:
+                res = pipe(body=form_data)
+
+                if isinstance(res, dict):
+                    return res
+                elif isinstance(res, BaseModel):
+                    return res.model_dump()
+                else:
+                    message = ""
+                    if isinstance(res, str):
+                        message = res
+                    if isinstance(res, Generator):
+                        for stream in res:
+                            message = f"{message}{stream}"
+
+                    return {
+                        "id": f"{form_data['model']}-{str(uuid.uuid4())}",
+                        "object": "chat.completion",
+                        "created": int(time.time()),
+                        "model": form_data["model"],
+                        "choices": [
+                            {
+                                "index": 0,
+                                "message": {
+                                    "role": "assistant",
+                                    "content": message,
+                                },
+                                "logprobs": None,
+                                "finish_reason": "stop",
+                            }
+                        ],
+                    }
+
+        return await run_in_threadpool(job)
     if model["owned_by"] == "ollama":
         return await generate_ollama_chat_completion(form_data, user=user)
     else:
@@ -877,32 +967,35 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
                 pass
 
     # Check if the model has any filters
-    for filter_id in model["info"]["meta"].get("filterIds", []):
-        filter = Functions.get_function_by_id(filter_id)
-        if filter:
-            if filter_id in webui_app.state.FUNCTIONS:
-                function_module = webui_app.state.FUNCTIONS[filter_id]
-            else:
-                function_module, function_type = load_function_module_by_id(filter_id)
-                webui_app.state.FUNCTIONS[filter_id] = function_module
+    if "info" in model and "meta" in model["info"]:
+        for filter_id in model["info"]["meta"].get("filterIds", []):
+            filter = Functions.get_function_by_id(filter_id)
+            if filter:
+                if filter_id in webui_app.state.FUNCTIONS:
+                    function_module = webui_app.state.FUNCTIONS[filter_id]
+                else:
+                    function_module, function_type = load_function_module_by_id(
+                        filter_id
+                    )
+                    webui_app.state.FUNCTIONS[filter_id] = function_module
 
-            try:
-                if hasattr(function_module, "outlet"):
-                    data = function_module.outlet(
-                        data,
-                        {
-                            "id": user.id,
-                            "email": user.email,
-                            "name": user.name,
-                            "role": user.role,
-                        },
+                try:
+                    if hasattr(function_module, "outlet"):
+                        data = function_module.outlet(
+                            data,
+                            {
+                                "id": user.id,
+                                "email": user.email,
+                                "name": user.name,
+                                "role": user.role,
+                            },
+                        )
+                except Exception as e:
+                    print(f"Error: {e}")
+                    return JSONResponse(
+                        status_code=status.HTTP_400_BAD_REQUEST,
+                        content={"detail": str(e)},
                     )
-            except Exception as e:
-                print(f"Error: {e}")
-                return JSONResponse(
-                    status_code=status.HTTP_400_BAD_REQUEST,
-                    content={"detail": str(e)},
-                )
 
     return data
 

+ 19 - 0
backend/utils/misc.py

@@ -4,6 +4,8 @@ import json
 import re
 from datetime import timedelta
 from typing import Optional, List, Tuple
+import uuid
+import time
 
 
 def get_last_user_message(messages: List[dict]) -> str:
@@ -62,6 +64,23 @@ def add_or_update_system_message(content: str, messages: List[dict]):
     return messages
 
 
+def stream_message_template(model: str, message: str):
+    return {
+        "id": f"{model}-{str(uuid.uuid4())}",
+        "object": "chat.completion.chunk",
+        "created": int(time.time()),
+        "model": model,
+        "choices": [
+            {
+                "index": 0,
+                "delta": {"content": message},
+                "logprobs": None,
+                "finish_reason": None,
+            }
+        ],
+    }
+
+
 def get_gravatar_url(email):
     # Trim leading and trailing whitespace from
     # an email address and force all characters