Prechádzať zdrojové kódy

Merge pull request #3321 from open-webui/functions

feat: functions
Timothy Jaeryang Baek 10 mesiacov pred
rodič
commit
09a81eb225

+ 4 - 13
backend/apps/ollama/main.py

@@ -53,7 +53,7 @@ from config import (
     UPLOAD_DIR,
     AppConfig,
 )
-from utils.misc import calculate_sha256
+from utils.misc import calculate_sha256, add_or_update_system_message
 
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
@@ -834,18 +834,9 @@ async def generate_chat_completion(
             )
 
             if payload.get("messages"):
-                for message in payload["messages"]:
-                    if message.get("role") == "system":
-                        message["content"] = system + message["content"]
-                        break
-                else:
-                    payload["messages"].insert(
-                        0,
-                        {
-                            "role": "system",
-                            "content": system,
-                        },
-                    )
+                payload["messages"] = add_or_update_system_message(
+                    system, payload["messages"]
+                )
 
     if url_idx == None:
         if ":" not in payload["model"]:

+ 6 - 1
backend/apps/openai/main.py

@@ -432,7 +432,12 @@ async def generate_chat_completion(
     idx = model["urlIdx"]
 
     if "pipeline" in model and model.get("pipeline"):
-        payload["user"] = {"name": user.name, "id": user.id}
+        payload["user"] = {
+            "name": user.name,
+            "id": user.id,
+            "email": user.email,
+            "role": user.role,
+        }
 
     # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
     # This is a workaround until OpenAI fixes the issue with this model

+ 61 - 0
backend/apps/webui/internal/migrations/015_add_functions.py

@@ -0,0 +1,61 @@
+"""Peewee migrations -- 009_add_models.py.
+
+Some examples (model - class or model name)::
+
+    > Model = migrator.orm['table_name']            # Return model in current state by name
+    > Model = migrator.ModelClass                   # Return model in current state by name
+
+    > migrator.sql(sql)                             # Run custom SQL
+    > migrator.run(func, *args, **kwargs)           # Run python function with the given args
+    > migrator.create_model(Model)                  # Create a model (could be used as decorator)
+    > migrator.remove_model(model, cascade=True)    # Remove a model
+    > migrator.add_fields(model, **fields)          # Add fields to a model
+    > migrator.change_fields(model, **fields)       # Change fields
+    > migrator.remove_fields(model, *field_names, cascade=True)
+    > migrator.rename_field(model, old_field_name, new_field_name)
+    > migrator.rename_table(model, new_table_name)
+    > migrator.add_index(model, *col_names, unique=False)
+    > migrator.add_not_null(model, *field_names)
+    > migrator.add_default(model, field_name, default)
+    > migrator.add_constraint(model, name, sql)
+    > migrator.drop_index(model, *col_names)
+    > migrator.drop_not_null(model, *field_names)
+    > migrator.drop_constraints(model, *constraints)
+
+"""
+
+from contextlib import suppress
+
+import peewee as pw
+from peewee_migrate import Migrator
+
+
+with suppress(ImportError):
+    import playhouse.postgres_ext as pw_pext
+
+
+def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
+    """Write your migrations here."""
+
+    @migrator.create_model
+    class Function(pw.Model):
+        id = pw.TextField(unique=True)
+        user_id = pw.TextField()
+
+        name = pw.TextField()
+        type = pw.TextField()
+
+        content = pw.TextField()
+        meta = pw.TextField()
+
+        created_at = pw.BigIntegerField(null=False)
+        updated_at = pw.BigIntegerField(null=False)
+
+        class Meta:
+            table_name = "function"
+
+
+def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
+    """Write your rollback migrations here."""
+
+    migrator.remove_model("function")

+ 66 - 4
backend/apps/webui/main.py

@@ -13,7 +13,11 @@ from apps.webui.routers import (
     memories,
     utils,
     files,
+    functions,
 )
+from apps.webui.models.functions import Functions
+from apps.webui.utils import load_function_module_by_id
+
 from config import (
     WEBUI_BUILD_HASH,
     SHOW_ADMIN_DETAILS,
@@ -60,7 +64,7 @@ app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
 
 app.state.MODELS = {}
 app.state.TOOLS = {}
-
+app.state.FUNCTIONS = {}
 
 app.add_middleware(
     CORSMiddleware,
@@ -70,19 +74,22 @@ app.add_middleware(
     allow_headers=["*"],
 )
 
+
+app.include_router(configs.router, prefix="/configs", tags=["configs"])
 app.include_router(auths.router, prefix="/auths", tags=["auths"])
 app.include_router(users.router, prefix="/users", tags=["users"])
 app.include_router(chats.router, prefix="/chats", tags=["chats"])
 
 app.include_router(documents.router, prefix="/documents", tags=["documents"])
-app.include_router(tools.router, prefix="/tools", tags=["tools"])
 app.include_router(models.router, prefix="/models", tags=["models"])
 app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
+
 app.include_router(memories.router, prefix="/memories", tags=["memories"])
+app.include_router(files.router, prefix="/files", tags=["files"])
+app.include_router(tools.router, prefix="/tools", tags=["tools"])
+app.include_router(functions.router, prefix="/functions", tags=["functions"])
 
-app.include_router(configs.router, prefix="/configs", tags=["configs"])
 app.include_router(utils.router, prefix="/utils", tags=["utils"])
-app.include_router(files.router, prefix="/files", tags=["files"])
 
 
 @app.get("/")
@@ -93,3 +100,58 @@ async def get_status():
         "default_models": app.state.config.DEFAULT_MODELS,
         "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
     }
+
+
+async def get_pipe_models():
+    pipes = Functions.get_functions_by_type("pipe")
+    pipe_models = []
+
+    for pipe in pipes:
+        # Check if function is already loaded
+        if pipe.id not in app.state.FUNCTIONS:
+            function_module, function_type = load_function_module_by_id(pipe.id)
+            app.state.FUNCTIONS[pipe.id] = function_module
+        else:
+            function_module = app.state.FUNCTIONS[pipe.id]
+
+        # Check if function is a manifold
+        if hasattr(function_module, "type"):
+            if function_module.type == "manifold":
+                manifold_pipes = []
+
+                # Check if pipes is a function or a list
+                if callable(function_module.pipes):
+                    manifold_pipes = function_module.pipes()
+                else:
+                    manifold_pipes = function_module.pipes
+
+                for p in manifold_pipes:
+                    manifold_pipe_id = f'{pipe.id}.{p["id"]}'
+                    manifold_pipe_name = p["name"]
+
+                    if hasattr(function_module, "name"):
+                        manifold_pipe_name = f"{pipe.name}{manifold_pipe_name}"
+
+                    pipe_models.append(
+                        {
+                            "id": manifold_pipe_id,
+                            "name": manifold_pipe_name,
+                            "object": "model",
+                            "created": pipe.created_at,
+                            "owned_by": "openai",
+                            "pipe": {"type": pipe.type},
+                        }
+                    )
+        else:
+            pipe_models.append(
+                {
+                    "id": pipe.id,
+                    "name": pipe.name,
+                    "object": "model",
+                    "created": pipe.created_at,
+                    "owned_by": "openai",
+                    "pipe": {"type": "pipe"},
+                }
+            )
+
+    return pipe_models

+ 5 - 4
backend/apps/webui/models/functions.py

@@ -55,6 +55,7 @@ class FunctionModel(BaseModel):
 class FunctionResponse(BaseModel):
     id: str
     user_id: str
+    type: str
     name: str
     meta: FunctionMeta
     updated_at: int  # timestamp in epoch
@@ -64,23 +65,23 @@ class FunctionResponse(BaseModel):
 class FunctionForm(BaseModel):
     id: str
     name: str
-    type: str
     content: str
     meta: FunctionMeta
 
 
-class ToolsTable:
+class FunctionsTable:
     def __init__(self, db):
         self.db = db
         self.db.create_tables([Function])
 
     def insert_new_function(
-        self, user_id: str, form_data: FunctionForm
+        self, user_id: str, type: str, form_data: FunctionForm
     ) -> Optional[FunctionModel]:
         function = FunctionModel(
             **{
                 **form_data.model_dump(),
                 "user_id": user_id,
+                "type": type,
                 "updated_at": int(time.time()),
                 "created_at": int(time.time()),
             }
@@ -137,4 +138,4 @@ class ToolsTable:
             return False
 
 
-Tools = ToolsTable(DB)
+Functions = FunctionsTable(DB)

+ 180 - 0
backend/apps/webui/routers/functions.py

@@ -0,0 +1,180 @@
+from fastapi import Depends, FastAPI, HTTPException, status, Request
+from datetime import datetime, timedelta
+from typing import List, Union, Optional
+
+from fastapi import APIRouter
+from pydantic import BaseModel
+import json
+
+from apps.webui.models.functions import (
+    Functions,
+    FunctionForm,
+    FunctionModel,
+    FunctionResponse,
+)
+from apps.webui.utils import load_function_module_by_id
+from utils.utils import get_verified_user, get_admin_user
+from constants import ERROR_MESSAGES
+
+from importlib import util
+import os
+from pathlib import Path
+
+from config import DATA_DIR, CACHE_DIR, FUNCTIONS_DIR
+
+
+router = APIRouter()
+
+############################
+# GetFunctions
+############################
+
+
+@router.get("/", response_model=List[FunctionResponse])
+async def get_functions(user=Depends(get_verified_user)):
+    return Functions.get_functions()
+
+
+############################
+# ExportFunctions
+############################
+
+
+@router.get("/export", response_model=List[FunctionModel])
+async def get_functions(user=Depends(get_admin_user)):
+    return Functions.get_functions()
+
+
+############################
+# CreateNewFunction
+############################
+
+
+@router.post("/create", response_model=Optional[FunctionResponse])
+async def create_new_function(
+    request: Request, form_data: FunctionForm, user=Depends(get_admin_user)
+):
+    if not form_data.id.isidentifier():
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail="Only alphanumeric characters and underscores are allowed in the id",
+        )
+
+    form_data.id = form_data.id.lower()
+
+    function = Functions.get_function_by_id(form_data.id)
+    if function == None:
+        function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py")
+        try:
+            with open(function_path, "w") as function_file:
+                function_file.write(form_data.content)
+
+            function_module, function_type = load_function_module_by_id(form_data.id)
+
+            FUNCTIONS = request.app.state.FUNCTIONS
+            FUNCTIONS[form_data.id] = function_module
+
+            function = Functions.insert_new_function(user.id, function_type, form_data)
+
+            function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id
+            function_cache_dir.mkdir(parents=True, exist_ok=True)
+
+            if function:
+                return function
+            else:
+                raise HTTPException(
+                    status_code=status.HTTP_400_BAD_REQUEST,
+                    detail=ERROR_MESSAGES.DEFAULT("Error creating function"),
+                )
+        except Exception as e:
+            print(e)
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT(e),
+            )
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.ID_TAKEN,
+        )
+
+
+############################
+# GetFunctionById
+############################
+
+
+@router.get("/id/{id}", response_model=Optional[FunctionModel])
+async def get_function_by_id(id: str, user=Depends(get_admin_user)):
+    function = Functions.get_function_by_id(id)
+
+    if function:
+        return function
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+
+############################
+# UpdateFunctionById
+############################
+
+
+@router.post("/id/{id}/update", response_model=Optional[FunctionModel])
+async def update_toolkit_by_id(
+    request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user)
+):
+    function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
+
+    try:
+        with open(function_path, "w") as function_file:
+            function_file.write(form_data.content)
+
+        function_module, function_type = load_function_module_by_id(id)
+
+        FUNCTIONS = request.app.state.FUNCTIONS
+        FUNCTIONS[id] = function_module
+
+        updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
+        print(updated)
+
+        function = Functions.update_function_by_id(id, updated)
+
+        if function:
+            return function
+        else:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
+            )
+
+    except Exception as e:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.DEFAULT(e),
+        )
+
+
+############################
+# DeleteFunctionById
+############################
+
+
+@router.delete("/id/{id}/delete", response_model=bool)
+async def delete_function_by_id(
+    request: Request, id: str, user=Depends(get_admin_user)
+):
+    result = Functions.delete_function_by_id(id)
+
+    if result:
+        FUNCTIONS = request.app.state.FUNCTIONS
+        if id in FUNCTIONS:
+            del FUNCTIONS[id]
+
+        # delete the function file
+        function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
+        os.remove(function_path)
+
+    return result

+ 23 - 1
backend/apps/webui/utils.py

@@ -1,7 +1,7 @@
 from importlib import util
 import os
 
-from config import TOOLS_DIR
+from config import TOOLS_DIR, FUNCTIONS_DIR
 
 
 def load_toolkit_module_by_id(toolkit_id):
@@ -21,3 +21,25 @@ def load_toolkit_module_by_id(toolkit_id):
         # Move the file to the error folder
         os.rename(toolkit_path, f"{toolkit_path}.error")
         raise e
+
+
+def load_function_module_by_id(function_id):
+    function_path = os.path.join(FUNCTIONS_DIR, f"{function_id}.py")
+
+    spec = util.spec_from_file_location(function_id, function_path)
+    module = util.module_from_spec(spec)
+
+    try:
+        spec.loader.exec_module(module)
+        print(f"Loaded module: {module.__name__}")
+        if hasattr(module, "Pipe"):
+            return module.Pipe(), "pipe"
+        elif hasattr(module, "Filter"):
+            return module.Filter(), "filter"
+        else:
+            raise Exception("No Function class found")
+    except Exception as e:
+        print(f"Error loading module: {function_id}")
+        # Move the file to the error folder
+        os.rename(function_path, f"{function_path}.error")
+        raise e

+ 8 - 0
backend/config.py

@@ -377,6 +377,14 @@ TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
 Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
 
 
+####################################
+# Functions DIR
+####################################
+
+FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions")
+Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True)
+
+
 ####################################
 # LITELLM_CONFIG
 ####################################

+ 342 - 124
backend/main.py

@@ -15,6 +15,7 @@ import uuid
 import inspect
 import asyncio
 
+from fastapi.concurrency import run_in_threadpool
 from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
 from fastapi.staticfiles import StaticFiles
 from fastapi.responses import JSONResponse
@@ -42,15 +43,17 @@ from apps.openai.main import (
 from apps.audio.main import app as audio_app
 from apps.images.main import app as images_app
 from apps.rag.main import app as rag_app
-from apps.webui.main import app as webui_app
+from apps.webui.main import app as webui_app, get_pipe_models
 
 
 from pydantic import BaseModel
-from typing import List, Optional
+from typing import List, Optional, Iterator, Generator, Union
 
 from apps.webui.models.models import Models, ModelModel
 from apps.webui.models.tools import Tools
-from apps.webui.utils import load_toolkit_module_by_id
+from apps.webui.models.functions import Functions
+
+from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
 
 
 from utils.utils import (
@@ -64,7 +67,11 @@ from utils.task import (
     search_query_generation_template,
     tools_function_calling_generation_template,
 )
-from utils.misc import get_last_user_message, add_or_update_system_message
+from utils.misc import (
+    get_last_user_message,
+    add_or_update_system_message,
+    stream_message_template,
+)
 
 from apps.rag.utils import get_rag_context, rag_template
 
@@ -170,6 +177,13 @@ app.state.MODELS = {}
 origins = ["*"]
 
 
+##################################
+#
+# ChatCompletion Middleware
+#
+##################################
+
+
 async def get_function_call_response(
     messages, files, tool_id, template, task_model_id, user
 ):
@@ -309,41 +323,72 @@ async def get_function_call_response(
 
 class ChatCompletionMiddleware(BaseHTTPMiddleware):
     async def dispatch(self, request: Request, call_next):
-        return_citations = False
+        data_items = []
 
-        if request.method == "POST" and (
-            "/ollama/api/chat" in request.url.path
-            or "/chat/completions" in request.url.path
+        if request.method == "POST" and any(
+            endpoint in request.url.path
+            for endpoint in ["/ollama/api/chat", "/chat/completions"]
         ):
             log.debug(f"request.url.path: {request.url.path}")
 
             # Read the original request body
             body = await request.body()
-            # Decode body to string
             body_str = body.decode("utf-8")
-            # Parse string to JSON
             data = json.loads(body_str) if body_str else {}
 
             user = get_current_user(
                 request,
                 get_http_authorization_cred(request.headers.get("Authorization")),
             )
+            # Flag to skip RAG completions if file_handler is present in tools/functions
+            skip_files = False
 
-            # Remove the citations from the body
-            return_citations = data.get("citations", False)
-            if "citations" in data:
-                del data["citations"]
-
-            # Set the task model
-            task_model_id = data["model"]
-            if task_model_id not in app.state.MODELS:
+            model_id = data["model"]
+            if model_id not in app.state.MODELS:
                 raise HTTPException(
                     status_code=status.HTTP_404_NOT_FOUND,
                     detail="Model not found",
                 )
+            model = app.state.MODELS[model_id]
+
+            # Check if the model has any filters
+            if "info" in model and "meta" in model["info"]:
+                for filter_id in model["info"]["meta"].get("filterIds", []):
+                    filter = Functions.get_function_by_id(filter_id)
+                    if filter:
+                        if filter_id in webui_app.state.FUNCTIONS:
+                            function_module = webui_app.state.FUNCTIONS[filter_id]
+                        else:
+                            function_module, function_type = load_function_module_by_id(
+                                filter_id
+                            )
+                            webui_app.state.FUNCTIONS[filter_id] = function_module
+
+                        # Check if the function has a file_handler variable
+                        if getattr(function_module, "file_handler"):
+                            skip_files = True
 
-            # Check if the user has a custom task model
-            # If the user has a custom task model, use that model
+                        try:
+                            if hasattr(function_module, "inlet"):
+                                data = function_module.inlet(
+                                    data,
+                                    {
+                                        "id": user.id,
+                                        "email": user.email,
+                                        "name": user.name,
+                                        "role": user.role,
+                                    },
+                                )
+                        except Exception as e:
+                            print(f"Error: {e}")
+                            return JSONResponse(
+                                status_code=status.HTTP_400_BAD_REQUEST,
+                                content={"detail": str(e)},
+                            )
+
+            # Set the task model
+            task_model_id = data["model"]
+            # Check if the user has a custom task model and use that model
             if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
                 if (
                     app.state.config.TASK_MODEL
@@ -361,8 +406,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
             context = ""
 
             # If tool_ids field is present, call the functions
-
-            skip_files = False
             if "tool_ids" in data:
                 print(data["tool_ids"])
                 for tool_id in data["tool_ids"]:
@@ -408,18 +451,22 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                         context += ("\n" if context != "" else "") + rag_context
 
                     log.debug(f"rag_context: {rag_context}, citations: {citations}")
-                else:
-                    return_citations = False
+
+                    if citations and data.get("citations"):
+                        data_items.append({"citations": citations})
 
                 del data["files"]
 
+            if data.get("citations"):
+                del data["citations"]
+
             if context != "":
                 system_prompt = rag_template(
                     rag_app.state.config.RAG_TEMPLATE, context, prompt
                 )
                 print(system_prompt)
                 data["messages"] = add_or_update_system_message(
-                    f"\n{system_prompt}", data["messages"]
+                    system_prompt, data["messages"]
                 )
 
             modified_body_bytes = json.dumps(data).encode("utf-8")
@@ -435,40 +482,51 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                 ],
             ]
 
-        response = await call_next(request)
-
-        if return_citations:
-            # Inject the citations into the response
+            response = await call_next(request)
             if isinstance(response, StreamingResponse):
                 # If it's a streaming response, inject it as SSE event or NDJSON line
                 content_type = response.headers.get("Content-Type")
                 if "text/event-stream" in content_type:
                     return StreamingResponse(
-                        self.openai_stream_wrapper(response.body_iterator, citations),
+                        self.openai_stream_wrapper(response.body_iterator, data_items),
                     )
                 if "application/x-ndjson" in content_type:
                     return StreamingResponse(
-                        self.ollama_stream_wrapper(response.body_iterator, citations),
+                        self.ollama_stream_wrapper(response.body_iterator, data_items),
                     )
+            else:
+                return response
 
+        # If it's not a chat completion request, just pass it through
+        response = await call_next(request)
         return response
 
     async def _receive(self, body: bytes):
         return {"type": "http.request", "body": body, "more_body": False}
 
-    async def openai_stream_wrapper(self, original_generator, citations):
-        yield f"data: {json.dumps({'citations': citations})}\n\n"
+    async def openai_stream_wrapper(self, original_generator, data_items):
+        for item in data_items:
+            yield f"data: {json.dumps(item)}\n\n"
+
         async for data in original_generator:
             yield data
 
-    async def ollama_stream_wrapper(self, original_generator, citations):
-        yield f"{json.dumps({'citations': citations})}\n"
+    async def ollama_stream_wrapper(self, original_generator, data_items):
+        for item in data_items:
+            yield f"{json.dumps(item)}\n"
+
         async for data in original_generator:
             yield data
 
 
 app.add_middleware(ChatCompletionMiddleware)
 
+##################################
+#
+# Pipeline Middleware
+#
+##################################
+
 
 def filter_pipeline(payload, user):
     user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
@@ -628,7 +686,6 @@ async def update_embedding_function(request: Request, call_next):
 
 app.mount("/ws", socket_app)
 
-
 app.mount("/ollama", ollama_app)
 app.mount("/openai", openai_app)
 
@@ -642,17 +699,18 @@ webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
 
 
 async def get_all_models():
+    pipe_models = []
     openai_models = []
     ollama_models = []
 
+    pipe_models = await get_pipe_models()
+
     if app.state.config.ENABLE_OPENAI_API:
         openai_models = await get_openai_models()
-
         openai_models = openai_models["data"]
 
     if app.state.config.ENABLE_OLLAMA_API:
         ollama_models = await get_ollama_models()
-
         ollama_models = [
             {
                 "id": model["model"],
@@ -665,9 +723,9 @@ async def get_all_models():
             for model in ollama_models["models"]
         ]
 
-    models = openai_models + ollama_models
-    custom_models = Models.get_all_models()
+    models = pipe_models + openai_models + ollama_models
 
+    custom_models = Models.get_all_models()
     for custom_model in custom_models:
         if custom_model.base_model_id == None:
             for model in models:
@@ -730,6 +788,234 @@ async def get_models(user=Depends(get_verified_user)):
     return {"data": models}
 
 
+@app.post("/api/chat/completions")
+async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
+    model_id = form_data["model"]
+    if model_id not in app.state.MODELS:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail="Model not found",
+        )
+
+    model = app.state.MODELS[model_id]
+    print(model)
+
+    pipe = model.get("pipe")
+    if pipe:
+        form_data["user"] = {
+            "id": user.id,
+            "email": user.email,
+            "name": user.name,
+            "role": user.role,
+        }
+
+        def job():
+            pipe_id = form_data["model"]
+            if "." in pipe_id:
+                pipe_id, sub_pipe_id = pipe_id.split(".", 1)
+            print(pipe_id)
+
+            pipe = webui_app.state.FUNCTIONS[pipe_id].pipe
+            if form_data["stream"]:
+
+                def stream_content():
+                    res = pipe(body=form_data)
+
+                    if isinstance(res, str):
+                        message = stream_message_template(form_data["model"], res)
+                        yield f"data: {json.dumps(message)}\n\n"
+
+                    if isinstance(res, Iterator):
+                        for line in res:
+                            if isinstance(line, BaseModel):
+                                line = line.model_dump_json()
+                                line = f"data: {line}"
+                            try:
+                                line = line.decode("utf-8")
+                            except:
+                                pass
+
+                            if line.startswith("data:"):
+                                yield f"{line}\n\n"
+                            else:
+                                line = stream_message_template(form_data["model"], line)
+                                yield f"data: {json.dumps(line)}\n\n"
+
+                    if isinstance(res, str) or isinstance(res, Generator):
+                        finish_message = {
+                            "id": f"{form_data['model']}-{str(uuid.uuid4())}",
+                            "object": "chat.completion.chunk",
+                            "created": int(time.time()),
+                            "model": form_data["model"],
+                            "choices": [
+                                {
+                                    "index": 0,
+                                    "delta": {},
+                                    "logprobs": None,
+                                    "finish_reason": "stop",
+                                }
+                            ],
+                        }
+
+                        yield f"data: {json.dumps(finish_message)}\n\n"
+                        yield f"data: [DONE]"
+
+                return StreamingResponse(
+                    stream_content(), media_type="text/event-stream"
+                )
+            else:
+                res = pipe(body=form_data)
+
+                if isinstance(res, dict):
+                    return res
+                elif isinstance(res, BaseModel):
+                    return res.model_dump()
+                else:
+                    message = ""
+                    if isinstance(res, str):
+                        message = res
+                    if isinstance(res, Generator):
+                        for stream in res:
+                            message = f"{message}{stream}"
+
+                    return {
+                        "id": f"{form_data['model']}-{str(uuid.uuid4())}",
+                        "object": "chat.completion",
+                        "created": int(time.time()),
+                        "model": form_data["model"],
+                        "choices": [
+                            {
+                                "index": 0,
+                                "message": {
+                                    "role": "assistant",
+                                    "content": message,
+                                },
+                                "logprobs": None,
+                                "finish_reason": "stop",
+                            }
+                        ],
+                    }
+
+        return await run_in_threadpool(job)
+    if model["owned_by"] == "ollama":
+        return await generate_ollama_chat_completion(form_data, user=user)
+    else:
+        return await generate_openai_chat_completion(form_data, user=user)
+
+
+@app.post("/api/chat/completed")
+async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
+    data = form_data
+    model_id = data["model"]
+    if model_id not in app.state.MODELS:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail="Model not found",
+        )
+    model = app.state.MODELS[model_id]
+
+    filters = [
+        model
+        for model in app.state.MODELS.values()
+        if "pipeline" in model
+        and "type" in model["pipeline"]
+        and model["pipeline"]["type"] == "filter"
+        and (
+            model["pipeline"]["pipelines"] == ["*"]
+            or any(
+                model_id == target_model_id
+                for target_model_id in model["pipeline"]["pipelines"]
+            )
+        )
+    ]
+
+    sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
+    if "pipeline" in model:
+        sorted_filters = [model] + sorted_filters
+
+    for filter in sorted_filters:
+        r = None
+        try:
+            urlIdx = filter["urlIdx"]
+
+            url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
+            key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
+
+            if key != "":
+                headers = {"Authorization": f"Bearer {key}"}
+                r = requests.post(
+                    f"{url}/{filter['id']}/filter/outlet",
+                    headers=headers,
+                    json={
+                        "user": {"id": user.id, "name": user.name, "role": user.role},
+                        "body": data,
+                    },
+                )
+
+                r.raise_for_status()
+                data = r.json()
+        except Exception as e:
+            # Handle connection error here
+            print(f"Connection error: {e}")
+
+            if r is not None:
+                try:
+                    res = r.json()
+                    if "detail" in res:
+                        return JSONResponse(
+                            status_code=r.status_code,
+                            content=res,
+                        )
+                except:
+                    pass
+
+            else:
+                pass
+
+    # Check if the model has any filters
+    if "info" in model and "meta" in model["info"]:
+        for filter_id in model["info"]["meta"].get("filterIds", []):
+            filter = Functions.get_function_by_id(filter_id)
+            if filter:
+                if filter_id in webui_app.state.FUNCTIONS:
+                    function_module = webui_app.state.FUNCTIONS[filter_id]
+                else:
+                    function_module, function_type = load_function_module_by_id(
+                        filter_id
+                    )
+                    webui_app.state.FUNCTIONS[filter_id] = function_module
+
+                try:
+                    if hasattr(function_module, "outlet"):
+                        data = function_module.outlet(
+                            data,
+                            {
+                                "id": user.id,
+                                "email": user.email,
+                                "name": user.name,
+                                "role": user.role,
+                            },
+                        )
+                except Exception as e:
+                    print(f"Error: {e}")
+                    return JSONResponse(
+                        status_code=status.HTTP_400_BAD_REQUEST,
+                        content={"detail": str(e)},
+                    )
+
+    return data
+
+
+##################################
+#
+# Task Endpoints
+#
+##################################
+
+
+# TODO: Refactor task API endpoints below into a separate file
+
+
 @app.get("/api/task/config")
 async def get_task_config(user=Depends(get_verified_user)):
     return {
@@ -1015,92 +1301,14 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
         )
 
 
-@app.post("/api/chat/completions")
-async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
-    model_id = form_data["model"]
-    if model_id not in app.state.MODELS:
-        raise HTTPException(
-            status_code=status.HTTP_404_NOT_FOUND,
-            detail="Model not found",
-        )
+##################################
+#
+# Pipelines Endpoints
+#
+##################################
 
-    model = app.state.MODELS[model_id]
-    print(model)
 
-    if model["owned_by"] == "ollama":
-        return await generate_ollama_chat_completion(form_data, user=user)
-    else:
-        return await generate_openai_chat_completion(form_data, user=user)
-
-
-@app.post("/api/chat/completed")
-async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
-    data = form_data
-    model_id = data["model"]
-
-    filters = [
-        model
-        for model in app.state.MODELS.values()
-        if "pipeline" in model
-        and "type" in model["pipeline"]
-        and model["pipeline"]["type"] == "filter"
-        and (
-            model["pipeline"]["pipelines"] == ["*"]
-            or any(
-                model_id == target_model_id
-                for target_model_id in model["pipeline"]["pipelines"]
-            )
-        )
-    ]
-    sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
-
-    print(model_id)
-
-    if model_id in app.state.MODELS:
-        model = app.state.MODELS[model_id]
-        if "pipeline" in model:
-            sorted_filters = [model] + sorted_filters
-
-    for filter in sorted_filters:
-        r = None
-        try:
-            urlIdx = filter["urlIdx"]
-
-            url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
-            key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
-
-            if key != "":
-                headers = {"Authorization": f"Bearer {key}"}
-                r = requests.post(
-                    f"{url}/{filter['id']}/filter/outlet",
-                    headers=headers,
-                    json={
-                        "user": {"id": user.id, "name": user.name, "role": user.role},
-                        "body": data,
-                    },
-                )
-
-                r.raise_for_status()
-                data = r.json()
-        except Exception as e:
-            # Handle connection error here
-            print(f"Connection error: {e}")
-
-            if r is not None:
-                try:
-                    res = r.json()
-                    if "detail" in res:
-                        return JSONResponse(
-                            status_code=r.status_code,
-                            content=res,
-                        )
-                except:
-                    pass
-
-            else:
-                pass
-
-    return data
+# TODO: Refactor pipelines API endpoints below into a separate file
 
 
 @app.get("/api/pipelines/list")
@@ -1423,6 +1631,13 @@ async def update_pipeline_valves(
         )
 
 
+##################################
+#
+# Config Endpoints
+#
+##################################
+
+
 @app.get("/api/config")
 async def get_app_config():
     # Checking and Handling the Absence of 'ui' in CONFIG_DATA
@@ -1486,6 +1701,9 @@ async def update_model_filter_config(
     }
 
 
+# TODO: webhook endpoint should be under config endpoints
+
+
 @app.get("/api/webhook")
 async def get_webhook_url(user=Depends(get_admin_user)):
     return {

+ 19 - 0
backend/utils/misc.py

@@ -4,6 +4,8 @@ import json
 import re
 from datetime import timedelta
 from typing import Optional, List, Tuple
+import uuid
+import time
 
 
 def get_last_user_message(messages: List[dict]) -> str:
@@ -62,6 +64,23 @@ def add_or_update_system_message(content: str, messages: List[dict]):
     return messages
 
 
+def stream_message_template(model: str, message: str):
+    return {
+        "id": f"{model}-{str(uuid.uuid4())}",
+        "object": "chat.completion.chunk",
+        "created": int(time.time()),
+        "model": model,
+        "choices": [
+            {
+                "index": 0,
+                "delta": {"content": message},
+                "logprobs": None,
+                "finish_reason": None,
+            }
+        ],
+    }
+
+
 def get_gravatar_url(email):
     # Trim leading and trailing whitespace from
     # an email address and force all characters

+ 193 - 0
src/lib/apis/functions/index.ts

@@ -0,0 +1,193 @@
+import { WEBUI_API_BASE_URL } from '$lib/constants';
+
+export const createNewFunction = async (token: string, func: object) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/functions/create`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			...func
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			error = err.detail;
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const getFunctions = async (token: string = '') => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/functions/`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const exportFunctions = async (token: string = '') => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/functions/export`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const getFunctionById = async (token: string, id: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const updateFunctionById = async (token: string, id: string, func: object) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/update`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			...func
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const deleteFunctionById = async (token: string, id: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/delete`, {
+		method: 'DELETE',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};

+ 14 - 4
src/lib/components/chat/Chat.svelte

@@ -278,7 +278,9 @@
 			})),
 			chat_id: $chatId
 		}).catch((error) => {
-			console.error(error);
+			toast.error(error);
+			messages.at(-1).error = { content: error };
+
 			return null;
 		});
 
@@ -323,6 +325,13 @@
 		} else if (messages.length != 0 && messages.at(-1).done != true) {
 			// Response not done
 			console.log('wait');
+		} else if (messages.length != 0 && messages.at(-1).error) {
+			// Error in response
+			toast.error(
+				$i18n.t(
+					`Oops! There was an error in the previous response. Please try again or contact admin.`
+				)
+			);
 		} else if (
 			files.length > 0 &&
 			files.filter((file) => file.type !== 'image' && file.status !== 'processed').length > 0
@@ -630,7 +639,7 @@
 			keep_alive: $settings.keepAlive ?? undefined,
 			tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
 			files: files.length > 0 ? files : undefined,
-			citations: files.length > 0,
+			citations: files.length > 0 ? true : undefined,
 			chat_id: $chatId
 		});
 
@@ -928,10 +937,11 @@
 					max_tokens: $settings?.params?.max_tokens ?? undefined,
 					tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
 					files: files.length > 0 ? files : undefined,
-					citations: files.length > 0,
+					citations: files.length > 0 ? true : undefined,
+
 					chat_id: $chatId
 				},
-				`${OPENAI_API_BASE_URL}`
+				`${WEBUI_BASE_URL}/api`
 			);
 
 			// Wait until history/message have been updated

+ 73 - 54
src/lib/components/workspace/Functions.svelte

@@ -3,25 +3,27 @@
 	import fileSaver from 'file-saver';
 	const { saveAs } = fileSaver;
 
+	import { WEBUI_NAME, functions, models } from '$lib/stores';
 	import { onMount, getContext } from 'svelte';
-	import { WEBUI_NAME, prompts, tools } from '$lib/stores';
 	import { createNewPrompt, deletePromptByCommand, getPrompts } from '$lib/apis/prompts';
 
 	import { goto } from '$app/navigation';
 	import {
-		createNewTool,
-		deleteToolById,
-		exportTools,
-		getToolById,
-		getTools
-	} from '$lib/apis/tools';
+		createNewFunction,
+		deleteFunctionById,
+		exportFunctions,
+		getFunctionById,
+		getFunctions
+	} from '$lib/apis/functions';
+
 	import ArrowDownTray from '../icons/ArrowDownTray.svelte';
 	import Tooltip from '../common/Tooltip.svelte';
 	import ConfirmDialog from '../common/ConfirmDialog.svelte';
+	import { getModels } from '$lib/apis';
 
 	const i18n = getContext('i18n');
 
-	let toolsImportInputElement: HTMLInputElement;
+	let functionsImportInputElement: HTMLInputElement;
 	let importFiles;
 
 	let showConfirm = false;
@@ -64,7 +66,7 @@
 	<div>
 		<a
 			class=" px-2 py-2 rounded-xl border border-gray-200 dark:border-gray-600 dark:border-0 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 transition font-medium text-sm flex items-center space-x-1"
-			href="/workspace/tools/create"
+			href="/workspace/functions/create"
 		>
 			<svg
 				xmlns="http://www.w3.org/2000/svg"
@@ -82,30 +84,40 @@
 <hr class=" dark:border-gray-850 my-2.5" />
 
 <div class="my-3 mb-5">
-	{#each $tools.filter((t) => query === '' || t.name
+	{#each $functions.filter((f) => query === '' || f.name
 				.toLowerCase()
-				.includes(query.toLowerCase()) || t.id.toLowerCase().includes(query.toLowerCase())) as tool}
+				.includes(query.toLowerCase()) || f.id.toLowerCase().includes(query.toLowerCase())) as func}
 		<button
 			class=" flex space-x-4 cursor-pointer w-full px-3 py-2 dark:hover:bg-white/5 hover:bg-black/5 rounded-xl"
 			type="button"
 			on:click={() => {
-				goto(`/workspace/tools/edit?id=${encodeURIComponent(tool.id)}`);
+				goto(`/workspace/functions/edit?id=${encodeURIComponent(func.id)}`);
 			}}
 		>
 			<div class=" flex flex-1 space-x-4 cursor-pointer w-full">
 				<a
-					href={`/workspace/tools/edit?id=${encodeURIComponent(tool.id)}`}
+					href={`/workspace/functions/edit?id=${encodeURIComponent(func.id)}`}
 					class="flex items-center text-left"
 				>
-					<div class=" flex-1 self-center pl-5">
+					<div class=" flex-1 self-center pl-1">
 						<div class=" font-semibold flex items-center gap-1.5">
+							<div
+								class=" text-xs font-black px-1 rounded uppercase line-clamp-1 bg-gray-500/20 text-gray-700 dark:text-gray-200"
+							>
+								{func.type}
+							</div>
+
 							<div>
-								{tool.name}
+								{func.name}
 							</div>
-							<div class=" text-gray-500 text-xs font-medium">{tool.id}</div>
 						</div>
-						<div class=" text-xs overflow-hidden text-ellipsis line-clamp-1">
-							{tool.meta.description}
+
+						<div class="flex gap-1.5 px-1">
+							<div class=" text-gray-500 text-xs font-medium">{func.id}</div>
+
+							<div class=" text-xs overflow-hidden text-ellipsis line-clamp-1">
+								{func.meta.description}
+							</div>
 						</div>
 					</div>
 				</a>
@@ -115,7 +127,7 @@
 					<a
 						class="self-center w-fit text-sm px-2 py-2 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl"
 						type="button"
-						href={`/workspace/tools/edit?id=${encodeURIComponent(tool.id)}`}
+						href={`/workspace/functions/edit?id=${encodeURIComponent(func.id)}`}
 					>
 						<svg
 							xmlns="http://www.w3.org/2000/svg"
@@ -141,18 +153,20 @@
 						on:click={async (e) => {
 							e.stopPropagation();
 
-							const _tool = await getToolById(localStorage.token, tool.id).catch((error) => {
-								toast.error(error);
-								return null;
-							});
-
-							if (_tool) {
-								sessionStorage.tool = JSON.stringify({
-									..._tool,
-									id: `${_tool.id}_clone`,
-									name: `${_tool.name} (Clone)`
+							const _function = await getFunctionById(localStorage.token, func.id).catch(
+								(error) => {
+									toast.error(error);
+									return null;
+								}
+							);
+
+							if (_function) {
+								sessionStorage.function = JSON.stringify({
+									..._function,
+									id: `${_function.id}_clone`,
+									name: `${_function.name} (Clone)`
 								});
-								goto('/workspace/tools/create');
+								goto('/workspace/functions/create');
 							}
 						}}
 					>
@@ -180,16 +194,18 @@
 						on:click={async (e) => {
 							e.stopPropagation();
 
-							const _tool = await getToolById(localStorage.token, tool.id).catch((error) => {
-								toast.error(error);
-								return null;
-							});
+							const _function = await getFunctionById(localStorage.token, func.id).catch(
+								(error) => {
+									toast.error(error);
+									return null;
+								}
+							);
 
-							if (_tool) {
-								let blob = new Blob([JSON.stringify([_tool])], {
+							if (_function) {
+								let blob = new Blob([JSON.stringify([_function])], {
 									type: 'application/json'
 								});
-								saveAs(blob, `tool-${_tool.id}-export-${Date.now()}.json`);
+								saveAs(blob, `function-${_function.id}-export-${Date.now()}.json`);
 							}
 						}}
 					>
@@ -204,14 +220,16 @@
 						on:click={async (e) => {
 							e.stopPropagation();
 
-							const res = await deleteToolById(localStorage.token, tool.id).catch((error) => {
+							const res = await deleteFunctionById(localStorage.token, func.id).catch((error) => {
 								toast.error(error);
 								return null;
 							});
 
 							if (res) {
-								toast.success('Tool deleted successfully');
-								tools.set(await getTools(localStorage.token));
+								toast.success('Function deleted successfully');
+
+								functions.set(await getFunctions(localStorage.token));
+								models.set(await getModels(localStorage.token));
 							}
 						}}
 					>
@@ -246,7 +264,7 @@
 	<div class="flex space-x-2">
 		<input
 			id="documents-import-input"
-			bind:this={toolsImportInputElement}
+			bind:this={functionsImportInputElement}
 			bind:files={importFiles}
 			type="file"
 			accept=".json"
@@ -260,7 +278,7 @@
 		<button
 			class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition"
 			on:click={() => {
-				toolsImportInputElement.click();
+				functionsImportInputElement.click();
 			}}
 		>
 			<div class=" self-center mr-2 font-medium">{$i18n.t('Import Functions')}</div>
@@ -284,16 +302,16 @@
 		<button
 			class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition"
 			on:click={async () => {
-				const _tools = await exportTools(localStorage.token).catch((error) => {
+				const _functions = await exportFunctions(localStorage.token).catch((error) => {
 					toast.error(error);
 					return null;
 				});
 
-				if (_tools) {
-					let blob = new Blob([JSON.stringify(_tools)], {
+				if (_functions) {
+					let blob = new Blob([JSON.stringify(_functions)], {
 						type: 'application/json'
 					});
-					saveAs(blob, `tools-export-${Date.now()}.json`);
+					saveAs(blob, `functions-export-${Date.now()}.json`);
 				}
 			}}
 		>
@@ -322,18 +340,19 @@
 	on:confirm={() => {
 		const reader = new FileReader();
 		reader.onload = async (event) => {
-			const _tools = JSON.parse(event.target.result);
-			console.log(_tools);
+			const _functions = JSON.parse(event.target.result);
+			console.log(_functions);
 
-			for (const tool of _tools) {
-				const res = await createNewTool(localStorage.token, tool).catch((error) => {
+			for (const func of _functions) {
+				const res = await createNewFunction(localStorage.token, func).catch((error) => {
 					toast.error(error);
 					return null;
 				});
 			}
 
-			toast.success('Tool imported successfully');
-			tools.set(await getTools(localStorage.token));
+			toast.success('Functions imported successfully');
+			functions.set(await getFunctions(localStorage.token));
+			models.set(await getModels(localStorage.token));
 		};
 
 		reader.readAsText(importFiles[0]);
@@ -344,8 +363,8 @@
 			<div>Please carefully review the following warnings:</div>
 
 			<ul class=" mt-1 list-disc pl-4 text-xs">
-				<li>Tools have a function calling system that allows arbitrary code execution.</li>
-				<li>Do not install tools from sources you do not fully trust.</li>
+				<li>Functions allow arbitrary code execution.</li>
+				<li>Do not install functions from sources you do not fully trust.</li>
 			</ul>
 		</div>
 

+ 235 - 0
src/lib/components/workspace/Functions/FunctionEditor.svelte

@@ -0,0 +1,235 @@
+<script>
+	import { getContext, createEventDispatcher, onMount } from 'svelte';
+	import { goto } from '$app/navigation';
+
+	const dispatch = createEventDispatcher();
+	const i18n = getContext('i18n');
+
+	import CodeEditor from '$lib/components/common/CodeEditor.svelte';
+	import ConfirmDialog from '$lib/components/common/ConfirmDialog.svelte';
+
+	let formElement = null;
+	let loading = false;
+	let showConfirm = false;
+
+	export let edit = false;
+	export let clone = false;
+
+	export let id = '';
+	export let name = '';
+	export let meta = {
+		description: ''
+	};
+	export let content = '';
+
+	$: if (name && !edit && !clone) {
+		id = name.replace(/\s+/g, '_').toLowerCase();
+	}
+
+	let codeEditor;
+	let boilerplate = `from pydantic import BaseModel
+from typing import Optional
+
+
+class Filter:
+    class Valves(BaseModel):
+        max_turns: int = 4
+        pass
+
+    def __init__(self):
+        # Indicates custom file handling logic. This flag helps disengage default routines in favor of custom
+        # implementations, informing the WebUI to defer file-related operations to designated methods within this class.
+        # Alternatively, you can remove the files directly from the body in from the inlet hook
+        self.file_handler = True
+
+        # Initialize 'valves' with specific configurations. Using 'Valves' instance helps encapsulate settings,
+        # which ensures settings are managed cohesively and not confused with operational flags like 'file_handler'.
+        self.valves = self.Valves(**{"max_turns": 2})
+        pass
+
+    def inlet(self, body: dict, user: Optional[dict] = None) -> dict:
+        # Modify the request body or validate it before processing by the chat completion API.
+        # This function is the pre-processor for the API where various checks on the input can be performed.
+        # It can also modify the request before sending it to the API.
+        print(f"inlet:{__name__}")
+        print(f"inlet:body:{body}")
+        print(f"inlet:user:{user}")
+
+        if user.get("role", "admin") in ["user", "admin"]:
+            messages = body.get("messages", [])
+            if len(messages) > self.valves.max_turns:
+                raise Exception(
+                    f"Conversation turn limit exceeded. Max turns: {self.valves.max_turns}"
+                )
+
+        return body
+
+    def outlet(self, body: dict, user: Optional[dict] = None) -> dict:
+        # Modify or analyze the response body after processing by the API.
+        # This function is the post-processor for the API, which can be used to modify the response
+        # or perform additional checks and analytics.
+        print(f"outlet:{__name__}")
+        print(f"outlet:body:{body}")
+        print(f"outlet:user:{user}")
+
+        messages = [
+            {
+                **message,
+                "content": f"{message['content']} - @@Modified from Filter Outlet",
+            }
+            for message in body.get("messages", [])
+        ]
+
+        return {"messages": messages}
+
+`;
+
+	const saveHandler = async () => {
+		loading = true;
+		dispatch('save', {
+			id,
+			name,
+			meta,
+			content
+		});
+	};
+
+	const submitHandler = async () => {
+		if (codeEditor) {
+			const res = await codeEditor.formatPythonCodeHandler();
+
+			if (res) {
+				console.log('Code formatted successfully');
+				saveHandler();
+			}
+		}
+	};
+</script>
+
+<div class=" flex flex-col justify-between w-full overflow-y-auto h-full">
+	<div class="mx-auto w-full md:px-0 h-full">
+		<form
+			bind:this={formElement}
+			class=" flex flex-col max-h-[100dvh] h-full"
+			on:submit|preventDefault={() => {
+				if (edit) {
+					submitHandler();
+				} else {
+					showConfirm = true;
+				}
+			}}
+		>
+			<div class="mb-2.5">
+				<button
+					class="flex space-x-1"
+					on:click={() => {
+						goto('/workspace/functions');
+					}}
+					type="button"
+				>
+					<div class=" self-center">
+						<svg
+							xmlns="http://www.w3.org/2000/svg"
+							viewBox="0 0 20 20"
+							fill="currentColor"
+							class="w-4 h-4"
+						>
+							<path
+								fill-rule="evenodd"
+								d="M17 10a.75.75 0 01-.75.75H5.612l4.158 3.96a.75.75 0 11-1.04 1.08l-5.5-5.25a.75.75 0 010-1.08l5.5-5.25a.75.75 0 111.04 1.08L5.612 9.25H16.25A.75.75 0 0117 10z"
+								clip-rule="evenodd"
+							/>
+						</svg>
+					</div>
+					<div class=" self-center font-medium text-sm">{$i18n.t('Back')}</div>
+				</button>
+			</div>
+
+			<div class="flex flex-col flex-1 overflow-auto h-0 rounded-lg">
+				<div class="w-full mb-2 flex flex-col gap-1.5">
+					<div class="flex gap-2 w-full">
+						<input
+							class="w-full px-3 py-2 text-sm font-medium bg-gray-50 dark:bg-gray-850 dark:text-gray-200 rounded-lg outline-none"
+							type="text"
+							placeholder="Function Name (e.g. My Filter)"
+							bind:value={name}
+							required
+						/>
+
+						<input
+							class="w-full px-3 py-2 text-sm font-medium disabled:text-gray-300 dark:disabled:text-gray-700 bg-gray-50 dark:bg-gray-850 dark:text-gray-200 rounded-lg outline-none"
+							type="text"
+							placeholder="Function ID (e.g. my_filter)"
+							bind:value={id}
+							required
+							disabled={edit}
+						/>
+					</div>
+					<input
+						class="w-full px-3 py-2 text-sm font-medium bg-gray-50 dark:bg-gray-850 dark:text-gray-200 rounded-lg outline-none"
+						type="text"
+						placeholder="Function Description (e.g. A filter to remove profanity from text)"
+						bind:value={meta.description}
+						required
+					/>
+				</div>
+
+				<div class="mb-2 flex-1 overflow-auto h-0 rounded-lg">
+					<CodeEditor
+						bind:value={content}
+						bind:this={codeEditor}
+						{boilerplate}
+						on:save={() => {
+							if (formElement) {
+								formElement.requestSubmit();
+							}
+						}}
+					/>
+				</div>
+
+				<div class="pb-3 flex justify-between">
+					<div class="flex-1 pr-3">
+						<div class="text-xs text-gray-500 line-clamp-2">
+							<span class=" font-semibold dark:text-gray-200">Warning:</span> Functions allow
+							arbitrary code execution <br />—
+							<span class=" font-medium dark:text-gray-400"
+								>don't install random functions from sources you don't trust.</span
+							>
+						</div>
+					</div>
+
+					<button
+						class="px-3 py-1.5 text-sm font-medium bg-emerald-600 hover:bg-emerald-700 text-gray-50 transition rounded-lg"
+						type="submit"
+					>
+						{$i18n.t('Save')}
+					</button>
+				</div>
+			</div>
+		</form>
+	</div>
+</div>
+
+<ConfirmDialog
+	bind:show={showConfirm}
+	on:confirm={() => {
+		submitHandler();
+	}}
+>
+	<div class="text-sm text-gray-500">
+		<div class=" bg-yellow-500/20 text-yellow-700 dark:text-yellow-200 rounded-lg px-4 py-3">
+			<div>Please carefully review the following warnings:</div>
+
+			<ul class=" mt-1 list-disc pl-4 text-xs">
+				<li>Functions allow arbitrary code execution.</li>
+				<li>Do not install functions from sources you do not fully trust.</li>
+			</ul>
+		</div>
+
+		<div class="my-3">
+			I acknowledge that I have read and I understand the implications of my action. I am aware of
+			the risks associated with executing arbitrary code and I have verified the trustworthiness of
+			the source.
+		</div>
+	</div>
+</ConfirmDialog>

+ 60 - 0
src/lib/components/workspace/Models/FiltersSelector.svelte

@@ -0,0 +1,60 @@
+<script lang="ts">
+	import { getContext, onMount } from 'svelte';
+	import Checkbox from '$lib/components/common/Checkbox.svelte';
+	import Tooltip from '$lib/components/common/Tooltip.svelte';
+
+	const i18n = getContext('i18n');
+
+	export let filters = [];
+	export let selectedFilterIds = [];
+
+	let _filters = {};
+
+	onMount(() => {
+		_filters = filters.reduce((acc, filter) => {
+			acc[filter.id] = {
+				...filter,
+				selected: selectedFilterIds.includes(filter.id)
+			};
+
+			return acc;
+		}, {});
+	});
+</script>
+
+<div>
+	<div class="flex w-full justify-between mb-1">
+		<div class=" self-center text-sm font-semibold">{$i18n.t('Filters')}</div>
+	</div>
+
+	<div class=" text-xs dark:text-gray-500">
+		{$i18n.t('To select filters here, add them to the "Functions" workspace first.')}
+	</div>
+
+	<!-- TODO: Filer order matters -->
+	<div class="flex flex-col">
+		{#if filters.length > 0}
+			<div class=" flex items-center mt-2 flex-wrap">
+				{#each Object.keys(_filters) as filter, filterIdx}
+					<div class=" flex items-center gap-2 mr-3">
+						<div class="self-center flex items-center">
+							<Checkbox
+								state={_filters[filter].selected ? 'checked' : 'unchecked'}
+								on:change={(e) => {
+									_filters[filter].selected = e.detail === 'checked';
+									selectedFilterIds = Object.keys(_filters).filter((t) => _filters[t].selected);
+								}}
+							/>
+						</div>
+
+						<div class=" py-0.5 text-sm w-full capitalize font-medium">
+							<Tooltip content={_filters[filter].meta.description}>
+								{_filters[filter].name}
+							</Tooltip>
+						</div>
+					</div>
+				{/each}
+			</div>
+		{/if}
+	</div>
+</div>

+ 2 - 0
src/lib/stores/index.ts

@@ -27,7 +27,9 @@ export const tags = writable([]);
 export const models: Writable<Model[]> = writable([]);
 export const prompts: Writable<Prompt[]> = writable([]);
 export const documents: Writable<Document[]> = writable([]);
+
 export const tools = writable([]);
+export const functions = writable([]);
 
 export const banners: Writable<Banner[]> = writable([]);
 

+ 6 - 1
src/routes/(app)/workspace/+layout.svelte

@@ -1,11 +1,16 @@
 <script lang="ts">
 	import { onMount, getContext } from 'svelte';
 
-	import { WEBUI_NAME, showSidebar } from '$lib/stores';
+	import { WEBUI_NAME, showSidebar, functions } from '$lib/stores';
 	import MenuLines from '$lib/components/icons/MenuLines.svelte';
 	import { page } from '$app/stores';
+	import { getFunctions } from '$lib/apis/functions';
 
 	const i18n = getContext('i18n');
+
+	onMount(async () => {
+		functions.set(await getFunctions(localStorage.token));
+	});
 </script>
 
 <svelte:head>

+ 22 - 19
src/routes/(app)/workspace/functions/create/+page.svelte

@@ -1,18 +1,20 @@
 <script>
-	import { goto } from '$app/navigation';
-	import { createNewTool, getTools } from '$lib/apis/tools';
-	import ToolkitEditor from '$lib/components/workspace/Tools/ToolkitEditor.svelte';
-	import { tools } from '$lib/stores';
-	import { onMount } from 'svelte';
 	import { toast } from 'svelte-sonner';
+	import { onMount } from 'svelte';
+	import { goto } from '$app/navigation';
+
+	import { functions, models } from '$lib/stores';
+	import { createNewFunction, getFunctions } from '$lib/apis/functions';
+	import FunctionEditor from '$lib/components/workspace/Functions/FunctionEditor.svelte';
+	import { getModels } from '$lib/apis';
 
 	let mounted = false;
 	let clone = false;
-	let tool = null;
+	let func = null;
 
 	const saveHandler = async (data) => {
 		console.log(data);
-		const res = await createNewTool(localStorage.token, {
+		const res = await createNewFunction(localStorage.token, {
 			id: data.id,
 			name: data.name,
 			meta: data.meta,
@@ -23,19 +25,20 @@
 		});
 
 		if (res) {
-			toast.success('Tool created successfully');
-			tools.set(await getTools(localStorage.token));
+			toast.success('Function created successfully');
+			functions.set(await getFunctions(localStorage.token));
+			models.set(await getModels(localStorage.token));
 
-			await goto('/workspace/tools');
+			await goto('/workspace/functions');
 		}
 	};
 
 	onMount(() => {
-		if (sessionStorage.tool) {
-			tool = JSON.parse(sessionStorage.tool);
-			sessionStorage.removeItem('tool');
+		if (sessionStorage.function) {
+			func = JSON.parse(sessionStorage.function);
+			sessionStorage.removeItem('function');
 
-			console.log(tool);
+			console.log(func);
 			clone = true;
 		}
 
@@ -44,11 +47,11 @@
 </script>
 
 {#if mounted}
-	<ToolkitEditor
-		id={tool?.id ?? ''}
-		name={tool?.name ?? ''}
-		meta={tool?.meta ?? { description: '' }}
-		content={tool?.content ?? ''}
+	<FunctionEditor
+		id={func?.id ?? ''}
+		name={func?.name ?? ''}
+		meta={func?.meta ?? { description: '' }}
+		content={func?.content ?? ''}
 		{clone}
 		on:save={(e) => {
 			saveHandler(e.detail);

+ 22 - 20
src/routes/(app)/workspace/functions/edit/+page.svelte

@@ -1,18 +1,21 @@
 <script>
+	import { toast } from 'svelte-sonner';
+	import { onMount } from 'svelte';
+
 	import { goto } from '$app/navigation';
 	import { page } from '$app/stores';
-	import { getToolById, getTools, updateToolById } from '$lib/apis/tools';
+	import { functions, models } from '$lib/stores';
+	import { updateFunctionById, getFunctions, getFunctionById } from '$lib/apis/functions';
+
+	import FunctionEditor from '$lib/components/workspace/Functions/FunctionEditor.svelte';
 	import Spinner from '$lib/components/common/Spinner.svelte';
-	import ToolkitEditor from '$lib/components/workspace/Tools/ToolkitEditor.svelte';
-	import { tools } from '$lib/stores';
-	import { onMount } from 'svelte';
-	import { toast } from 'svelte-sonner';
+	import { getModels } from '$lib/apis';
 
-	let tool = null;
+	let func = null;
 
 	const saveHandler = async (data) => {
 		console.log(data);
-		const res = await updateToolById(localStorage.token, tool.id, {
+		const res = await updateFunctionById(localStorage.token, func.id, {
 			id: data.id,
 			name: data.name,
 			meta: data.meta,
@@ -23,10 +26,9 @@
 		});
 
 		if (res) {
-			toast.success('Tool updated successfully');
-			tools.set(await getTools(localStorage.token));
-
-			// await goto('/workspace/tools');
+			toast.success('Function updated successfully');
+			functions.set(await getFunctions(localStorage.token));
+			models.set(await getModels(localStorage.token));
 		}
 	};
 
@@ -35,24 +37,24 @@
 		const id = $page.url.searchParams.get('id');
 
 		if (id) {
-			tool = await getToolById(localStorage.token, id).catch((error) => {
+			func = await getFunctionById(localStorage.token, id).catch((error) => {
 				toast.error(error);
-				goto('/workspace/tools');
+				goto('/workspace/functions');
 				return null;
 			});
 
-			console.log(tool);
+			console.log(func);
 		}
 	});
 </script>
 
-{#if tool}
-	<ToolkitEditor
+{#if func}
+	<FunctionEditor
 		edit={true}
-		id={tool.id}
-		name={tool.name}
-		meta={tool.meta}
-		content={tool.content}
+		id={func.id}
+		name={func.name}
+		meta={func.meta}
+		content={func.content}
 		on:save={(e) => {
 			saveHandler(e.detail);
 		}}

+ 22 - 1
src/routes/(app)/workspace/models/edit/+page.svelte

@@ -5,7 +5,7 @@
 
 	import { onMount, getContext } from 'svelte';
 	import { page } from '$app/stores';
-	import { settings, user, config, models, tools } from '$lib/stores';
+	import { settings, user, config, models, tools, functions } from '$lib/stores';
 	import { splitStream } from '$lib/utils';
 
 	import { getModelInfos, updateModelById } from '$lib/apis/models';
@@ -16,6 +16,7 @@
 	import Tags from '$lib/components/common/Tags.svelte';
 	import Knowledge from '$lib/components/workspace/Models/Knowledge.svelte';
 	import ToolsSelector from '$lib/components/workspace/Models/ToolsSelector.svelte';
+	import FiltersSelector from '$lib/components/workspace/Models/FiltersSelector.svelte';
 
 	const i18n = getContext('i18n');
 
@@ -62,6 +63,7 @@
 
 	let knowledge = [];
 	let toolIds = [];
+	let filterIds = [];
 
 	const updateHandler = async () => {
 		loading = true;
@@ -86,6 +88,14 @@
 			}
 		}
 
+		if (filterIds.length > 0) {
+			info.meta.filterIds = filterIds;
+		} else {
+			if (info.meta.filterIds) {
+				delete info.meta.filterIds;
+			}
+		}
+
 		info.params.stop = params.stop ? params.stop.split(',').filter((s) => s.trim()) : null;
 		Object.keys(info.params).forEach((key) => {
 			if (info.params[key] === '' || info.params[key] === null) {
@@ -147,6 +157,10 @@
 					toolIds = [...model?.info?.meta?.toolIds];
 				}
 
+				if (model?.info?.meta?.filterIds) {
+					filterIds = [...model?.info?.meta?.filterIds];
+				}
+
 				if (model?.owned_by === 'openai') {
 					capabilities.usage = false;
 				}
@@ -534,6 +548,13 @@
 				<ToolsSelector bind:selectedToolIds={toolIds} tools={$tools} />
 			</div>
 
+			<div class="my-2">
+				<FiltersSelector
+					bind:selectedFilterIds={filterIds}
+					filters={$functions.filter((func) => func.type === 'filter')}
+				/>
+			</div>
+
 			<div class="my-2">
 				<div class="flex w-full justify-between mb-1">
 					<div class=" self-center text-sm font-semibold">{$i18n.t('Capabilities')}</div>