Ver Fonte

feat: pipe function

Timothy J. Baek há 10 meses atrás
pai
commit
d6e4aef607
3 ficheiros alterados com 231 adições e 61 exclusões
  1. 58 0
      backend/apps/webui/main.py
  2. 154 61
      backend/main.py
  3. 19 0
      backend/utils/misc.py

+ 58 - 0
backend/apps/webui/main.py

@@ -15,6 +15,9 @@ from apps.webui.routers import (
     files,
     files,
     functions,
     functions,
 )
 )
+from apps.webui.models.functions import Functions
+from apps.webui.utils import load_function_module_by_id
+
 from config import (
 from config import (
     WEBUI_BUILD_HASH,
     WEBUI_BUILD_HASH,
     SHOW_ADMIN_DETAILS,
     SHOW_ADMIN_DETAILS,
@@ -97,3 +100,58 @@ async def get_status():
         "default_models": app.state.config.DEFAULT_MODELS,
         "default_models": app.state.config.DEFAULT_MODELS,
         "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
         "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
     }
     }
+
+
+async def get_pipe_models():
+    pipes = Functions.get_functions_by_type("pipe")
+    pipe_models = []
+
+    for pipe in pipes:
+        # Check if function is already loaded
+        if pipe.id not in app.state.FUNCTIONS:
+            function_module, function_type = load_function_module_by_id(pipe.id)
+            app.state.FUNCTIONS[pipe.id] = function_module
+        else:
+            function_module = app.state.FUNCTIONS[pipe.id]
+
+        # Check if function is a manifold
+        if hasattr(function_module, "type"):
+            if function_module.type == "manifold":
+                manifold_pipes = []
+
+                # Check if pipes is a function or a list
+                if callable(pipe.pipes):
+                    manifold_pipes = pipe.pipes()
+                else:
+                    manifold_pipes = pipe.pipes
+
+                for p in manifold_pipes:
+                    manifold_pipe_id = f'{pipe.id}.{p["id"]}'
+                    manifold_pipe_name = p["name"]
+
+                    if hasattr(pipe, "name"):
+                        manifold_pipe_name = f"{pipe.name}{manifold_pipe_name}"
+
+                    pipe_models.append(
+                        {
+                            "id": manifold_pipe_id,
+                            "name": manifold_pipe_name,
+                            "object": "model",
+                            "created": pipe.created_at,
+                            "owned_by": "openai",
+                            "pipe": {"type": pipe.type},
+                        }
+                    )
+        else:
+            pipe_models.append(
+                {
+                    "id": pipe.id,
+                    "name": pipe.name,
+                    "object": "model",
+                    "created": pipe.created_at,
+                    "owned_by": "openai",
+                    "pipe": {"type": "pipe"},
+                }
+            )
+
+    return pipe_models

+ 154 - 61
backend/main.py

@@ -15,6 +15,7 @@ import uuid
 import inspect
 import inspect
 import asyncio
 import asyncio
 
 
+from fastapi.concurrency import run_in_threadpool
 from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
 from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
 from fastapi.staticfiles import StaticFiles
 from fastapi.staticfiles import StaticFiles
 from fastapi.responses import JSONResponse
 from fastapi.responses import JSONResponse
@@ -46,7 +47,7 @@ from apps.webui.main import app as webui_app, get_pipe_models
 
 
 
 
 from pydantic import BaseModel
 from pydantic import BaseModel
-from typing import List, Optional
+from typing import List, Optional, Iterator, Generator, Union
 
 
 from apps.webui.models.models import Models, ModelModel
 from apps.webui.models.models import Models, ModelModel
 from apps.webui.models.tools import Tools
 from apps.webui.models.tools import Tools
@@ -66,7 +67,11 @@ from utils.task import (
     search_query_generation_template,
     search_query_generation_template,
     tools_function_calling_generation_template,
     tools_function_calling_generation_template,
 )
 )
-from utils.misc import get_last_user_message, add_or_update_system_message
+from utils.misc import (
+    get_last_user_message,
+    add_or_update_system_message,
+    stream_message_template,
+)
 
 
 from apps.rag.utils import get_rag_context, rag_template
 from apps.rag.utils import get_rag_context, rag_template
 
 
@@ -347,38 +352,39 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
             model = app.state.MODELS[model_id]
             model = app.state.MODELS[model_id]
 
 
             # Check if the model has any filters
             # Check if the model has any filters
-            for filter_id in model["info"]["meta"].get("filterIds", []):
-                filter = Functions.get_function_by_id(filter_id)
-                if filter:
-                    if filter_id in webui_app.state.FUNCTIONS:
-                        function_module = webui_app.state.FUNCTIONS[filter_id]
-                    else:
-                        function_module, function_type = load_function_module_by_id(
-                            filter_id
-                        )
-                        webui_app.state.FUNCTIONS[filter_id] = function_module
+            if "info" in model and "meta" in model["info"]:
+                for filter_id in model["info"]["meta"].get("filterIds", []):
+                    filter = Functions.get_function_by_id(filter_id)
+                    if filter:
+                        if filter_id in webui_app.state.FUNCTIONS:
+                            function_module = webui_app.state.FUNCTIONS[filter_id]
+                        else:
+                            function_module, function_type = load_function_module_by_id(
+                                filter_id
+                            )
+                            webui_app.state.FUNCTIONS[filter_id] = function_module
 
 
-                    # Check if the function has a file_handler variable
-                    if getattr(function_module, "file_handler"):
-                        skip_files = True
+                        # Check if the function has a file_handler variable
+                        if getattr(function_module, "file_handler"):
+                            skip_files = True
 
 
-                    try:
-                        if hasattr(function_module, "inlet"):
-                            data = function_module.inlet(
-                                data,
-                                {
-                                    "id": user.id,
-                                    "email": user.email,
-                                    "name": user.name,
-                                    "role": user.role,
-                                },
+                        try:
+                            if hasattr(function_module, "inlet"):
+                                data = function_module.inlet(
+                                    data,
+                                    {
+                                        "id": user.id,
+                                        "email": user.email,
+                                        "name": user.name,
+                                        "role": user.role,
+                                    },
+                                )
+                        except Exception as e:
+                            print(f"Error: {e}")
+                            return JSONResponse(
+                                status_code=status.HTTP_400_BAD_REQUEST,
+                                content={"detail": str(e)},
                             )
                             )
-                    except Exception as e:
-                        print(f"Error: {e}")
-                        return JSONResponse(
-                            status_code=status.HTTP_400_BAD_REQUEST,
-                            content={"detail": str(e)},
-                        )
 
 
             # Set the task model
             # Set the task model
             task_model_id = data["model"]
             task_model_id = data["model"]
@@ -794,13 +800,97 @@ async def generate_chat_completions(form_data: dict, user=Depends(get_verified_u
     model = app.state.MODELS[model_id]
     model = app.state.MODELS[model_id]
     print(model)
     print(model)
 
 
-    
+    pipe = model.get("pipe")
+    if pipe:
+
+        def job():
+            pipe_id = form_data["model"]
+            if "." in pipe_id:
+                pipe_id, sub_pipe_id = pipe_id.split(".", 1)
+            print(pipe_id)
+
+            pipe = webui_app.state.FUNCTIONS[pipe_id].pipe
+            if form_data["stream"]:
+
+                def stream_content():
+                    res = pipe(body=form_data)
+
+                    if isinstance(res, str):
+                        message = stream_message_template(form_data["model"], res)
+                        yield f"data: {json.dumps(message)}\n\n"
+
+                    if isinstance(res, Iterator):
+                        for line in res:
+                            if isinstance(line, BaseModel):
+                                line = line.model_dump_json()
+                                line = f"data: {line}"
+                            try:
+                                line = line.decode("utf-8")
+                            except:
+                                pass
+
+                            if line.startswith("data:"):
+                                yield f"{line}\n\n"
+                            else:
+                                line = stream_message_template(form_data["model"], line)
+                                yield f"data: {json.dumps(line)}\n\n"
+
+                    if isinstance(res, str) or isinstance(res, Generator):
+                        finish_message = {
+                            "id": f"{form_data['model']}-{str(uuid.uuid4())}",
+                            "object": "chat.completion.chunk",
+                            "created": int(time.time()),
+                            "model": form_data["model"],
+                            "choices": [
+                                {
+                                    "index": 0,
+                                    "delta": {},
+                                    "logprobs": None,
+                                    "finish_reason": "stop",
+                                }
+                            ],
+                        }
+
+                        yield f"data: {json.dumps(finish_message)}\n\n"
+                        yield f"data: [DONE]"
 
 
-    if model.get('pipe') == True:
-        print('hi')
-    
-    
-    
+                return StreamingResponse(
+                    stream_content(), media_type="text/event-stream"
+                )
+            else:
+                res = pipe(body=form_data)
+
+                if isinstance(res, dict):
+                    return res
+                elif isinstance(res, BaseModel):
+                    return res.model_dump()
+                else:
+                    message = ""
+                    if isinstance(res, str):
+                        message = res
+                    if isinstance(res, Generator):
+                        for stream in res:
+                            message = f"{message}{stream}"
+
+                    return {
+                        "id": f"{form_data['model']}-{str(uuid.uuid4())}",
+                        "object": "chat.completion",
+                        "created": int(time.time()),
+                        "model": form_data["model"],
+                        "choices": [
+                            {
+                                "index": 0,
+                                "message": {
+                                    "role": "assistant",
+                                    "content": message,
+                                },
+                                "logprobs": None,
+                                "finish_reason": "stop",
+                            }
+                        ],
+                    }
+
+        return await run_in_threadpool(job)
     if model["owned_by"] == "ollama":
     if model["owned_by"] == "ollama":
         return await generate_ollama_chat_completion(form_data, user=user)
         return await generate_ollama_chat_completion(form_data, user=user)
     else:
     else:
@@ -877,32 +967,35 @@ async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
                 pass
                 pass
 
 
     # Check if the model has any filters
     # Check if the model has any filters
-    for filter_id in model["info"]["meta"].get("filterIds", []):
-        filter = Functions.get_function_by_id(filter_id)
-        if filter:
-            if filter_id in webui_app.state.FUNCTIONS:
-                function_module = webui_app.state.FUNCTIONS[filter_id]
-            else:
-                function_module, function_type = load_function_module_by_id(filter_id)
-                webui_app.state.FUNCTIONS[filter_id] = function_module
+    if "info" in model and "meta" in model["info"]:
+        for filter_id in model["info"]["meta"].get("filterIds", []):
+            filter = Functions.get_function_by_id(filter_id)
+            if filter:
+                if filter_id in webui_app.state.FUNCTIONS:
+                    function_module = webui_app.state.FUNCTIONS[filter_id]
+                else:
+                    function_module, function_type = load_function_module_by_id(
+                        filter_id
+                    )
+                    webui_app.state.FUNCTIONS[filter_id] = function_module
 
 
-            try:
-                if hasattr(function_module, "outlet"):
-                    data = function_module.outlet(
-                        data,
-                        {
-                            "id": user.id,
-                            "email": user.email,
-                            "name": user.name,
-                            "role": user.role,
-                        },
+                try:
+                    if hasattr(function_module, "outlet"):
+                        data = function_module.outlet(
+                            data,
+                            {
+                                "id": user.id,
+                                "email": user.email,
+                                "name": user.name,
+                                "role": user.role,
+                            },
+                        )
+                except Exception as e:
+                    print(f"Error: {e}")
+                    return JSONResponse(
+                        status_code=status.HTTP_400_BAD_REQUEST,
+                        content={"detail": str(e)},
                     )
                     )
-            except Exception as e:
-                print(f"Error: {e}")
-                return JSONResponse(
-                    status_code=status.HTTP_400_BAD_REQUEST,
-                    content={"detail": str(e)},
-                )
 
 
     return data
     return data
 
 

+ 19 - 0
backend/utils/misc.py

@@ -4,6 +4,8 @@ import json
 import re
 import re
 from datetime import timedelta
 from datetime import timedelta
 from typing import Optional, List, Tuple
 from typing import Optional, List, Tuple
+import uuid
+import time
 
 
 
 
 def get_last_user_message(messages: List[dict]) -> str:
 def get_last_user_message(messages: List[dict]) -> str:
@@ -62,6 +64,23 @@ def add_or_update_system_message(content: str, messages: List[dict]):
     return messages
     return messages
 
 
 
 
+def stream_message_template(model: str, message: str):
+    return {
+        "id": f"{model}-{str(uuid.uuid4())}",
+        "object": "chat.completion.chunk",
+        "created": int(time.time()),
+        "model": model,
+        "choices": [
+            {
+                "index": 0,
+                "delta": {"content": message},
+                "logprobs": None,
+                "finish_reason": None,
+            }
+        ],
+    }
+
+
 def get_gravatar_url(email):
 def get_gravatar_url(email):
     # Trim leading and trailing whitespace from
     # Trim leading and trailing whitespace from
     # an email address and force all characters
     # an email address and force all characters