Przeglądaj źródła

Merge pull request #3321 from open-webui/functions

feat: functions
Timothy Jaeryang Baek 10 miesięcy temu
rodzic
commit
09a81eb225

+ 4 - 13
backend/apps/ollama/main.py

@@ -53,7 +53,7 @@ from config import (
     UPLOAD_DIR,
     UPLOAD_DIR,
     AppConfig,
     AppConfig,
 )
 )
-from utils.misc import calculate_sha256
+from utils.misc import calculate_sha256, add_or_update_system_message
 
 
 log = logging.getLogger(__name__)
 log = logging.getLogger(__name__)
 log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
 log.setLevel(SRC_LOG_LEVELS["OLLAMA"])
@@ -834,18 +834,9 @@ async def generate_chat_completion(
             )
             )
 
 
             if payload.get("messages"):
             if payload.get("messages"):
-                for message in payload["messages"]:
-                    if message.get("role") == "system":
-                        message["content"] = system + message["content"]
-                        break
-                else:
-                    payload["messages"].insert(
-                        0,
-                        {
-                            "role": "system",
-                            "content": system,
-                        },
-                    )
+                payload["messages"] = add_or_update_system_message(
+                    system, payload["messages"]
+                )
 
 
     if url_idx == None:
     if url_idx == None:
         if ":" not in payload["model"]:
         if ":" not in payload["model"]:

+ 6 - 1
backend/apps/openai/main.py

@@ -432,7 +432,12 @@ async def generate_chat_completion(
     idx = model["urlIdx"]
     idx = model["urlIdx"]
 
 
     if "pipeline" in model and model.get("pipeline"):
     if "pipeline" in model and model.get("pipeline"):
-        payload["user"] = {"name": user.name, "id": user.id}
+        payload["user"] = {
+            "name": user.name,
+            "id": user.id,
+            "email": user.email,
+            "role": user.role,
+        }
 
 
     # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
     # Check if the model is "gpt-4-vision-preview" and set "max_tokens" to 4000
     # This is a workaround until OpenAI fixes the issue with this model
     # This is a workaround until OpenAI fixes the issue with this model

+ 61 - 0
backend/apps/webui/internal/migrations/015_add_functions.py

@@ -0,0 +1,61 @@
+"""Peewee migrations -- 009_add_models.py.
+
+Some examples (model - class or model name)::
+
+    > Model = migrator.orm['table_name']            # Return model in current state by name
+    > Model = migrator.ModelClass                   # Return model in current state by name
+
+    > migrator.sql(sql)                             # Run custom SQL
+    > migrator.run(func, *args, **kwargs)           # Run python function with the given args
+    > migrator.create_model(Model)                  # Create a model (could be used as decorator)
+    > migrator.remove_model(model, cascade=True)    # Remove a model
+    > migrator.add_fields(model, **fields)          # Add fields to a model
+    > migrator.change_fields(model, **fields)       # Change fields
+    > migrator.remove_fields(model, *field_names, cascade=True)
+    > migrator.rename_field(model, old_field_name, new_field_name)
+    > migrator.rename_table(model, new_table_name)
+    > migrator.add_index(model, *col_names, unique=False)
+    > migrator.add_not_null(model, *field_names)
+    > migrator.add_default(model, field_name, default)
+    > migrator.add_constraint(model, name, sql)
+    > migrator.drop_index(model, *col_names)
+    > migrator.drop_not_null(model, *field_names)
+    > migrator.drop_constraints(model, *constraints)
+
+"""
+
+from contextlib import suppress
+
+import peewee as pw
+from peewee_migrate import Migrator
+
+
+with suppress(ImportError):
+    import playhouse.postgres_ext as pw_pext
+
+
+def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
+    """Write your migrations here."""
+
+    @migrator.create_model
+    class Function(pw.Model):
+        id = pw.TextField(unique=True)
+        user_id = pw.TextField()
+
+        name = pw.TextField()
+        type = pw.TextField()
+
+        content = pw.TextField()
+        meta = pw.TextField()
+
+        created_at = pw.BigIntegerField(null=False)
+        updated_at = pw.BigIntegerField(null=False)
+
+        class Meta:
+            table_name = "function"
+
+
+def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
+    """Write your rollback migrations here."""
+
+    migrator.remove_model("function")

+ 66 - 4
backend/apps/webui/main.py

@@ -13,7 +13,11 @@ from apps.webui.routers import (
     memories,
     memories,
     utils,
     utils,
     files,
     files,
+    functions,
 )
 )
+from apps.webui.models.functions import Functions
+from apps.webui.utils import load_function_module_by_id
+
 from config import (
 from config import (
     WEBUI_BUILD_HASH,
     WEBUI_BUILD_HASH,
     SHOW_ADMIN_DETAILS,
     SHOW_ADMIN_DETAILS,
@@ -60,7 +64,7 @@ app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
 
 
 app.state.MODELS = {}
 app.state.MODELS = {}
 app.state.TOOLS = {}
 app.state.TOOLS = {}
-
+app.state.FUNCTIONS = {}
 
 
 app.add_middleware(
 app.add_middleware(
     CORSMiddleware,
     CORSMiddleware,
@@ -70,19 +74,22 @@ app.add_middleware(
     allow_headers=["*"],
     allow_headers=["*"],
 )
 )
 
 
+
+app.include_router(configs.router, prefix="/configs", tags=["configs"])
 app.include_router(auths.router, prefix="/auths", tags=["auths"])
 app.include_router(auths.router, prefix="/auths", tags=["auths"])
 app.include_router(users.router, prefix="/users", tags=["users"])
 app.include_router(users.router, prefix="/users", tags=["users"])
 app.include_router(chats.router, prefix="/chats", tags=["chats"])
 app.include_router(chats.router, prefix="/chats", tags=["chats"])
 
 
 app.include_router(documents.router, prefix="/documents", tags=["documents"])
 app.include_router(documents.router, prefix="/documents", tags=["documents"])
-app.include_router(tools.router, prefix="/tools", tags=["tools"])
 app.include_router(models.router, prefix="/models", tags=["models"])
 app.include_router(models.router, prefix="/models", tags=["models"])
 app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
 app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
+
 app.include_router(memories.router, prefix="/memories", tags=["memories"])
 app.include_router(memories.router, prefix="/memories", tags=["memories"])
+app.include_router(files.router, prefix="/files", tags=["files"])
+app.include_router(tools.router, prefix="/tools", tags=["tools"])
+app.include_router(functions.router, prefix="/functions", tags=["functions"])
 
 
-app.include_router(configs.router, prefix="/configs", tags=["configs"])
 app.include_router(utils.router, prefix="/utils", tags=["utils"])
 app.include_router(utils.router, prefix="/utils", tags=["utils"])
-app.include_router(files.router, prefix="/files", tags=["files"])
 
 
 
 
 @app.get("/")
 @app.get("/")
@@ -93,3 +100,58 @@ async def get_status():
         "default_models": app.state.config.DEFAULT_MODELS,
         "default_models": app.state.config.DEFAULT_MODELS,
         "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
         "default_prompt_suggestions": app.state.config.DEFAULT_PROMPT_SUGGESTIONS,
     }
     }
+
+
+async def get_pipe_models():
+    pipes = Functions.get_functions_by_type("pipe")
+    pipe_models = []
+
+    for pipe in pipes:
+        # Check if function is already loaded
+        if pipe.id not in app.state.FUNCTIONS:
+            function_module, function_type = load_function_module_by_id(pipe.id)
+            app.state.FUNCTIONS[pipe.id] = function_module
+        else:
+            function_module = app.state.FUNCTIONS[pipe.id]
+
+        # Check if function is a manifold
+        if hasattr(function_module, "type"):
+            if function_module.type == "manifold":
+                manifold_pipes = []
+
+                # Check if pipes is a function or a list
+                if callable(function_module.pipes):
+                    manifold_pipes = function_module.pipes()
+                else:
+                    manifold_pipes = function_module.pipes
+
+                for p in manifold_pipes:
+                    manifold_pipe_id = f'{pipe.id}.{p["id"]}'
+                    manifold_pipe_name = p["name"]
+
+                    if hasattr(function_module, "name"):
+                        manifold_pipe_name = f"{pipe.name}{manifold_pipe_name}"
+
+                    pipe_models.append(
+                        {
+                            "id": manifold_pipe_id,
+                            "name": manifold_pipe_name,
+                            "object": "model",
+                            "created": pipe.created_at,
+                            "owned_by": "openai",
+                            "pipe": {"type": pipe.type},
+                        }
+                    )
+        else:
+            pipe_models.append(
+                {
+                    "id": pipe.id,
+                    "name": pipe.name,
+                    "object": "model",
+                    "created": pipe.created_at,
+                    "owned_by": "openai",
+                    "pipe": {"type": "pipe"},
+                }
+            )
+
+    return pipe_models

+ 5 - 4
backend/apps/webui/models/functions.py

@@ -55,6 +55,7 @@ class FunctionModel(BaseModel):
 class FunctionResponse(BaseModel):
 class FunctionResponse(BaseModel):
     id: str
     id: str
     user_id: str
     user_id: str
+    type: str
     name: str
     name: str
     meta: FunctionMeta
     meta: FunctionMeta
     updated_at: int  # timestamp in epoch
     updated_at: int  # timestamp in epoch
@@ -64,23 +65,23 @@ class FunctionResponse(BaseModel):
 class FunctionForm(BaseModel):
 class FunctionForm(BaseModel):
     id: str
     id: str
     name: str
     name: str
-    type: str
     content: str
     content: str
     meta: FunctionMeta
     meta: FunctionMeta
 
 
 
 
-class ToolsTable:
+class FunctionsTable:
     def __init__(self, db):
     def __init__(self, db):
         self.db = db
         self.db = db
         self.db.create_tables([Function])
         self.db.create_tables([Function])
 
 
     def insert_new_function(
     def insert_new_function(
-        self, user_id: str, form_data: FunctionForm
+        self, user_id: str, type: str, form_data: FunctionForm
     ) -> Optional[FunctionModel]:
     ) -> Optional[FunctionModel]:
         function = FunctionModel(
         function = FunctionModel(
             **{
             **{
                 **form_data.model_dump(),
                 **form_data.model_dump(),
                 "user_id": user_id,
                 "user_id": user_id,
+                "type": type,
                 "updated_at": int(time.time()),
                 "updated_at": int(time.time()),
                 "created_at": int(time.time()),
                 "created_at": int(time.time()),
             }
             }
@@ -137,4 +138,4 @@ class ToolsTable:
             return False
             return False
 
 
 
 
-Tools = ToolsTable(DB)
+Functions = FunctionsTable(DB)

+ 180 - 0
backend/apps/webui/routers/functions.py

@@ -0,0 +1,180 @@
+from fastapi import Depends, FastAPI, HTTPException, status, Request
+from datetime import datetime, timedelta
+from typing import List, Union, Optional
+
+from fastapi import APIRouter
+from pydantic import BaseModel
+import json
+
+from apps.webui.models.functions import (
+    Functions,
+    FunctionForm,
+    FunctionModel,
+    FunctionResponse,
+)
+from apps.webui.utils import load_function_module_by_id
+from utils.utils import get_verified_user, get_admin_user
+from constants import ERROR_MESSAGES
+
+from importlib import util
+import os
+from pathlib import Path
+
+from config import DATA_DIR, CACHE_DIR, FUNCTIONS_DIR
+
+
+router = APIRouter()
+
+############################
+# GetFunctions
+############################
+
+
+@router.get("/", response_model=List[FunctionResponse])
+async def get_functions(user=Depends(get_verified_user)):
+    return Functions.get_functions()
+
+
+############################
+# ExportFunctions
+############################
+
+
+@router.get("/export", response_model=List[FunctionModel])
+async def get_functions(user=Depends(get_admin_user)):
+    return Functions.get_functions()
+
+
+############################
+# CreateNewFunction
+############################
+
+
+@router.post("/create", response_model=Optional[FunctionResponse])
+async def create_new_function(
+    request: Request, form_data: FunctionForm, user=Depends(get_admin_user)
+):
+    if not form_data.id.isidentifier():
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail="Only alphanumeric characters and underscores are allowed in the id",
+        )
+
+    form_data.id = form_data.id.lower()
+
+    function = Functions.get_function_by_id(form_data.id)
+    if function == None:
+        function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py")
+        try:
+            with open(function_path, "w") as function_file:
+                function_file.write(form_data.content)
+
+            function_module, function_type = load_function_module_by_id(form_data.id)
+
+            FUNCTIONS = request.app.state.FUNCTIONS
+            FUNCTIONS[form_data.id] = function_module
+
+            function = Functions.insert_new_function(user.id, function_type, form_data)
+
+            function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id
+            function_cache_dir.mkdir(parents=True, exist_ok=True)
+
+            if function:
+                return function
+            else:
+                raise HTTPException(
+                    status_code=status.HTTP_400_BAD_REQUEST,
+                    detail=ERROR_MESSAGES.DEFAULT("Error creating function"),
+                )
+        except Exception as e:
+            print(e)
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT(e),
+            )
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.ID_TAKEN,
+        )
+
+
+############################
+# GetFunctionById
+############################
+
+
+@router.get("/id/{id}", response_model=Optional[FunctionModel])
+async def get_function_by_id(id: str, user=Depends(get_admin_user)):
+    function = Functions.get_function_by_id(id)
+
+    if function:
+        return function
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+
+############################
+# UpdateFunctionById
+############################
+
+
+@router.post("/id/{id}/update", response_model=Optional[FunctionModel])
+async def update_toolkit_by_id(
+    request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user)
+):
+    function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
+
+    try:
+        with open(function_path, "w") as function_file:
+            function_file.write(form_data.content)
+
+        function_module, function_type = load_function_module_by_id(id)
+
+        FUNCTIONS = request.app.state.FUNCTIONS
+        FUNCTIONS[id] = function_module
+
+        updated = {**form_data.model_dump(exclude={"id"}), "type": function_type}
+        print(updated)
+
+        function = Functions.update_function_by_id(id, updated)
+
+        if function:
+            return function
+        else:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
+            )
+
+    except Exception as e:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.DEFAULT(e),
+        )
+
+
+############################
+# DeleteFunctionById
+############################
+
+
+@router.delete("/id/{id}/delete", response_model=bool)
+async def delete_function_by_id(
+    request: Request, id: str, user=Depends(get_admin_user)
+):
+    result = Functions.delete_function_by_id(id)
+
+    if result:
+        FUNCTIONS = request.app.state.FUNCTIONS
+        if id in FUNCTIONS:
+            del FUNCTIONS[id]
+
+        # delete the function file
+        function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
+        os.remove(function_path)
+
+    return result

+ 23 - 1
backend/apps/webui/utils.py

@@ -1,7 +1,7 @@
 from importlib import util
 from importlib import util
 import os
 import os
 
 
-from config import TOOLS_DIR
+from config import TOOLS_DIR, FUNCTIONS_DIR
 
 
 
 
 def load_toolkit_module_by_id(toolkit_id):
 def load_toolkit_module_by_id(toolkit_id):
@@ -21,3 +21,25 @@ def load_toolkit_module_by_id(toolkit_id):
         # Move the file to the error folder
         # Move the file to the error folder
         os.rename(toolkit_path, f"{toolkit_path}.error")
         os.rename(toolkit_path, f"{toolkit_path}.error")
         raise e
         raise e
+
+
+def load_function_module_by_id(function_id):
+    function_path = os.path.join(FUNCTIONS_DIR, f"{function_id}.py")
+
+    spec = util.spec_from_file_location(function_id, function_path)
+    module = util.module_from_spec(spec)
+
+    try:
+        spec.loader.exec_module(module)
+        print(f"Loaded module: {module.__name__}")
+        if hasattr(module, "Pipe"):
+            return module.Pipe(), "pipe"
+        elif hasattr(module, "Filter"):
+            return module.Filter(), "filter"
+        else:
+            raise Exception("No Function class found")
+    except Exception as e:
+        print(f"Error loading module: {function_id}")
+        # Move the file to the error folder
+        os.rename(function_path, f"{function_path}.error")
+        raise e

+ 8 - 0
backend/config.py

@@ -377,6 +377,14 @@ TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
 Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
 Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
 
 
 
 
+####################################
+# Functions DIR
+####################################
+
+FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions")
+Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True)
+
+
 ####################################
 ####################################
 # LITELLM_CONFIG
 # LITELLM_CONFIG
 ####################################
 ####################################

+ 342 - 124
backend/main.py

@@ -15,6 +15,7 @@ import uuid
 import inspect
 import inspect
 import asyncio
 import asyncio
 
 
+from fastapi.concurrency import run_in_threadpool
 from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
 from fastapi import FastAPI, Request, Depends, status, UploadFile, File, Form
 from fastapi.staticfiles import StaticFiles
 from fastapi.staticfiles import StaticFiles
 from fastapi.responses import JSONResponse
 from fastapi.responses import JSONResponse
@@ -42,15 +43,17 @@ from apps.openai.main import (
 from apps.audio.main import app as audio_app
 from apps.audio.main import app as audio_app
 from apps.images.main import app as images_app
 from apps.images.main import app as images_app
 from apps.rag.main import app as rag_app
 from apps.rag.main import app as rag_app
-from apps.webui.main import app as webui_app
+from apps.webui.main import app as webui_app, get_pipe_models
 
 
 
 
 from pydantic import BaseModel
 from pydantic import BaseModel
-from typing import List, Optional
+from typing import List, Optional, Iterator, Generator, Union
 
 
 from apps.webui.models.models import Models, ModelModel
 from apps.webui.models.models import Models, ModelModel
 from apps.webui.models.tools import Tools
 from apps.webui.models.tools import Tools
-from apps.webui.utils import load_toolkit_module_by_id
+from apps.webui.models.functions import Functions
+
+from apps.webui.utils import load_toolkit_module_by_id, load_function_module_by_id
 
 
 
 
 from utils.utils import (
 from utils.utils import (
@@ -64,7 +67,11 @@ from utils.task import (
     search_query_generation_template,
     search_query_generation_template,
     tools_function_calling_generation_template,
     tools_function_calling_generation_template,
 )
 )
-from utils.misc import get_last_user_message, add_or_update_system_message
+from utils.misc import (
+    get_last_user_message,
+    add_or_update_system_message,
+    stream_message_template,
+)
 
 
 from apps.rag.utils import get_rag_context, rag_template
 from apps.rag.utils import get_rag_context, rag_template
 
 
@@ -170,6 +177,13 @@ app.state.MODELS = {}
 origins = ["*"]
 origins = ["*"]
 
 
 
 
+##################################
+#
+# ChatCompletion Middleware
+#
+##################################
+
+
 async def get_function_call_response(
 async def get_function_call_response(
     messages, files, tool_id, template, task_model_id, user
     messages, files, tool_id, template, task_model_id, user
 ):
 ):
@@ -309,41 +323,72 @@ async def get_function_call_response(
 
 
 class ChatCompletionMiddleware(BaseHTTPMiddleware):
 class ChatCompletionMiddleware(BaseHTTPMiddleware):
     async def dispatch(self, request: Request, call_next):
     async def dispatch(self, request: Request, call_next):
-        return_citations = False
+        data_items = []
 
 
-        if request.method == "POST" and (
-            "/ollama/api/chat" in request.url.path
-            or "/chat/completions" in request.url.path
+        if request.method == "POST" and any(
+            endpoint in request.url.path
+            for endpoint in ["/ollama/api/chat", "/chat/completions"]
         ):
         ):
             log.debug(f"request.url.path: {request.url.path}")
             log.debug(f"request.url.path: {request.url.path}")
 
 
             # Read the original request body
             # Read the original request body
             body = await request.body()
             body = await request.body()
-            # Decode body to string
             body_str = body.decode("utf-8")
             body_str = body.decode("utf-8")
-            # Parse string to JSON
             data = json.loads(body_str) if body_str else {}
             data = json.loads(body_str) if body_str else {}
 
 
             user = get_current_user(
             user = get_current_user(
                 request,
                 request,
                 get_http_authorization_cred(request.headers.get("Authorization")),
                 get_http_authorization_cred(request.headers.get("Authorization")),
             )
             )
+            # Flag to skip RAG completions if file_handler is present in tools/functions
+            skip_files = False
 
 
-            # Remove the citations from the body
-            return_citations = data.get("citations", False)
-            if "citations" in data:
-                del data["citations"]
-
-            # Set the task model
-            task_model_id = data["model"]
-            if task_model_id not in app.state.MODELS:
+            model_id = data["model"]
+            if model_id not in app.state.MODELS:
                 raise HTTPException(
                 raise HTTPException(
                     status_code=status.HTTP_404_NOT_FOUND,
                     status_code=status.HTTP_404_NOT_FOUND,
                     detail="Model not found",
                     detail="Model not found",
                 )
                 )
+            model = app.state.MODELS[model_id]
+
+            # Check if the model has any filters
+            if "info" in model and "meta" in model["info"]:
+                for filter_id in model["info"]["meta"].get("filterIds", []):
+                    filter = Functions.get_function_by_id(filter_id)
+                    if filter:
+                        if filter_id in webui_app.state.FUNCTIONS:
+                            function_module = webui_app.state.FUNCTIONS[filter_id]
+                        else:
+                            function_module, function_type = load_function_module_by_id(
+                                filter_id
+                            )
+                            webui_app.state.FUNCTIONS[filter_id] = function_module
+
+                        # Check if the function has a file_handler variable
+                        if getattr(function_module, "file_handler"):
+                            skip_files = True
 
 
-            # Check if the user has a custom task model
-            # If the user has a custom task model, use that model
+                        try:
+                            if hasattr(function_module, "inlet"):
+                                data = function_module.inlet(
+                                    data,
+                                    {
+                                        "id": user.id,
+                                        "email": user.email,
+                                        "name": user.name,
+                                        "role": user.role,
+                                    },
+                                )
+                        except Exception as e:
+                            print(f"Error: {e}")
+                            return JSONResponse(
+                                status_code=status.HTTP_400_BAD_REQUEST,
+                                content={"detail": str(e)},
+                            )
+
+            # Set the task model
+            task_model_id = data["model"]
+            # Check if the user has a custom task model and use that model
             if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
             if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
                 if (
                 if (
                     app.state.config.TASK_MODEL
                     app.state.config.TASK_MODEL
@@ -361,8 +406,6 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
             context = ""
             context = ""
 
 
             # If tool_ids field is present, call the functions
             # If tool_ids field is present, call the functions
-
-            skip_files = False
             if "tool_ids" in data:
             if "tool_ids" in data:
                 print(data["tool_ids"])
                 print(data["tool_ids"])
                 for tool_id in data["tool_ids"]:
                 for tool_id in data["tool_ids"]:
@@ -408,18 +451,22 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                         context += ("\n" if context != "" else "") + rag_context
                         context += ("\n" if context != "" else "") + rag_context
 
 
                     log.debug(f"rag_context: {rag_context}, citations: {citations}")
                     log.debug(f"rag_context: {rag_context}, citations: {citations}")
-                else:
-                    return_citations = False
+
+                    if citations and data.get("citations"):
+                        data_items.append({"citations": citations})
 
 
                 del data["files"]
                 del data["files"]
 
 
+            if data.get("citations"):
+                del data["citations"]
+
             if context != "":
             if context != "":
                 system_prompt = rag_template(
                 system_prompt = rag_template(
                     rag_app.state.config.RAG_TEMPLATE, context, prompt
                     rag_app.state.config.RAG_TEMPLATE, context, prompt
                 )
                 )
                 print(system_prompt)
                 print(system_prompt)
                 data["messages"] = add_or_update_system_message(
                 data["messages"] = add_or_update_system_message(
-                    f"\n{system_prompt}", data["messages"]
+                    system_prompt, data["messages"]
                 )
                 )
 
 
             modified_body_bytes = json.dumps(data).encode("utf-8")
             modified_body_bytes = json.dumps(data).encode("utf-8")
@@ -435,40 +482,51 @@ class ChatCompletionMiddleware(BaseHTTPMiddleware):
                 ],
                 ],
             ]
             ]
 
 
-        response = await call_next(request)
-
-        if return_citations:
-            # Inject the citations into the response
+            response = await call_next(request)
             if isinstance(response, StreamingResponse):
             if isinstance(response, StreamingResponse):
                 # If it's a streaming response, inject it as SSE event or NDJSON line
                 # If it's a streaming response, inject it as SSE event or NDJSON line
                 content_type = response.headers.get("Content-Type")
                 content_type = response.headers.get("Content-Type")
                 if "text/event-stream" in content_type:
                 if "text/event-stream" in content_type:
                     return StreamingResponse(
                     return StreamingResponse(
-                        self.openai_stream_wrapper(response.body_iterator, citations),
+                        self.openai_stream_wrapper(response.body_iterator, data_items),
                     )
                     )
                 if "application/x-ndjson" in content_type:
                 if "application/x-ndjson" in content_type:
                     return StreamingResponse(
                     return StreamingResponse(
-                        self.ollama_stream_wrapper(response.body_iterator, citations),
+                        self.ollama_stream_wrapper(response.body_iterator, data_items),
                     )
                     )
+            else:
+                return response
 
 
+        # If it's not a chat completion request, just pass it through
+        response = await call_next(request)
         return response
         return response
 
 
     async def _receive(self, body: bytes):
     async def _receive(self, body: bytes):
         return {"type": "http.request", "body": body, "more_body": False}
         return {"type": "http.request", "body": body, "more_body": False}
 
 
-    async def openai_stream_wrapper(self, original_generator, citations):
-        yield f"data: {json.dumps({'citations': citations})}\n\n"
+    async def openai_stream_wrapper(self, original_generator, data_items):
+        for item in data_items:
+            yield f"data: {json.dumps(item)}\n\n"
+
         async for data in original_generator:
         async for data in original_generator:
             yield data
             yield data
 
 
-    async def ollama_stream_wrapper(self, original_generator, citations):
-        yield f"{json.dumps({'citations': citations})}\n"
+    async def ollama_stream_wrapper(self, original_generator, data_items):
+        for item in data_items:
+            yield f"{json.dumps(item)}\n"
+
         async for data in original_generator:
         async for data in original_generator:
             yield data
             yield data
 
 
 
 
 app.add_middleware(ChatCompletionMiddleware)
 app.add_middleware(ChatCompletionMiddleware)
 
 
+##################################
+#
+# Pipeline Middleware
+#
+##################################
+
 
 
 def filter_pipeline(payload, user):
 def filter_pipeline(payload, user):
     user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
     user = {"id": user.id, "email": user.email, "name": user.name, "role": user.role}
@@ -628,7 +686,6 @@ async def update_embedding_function(request: Request, call_next):
 
 
 app.mount("/ws", socket_app)
 app.mount("/ws", socket_app)
 
 
-
 app.mount("/ollama", ollama_app)
 app.mount("/ollama", ollama_app)
 app.mount("/openai", openai_app)
 app.mount("/openai", openai_app)
 
 
@@ -642,17 +699,18 @@ webui_app.state.EMBEDDING_FUNCTION = rag_app.state.EMBEDDING_FUNCTION
 
 
 
 
 async def get_all_models():
 async def get_all_models():
+    pipe_models = []
     openai_models = []
     openai_models = []
     ollama_models = []
     ollama_models = []
 
 
+    pipe_models = await get_pipe_models()
+
     if app.state.config.ENABLE_OPENAI_API:
     if app.state.config.ENABLE_OPENAI_API:
         openai_models = await get_openai_models()
         openai_models = await get_openai_models()
-
         openai_models = openai_models["data"]
         openai_models = openai_models["data"]
 
 
     if app.state.config.ENABLE_OLLAMA_API:
     if app.state.config.ENABLE_OLLAMA_API:
         ollama_models = await get_ollama_models()
         ollama_models = await get_ollama_models()
-
         ollama_models = [
         ollama_models = [
             {
             {
                 "id": model["model"],
                 "id": model["model"],
@@ -665,9 +723,9 @@ async def get_all_models():
             for model in ollama_models["models"]
             for model in ollama_models["models"]
         ]
         ]
 
 
-    models = openai_models + ollama_models
-    custom_models = Models.get_all_models()
+    models = pipe_models + openai_models + ollama_models
 
 
+    custom_models = Models.get_all_models()
     for custom_model in custom_models:
     for custom_model in custom_models:
         if custom_model.base_model_id == None:
         if custom_model.base_model_id == None:
             for model in models:
             for model in models:
@@ -730,6 +788,234 @@ async def get_models(user=Depends(get_verified_user)):
     return {"data": models}
     return {"data": models}
 
 
 
 
+@app.post("/api/chat/completions")
+async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
+    model_id = form_data["model"]
+    if model_id not in app.state.MODELS:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail="Model not found",
+        )
+
+    model = app.state.MODELS[model_id]
+    print(model)
+
+    pipe = model.get("pipe")
+    if pipe:
+        form_data["user"] = {
+            "id": user.id,
+            "email": user.email,
+            "name": user.name,
+            "role": user.role,
+        }
+
+        def job():
+            pipe_id = form_data["model"]
+            if "." in pipe_id:
+                pipe_id, sub_pipe_id = pipe_id.split(".", 1)
+            print(pipe_id)
+
+            pipe = webui_app.state.FUNCTIONS[pipe_id].pipe
+            if form_data["stream"]:
+
+                def stream_content():
+                    res = pipe(body=form_data)
+
+                    if isinstance(res, str):
+                        message = stream_message_template(form_data["model"], res)
+                        yield f"data: {json.dumps(message)}\n\n"
+
+                    if isinstance(res, Iterator):
+                        for line in res:
+                            if isinstance(line, BaseModel):
+                                line = line.model_dump_json()
+                                line = f"data: {line}"
+                            try:
+                                line = line.decode("utf-8")
+                            except:
+                                pass
+
+                            if line.startswith("data:"):
+                                yield f"{line}\n\n"
+                            else:
+                                line = stream_message_template(form_data["model"], line)
+                                yield f"data: {json.dumps(line)}\n\n"
+
+                    if isinstance(res, str) or isinstance(res, Generator):
+                        finish_message = {
+                            "id": f"{form_data['model']}-{str(uuid.uuid4())}",
+                            "object": "chat.completion.chunk",
+                            "created": int(time.time()),
+                            "model": form_data["model"],
+                            "choices": [
+                                {
+                                    "index": 0,
+                                    "delta": {},
+                                    "logprobs": None,
+                                    "finish_reason": "stop",
+                                }
+                            ],
+                        }
+
+                        yield f"data: {json.dumps(finish_message)}\n\n"
+                        yield f"data: [DONE]"
+
+                return StreamingResponse(
+                    stream_content(), media_type="text/event-stream"
+                )
+            else:
+                res = pipe(body=form_data)
+
+                if isinstance(res, dict):
+                    return res
+                elif isinstance(res, BaseModel):
+                    return res.model_dump()
+                else:
+                    message = ""
+                    if isinstance(res, str):
+                        message = res
+                    if isinstance(res, Generator):
+                        for stream in res:
+                            message = f"{message}{stream}"
+
+                    return {
+                        "id": f"{form_data['model']}-{str(uuid.uuid4())}",
+                        "object": "chat.completion",
+                        "created": int(time.time()),
+                        "model": form_data["model"],
+                        "choices": [
+                            {
+                                "index": 0,
+                                "message": {
+                                    "role": "assistant",
+                                    "content": message,
+                                },
+                                "logprobs": None,
+                                "finish_reason": "stop",
+                            }
+                        ],
+                    }
+
+        return await run_in_threadpool(job)
+    if model["owned_by"] == "ollama":
+        return await generate_ollama_chat_completion(form_data, user=user)
+    else:
+        return await generate_openai_chat_completion(form_data, user=user)
+
+
+@app.post("/api/chat/completed")
+async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
+    data = form_data
+    model_id = data["model"]
+    if model_id not in app.state.MODELS:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail="Model not found",
+        )
+    model = app.state.MODELS[model_id]
+
+    filters = [
+        model
+        for model in app.state.MODELS.values()
+        if "pipeline" in model
+        and "type" in model["pipeline"]
+        and model["pipeline"]["type"] == "filter"
+        and (
+            model["pipeline"]["pipelines"] == ["*"]
+            or any(
+                model_id == target_model_id
+                for target_model_id in model["pipeline"]["pipelines"]
+            )
+        )
+    ]
+
+    sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
+    if "pipeline" in model:
+        sorted_filters = [model] + sorted_filters
+
+    for filter in sorted_filters:
+        r = None
+        try:
+            urlIdx = filter["urlIdx"]
+
+            url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
+            key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
+
+            if key != "":
+                headers = {"Authorization": f"Bearer {key}"}
+                r = requests.post(
+                    f"{url}/{filter['id']}/filter/outlet",
+                    headers=headers,
+                    json={
+                        "user": {"id": user.id, "name": user.name, "role": user.role},
+                        "body": data,
+                    },
+                )
+
+                r.raise_for_status()
+                data = r.json()
+        except Exception as e:
+            # Handle connection error here
+            print(f"Connection error: {e}")
+
+            if r is not None:
+                try:
+                    res = r.json()
+                    if "detail" in res:
+                        return JSONResponse(
+                            status_code=r.status_code,
+                            content=res,
+                        )
+                except:
+                    pass
+
+            else:
+                pass
+
+    # Check if the model has any filters
+    if "info" in model and "meta" in model["info"]:
+        for filter_id in model["info"]["meta"].get("filterIds", []):
+            filter = Functions.get_function_by_id(filter_id)
+            if filter:
+                if filter_id in webui_app.state.FUNCTIONS:
+                    function_module = webui_app.state.FUNCTIONS[filter_id]
+                else:
+                    function_module, function_type = load_function_module_by_id(
+                        filter_id
+                    )
+                    webui_app.state.FUNCTIONS[filter_id] = function_module
+
+                try:
+                    if hasattr(function_module, "outlet"):
+                        data = function_module.outlet(
+                            data,
+                            {
+                                "id": user.id,
+                                "email": user.email,
+                                "name": user.name,
+                                "role": user.role,
+                            },
+                        )
+                except Exception as e:
+                    print(f"Error: {e}")
+                    return JSONResponse(
+                        status_code=status.HTTP_400_BAD_REQUEST,
+                        content={"detail": str(e)},
+                    )
+
+    return data
+
+
+##################################
+#
+# Task Endpoints
+#
+##################################
+
+
+# TODO: Refactor task API endpoints below into a separate file
+
+
 @app.get("/api/task/config")
 @app.get("/api/task/config")
 async def get_task_config(user=Depends(get_verified_user)):
 async def get_task_config(user=Depends(get_verified_user)):
     return {
     return {
@@ -1015,92 +1301,14 @@ async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_
         )
         )
 
 
 
 
-@app.post("/api/chat/completions")
-async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
-    model_id = form_data["model"]
-    if model_id not in app.state.MODELS:
-        raise HTTPException(
-            status_code=status.HTTP_404_NOT_FOUND,
-            detail="Model not found",
-        )
+##################################
+#
+# Pipelines Endpoints
+#
+##################################
 
 
-    model = app.state.MODELS[model_id]
-    print(model)
 
 
-    if model["owned_by"] == "ollama":
-        return await generate_ollama_chat_completion(form_data, user=user)
-    else:
-        return await generate_openai_chat_completion(form_data, user=user)
-
-
-@app.post("/api/chat/completed")
-async def chat_completed(form_data: dict, user=Depends(get_verified_user)):
-    data = form_data
-    model_id = data["model"]
-
-    filters = [
-        model
-        for model in app.state.MODELS.values()
-        if "pipeline" in model
-        and "type" in model["pipeline"]
-        and model["pipeline"]["type"] == "filter"
-        and (
-            model["pipeline"]["pipelines"] == ["*"]
-            or any(
-                model_id == target_model_id
-                for target_model_id in model["pipeline"]["pipelines"]
-            )
-        )
-    ]
-    sorted_filters = sorted(filters, key=lambda x: x["pipeline"]["priority"])
-
-    print(model_id)
-
-    if model_id in app.state.MODELS:
-        model = app.state.MODELS[model_id]
-        if "pipeline" in model:
-            sorted_filters = [model] + sorted_filters
-
-    for filter in sorted_filters:
-        r = None
-        try:
-            urlIdx = filter["urlIdx"]
-
-            url = openai_app.state.config.OPENAI_API_BASE_URLS[urlIdx]
-            key = openai_app.state.config.OPENAI_API_KEYS[urlIdx]
-
-            if key != "":
-                headers = {"Authorization": f"Bearer {key}"}
-                r = requests.post(
-                    f"{url}/{filter['id']}/filter/outlet",
-                    headers=headers,
-                    json={
-                        "user": {"id": user.id, "name": user.name, "role": user.role},
-                        "body": data,
-                    },
-                )
-
-                r.raise_for_status()
-                data = r.json()
-        except Exception as e:
-            # Handle connection error here
-            print(f"Connection error: {e}")
-
-            if r is not None:
-                try:
-                    res = r.json()
-                    if "detail" in res:
-                        return JSONResponse(
-                            status_code=r.status_code,
-                            content=res,
-                        )
-                except:
-                    pass
-
-            else:
-                pass
-
-    return data
+# TODO: Refactor pipelines API endpoints below into a separate file
 
 
 
 
 @app.get("/api/pipelines/list")
 @app.get("/api/pipelines/list")
@@ -1423,6 +1631,13 @@ async def update_pipeline_valves(
         )
         )
 
 
 
 
+##################################
+#
+# Config Endpoints
+#
+##################################
+
+
 @app.get("/api/config")
 @app.get("/api/config")
 async def get_app_config():
 async def get_app_config():
     # Checking and Handling the Absence of 'ui' in CONFIG_DATA
     # Checking and Handling the Absence of 'ui' in CONFIG_DATA
@@ -1486,6 +1701,9 @@ async def update_model_filter_config(
     }
     }
 
 
 
 
+# TODO: webhook endpoint should be under config endpoints
+
+
 @app.get("/api/webhook")
 @app.get("/api/webhook")
 async def get_webhook_url(user=Depends(get_admin_user)):
 async def get_webhook_url(user=Depends(get_admin_user)):
     return {
     return {

+ 19 - 0
backend/utils/misc.py

@@ -4,6 +4,8 @@ import json
 import re
 import re
 from datetime import timedelta
 from datetime import timedelta
 from typing import Optional, List, Tuple
 from typing import Optional, List, Tuple
+import uuid
+import time
 
 
 
 
 def get_last_user_message(messages: List[dict]) -> str:
 def get_last_user_message(messages: List[dict]) -> str:
@@ -62,6 +64,23 @@ def add_or_update_system_message(content: str, messages: List[dict]):
     return messages
     return messages
 
 
 
 
+def stream_message_template(model: str, message: str):
+    return {
+        "id": f"{model}-{str(uuid.uuid4())}",
+        "object": "chat.completion.chunk",
+        "created": int(time.time()),
+        "model": model,
+        "choices": [
+            {
+                "index": 0,
+                "delta": {"content": message},
+                "logprobs": None,
+                "finish_reason": None,
+            }
+        ],
+    }
+
+
 def get_gravatar_url(email):
 def get_gravatar_url(email):
     # Trim leading and trailing whitespace from
     # Trim leading and trailing whitespace from
     # an email address and force all characters
     # an email address and force all characters

+ 193 - 0
src/lib/apis/functions/index.ts

@@ -0,0 +1,193 @@
+import { WEBUI_API_BASE_URL } from '$lib/constants';
+
+export const createNewFunction = async (token: string, func: object) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/functions/create`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			...func
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			error = err.detail;
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const getFunctions = async (token: string = '') => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/functions/`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const exportFunctions = async (token: string = '') => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/functions/export`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const getFunctionById = async (token: string, id: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const updateFunctionById = async (token: string, id: string, func: object) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/update`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			...func
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const deleteFunctionById = async (token: string, id: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/functions/id/${id}/delete`, {
+		method: 'DELETE',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};

+ 14 - 4
src/lib/components/chat/Chat.svelte

@@ -278,7 +278,9 @@
 			})),
 			})),
 			chat_id: $chatId
 			chat_id: $chatId
 		}).catch((error) => {
 		}).catch((error) => {
-			console.error(error);
+			toast.error(error);
+			messages.at(-1).error = { content: error };
+
 			return null;
 			return null;
 		});
 		});
 
 
@@ -323,6 +325,13 @@
 		} else if (messages.length != 0 && messages.at(-1).done != true) {
 		} else if (messages.length != 0 && messages.at(-1).done != true) {
 			// Response not done
 			// Response not done
 			console.log('wait');
 			console.log('wait');
+		} else if (messages.length != 0 && messages.at(-1).error) {
+			// Error in response
+			toast.error(
+				$i18n.t(
+					`Oops! There was an error in the previous response. Please try again or contact admin.`
+				)
+			);
 		} else if (
 		} else if (
 			files.length > 0 &&
 			files.length > 0 &&
 			files.filter((file) => file.type !== 'image' && file.status !== 'processed').length > 0
 			files.filter((file) => file.type !== 'image' && file.status !== 'processed').length > 0
@@ -630,7 +639,7 @@
 			keep_alive: $settings.keepAlive ?? undefined,
 			keep_alive: $settings.keepAlive ?? undefined,
 			tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
 			tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
 			files: files.length > 0 ? files : undefined,
 			files: files.length > 0 ? files : undefined,
-			citations: files.length > 0,
+			citations: files.length > 0 ? true : undefined,
 			chat_id: $chatId
 			chat_id: $chatId
 		});
 		});
 
 
@@ -928,10 +937,11 @@
 					max_tokens: $settings?.params?.max_tokens ?? undefined,
 					max_tokens: $settings?.params?.max_tokens ?? undefined,
 					tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
 					tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
 					files: files.length > 0 ? files : undefined,
 					files: files.length > 0 ? files : undefined,
-					citations: files.length > 0,
+					citations: files.length > 0 ? true : undefined,
+
 					chat_id: $chatId
 					chat_id: $chatId
 				},
 				},
-				`${OPENAI_API_BASE_URL}`
+				`${WEBUI_BASE_URL}/api`
 			);
 			);
 
 
 			// Wait until history/message have been updated
 			// Wait until history/message have been updated

+ 73 - 54
src/lib/components/workspace/Functions.svelte

@@ -3,25 +3,27 @@
 	import fileSaver from 'file-saver';
 	import fileSaver from 'file-saver';
 	const { saveAs } = fileSaver;
 	const { saveAs } = fileSaver;
 
 
+	import { WEBUI_NAME, functions, models } from '$lib/stores';
 	import { onMount, getContext } from 'svelte';
 	import { onMount, getContext } from 'svelte';
-	import { WEBUI_NAME, prompts, tools } from '$lib/stores';
 	import { createNewPrompt, deletePromptByCommand, getPrompts } from '$lib/apis/prompts';
 	import { createNewPrompt, deletePromptByCommand, getPrompts } from '$lib/apis/prompts';
 
 
 	import { goto } from '$app/navigation';
 	import { goto } from '$app/navigation';
 	import {
 	import {
-		createNewTool,
-		deleteToolById,
-		exportTools,
-		getToolById,
-		getTools
-	} from '$lib/apis/tools';
+		createNewFunction,
+		deleteFunctionById,
+		exportFunctions,
+		getFunctionById,
+		getFunctions
+	} from '$lib/apis/functions';
+
 	import ArrowDownTray from '../icons/ArrowDownTray.svelte';
 	import ArrowDownTray from '../icons/ArrowDownTray.svelte';
 	import Tooltip from '../common/Tooltip.svelte';
 	import Tooltip from '../common/Tooltip.svelte';
 	import ConfirmDialog from '../common/ConfirmDialog.svelte';
 	import ConfirmDialog from '../common/ConfirmDialog.svelte';
+	import { getModels } from '$lib/apis';
 
 
 	const i18n = getContext('i18n');
 	const i18n = getContext('i18n');
 
 
-	let toolsImportInputElement: HTMLInputElement;
+	let functionsImportInputElement: HTMLInputElement;
 	let importFiles;
 	let importFiles;
 
 
 	let showConfirm = false;
 	let showConfirm = false;
@@ -64,7 +66,7 @@
 	<div>
 	<div>
 		<a
 		<a
 			class=" px-2 py-2 rounded-xl border border-gray-200 dark:border-gray-600 dark:border-0 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 transition font-medium text-sm flex items-center space-x-1"
 			class=" px-2 py-2 rounded-xl border border-gray-200 dark:border-gray-600 dark:border-0 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 transition font-medium text-sm flex items-center space-x-1"
-			href="/workspace/tools/create"
+			href="/workspace/functions/create"
 		>
 		>
 			<svg
 			<svg
 				xmlns="http://www.w3.org/2000/svg"
 				xmlns="http://www.w3.org/2000/svg"
@@ -82,30 +84,40 @@
 <hr class=" dark:border-gray-850 my-2.5" />
 <hr class=" dark:border-gray-850 my-2.5" />
 
 
 <div class="my-3 mb-5">
 <div class="my-3 mb-5">
-	{#each $tools.filter((t) => query === '' || t.name
+	{#each $functions.filter((f) => query === '' || f.name
 				.toLowerCase()
 				.toLowerCase()
-				.includes(query.toLowerCase()) || t.id.toLowerCase().includes(query.toLowerCase())) as tool}
+				.includes(query.toLowerCase()) || f.id.toLowerCase().includes(query.toLowerCase())) as func}
 		<button
 		<button
 			class=" flex space-x-4 cursor-pointer w-full px-3 py-2 dark:hover:bg-white/5 hover:bg-black/5 rounded-xl"
 			class=" flex space-x-4 cursor-pointer w-full px-3 py-2 dark:hover:bg-white/5 hover:bg-black/5 rounded-xl"
 			type="button"
 			type="button"
 			on:click={() => {
 			on:click={() => {
-				goto(`/workspace/tools/edit?id=${encodeURIComponent(tool.id)}`);
+				goto(`/workspace/functions/edit?id=${encodeURIComponent(func.id)}`);
 			}}
 			}}
 		>
 		>
 			<div class=" flex flex-1 space-x-4 cursor-pointer w-full">
 			<div class=" flex flex-1 space-x-4 cursor-pointer w-full">
 				<a
 				<a
-					href={`/workspace/tools/edit?id=${encodeURIComponent(tool.id)}`}
+					href={`/workspace/functions/edit?id=${encodeURIComponent(func.id)}`}
 					class="flex items-center text-left"
 					class="flex items-center text-left"
 				>
 				>
-					<div class=" flex-1 self-center pl-5">
+					<div class=" flex-1 self-center pl-1">
 						<div class=" font-semibold flex items-center gap-1.5">
 						<div class=" font-semibold flex items-center gap-1.5">
+							<div
+								class=" text-xs font-black px-1 rounded uppercase line-clamp-1 bg-gray-500/20 text-gray-700 dark:text-gray-200"
+							>
+								{func.type}
+							</div>
+
 							<div>
 							<div>
-								{tool.name}
+								{func.name}
 							</div>
 							</div>
-							<div class=" text-gray-500 text-xs font-medium">{tool.id}</div>
 						</div>
 						</div>
-						<div class=" text-xs overflow-hidden text-ellipsis line-clamp-1">
-							{tool.meta.description}
+
+						<div class="flex gap-1.5 px-1">
+							<div class=" text-gray-500 text-xs font-medium">{func.id}</div>
+
+							<div class=" text-xs overflow-hidden text-ellipsis line-clamp-1">
+								{func.meta.description}
+							</div>
 						</div>
 						</div>
 					</div>
 					</div>
 				</a>
 				</a>
@@ -115,7 +127,7 @@
 					<a
 					<a
 						class="self-center w-fit text-sm px-2 py-2 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl"
 						class="self-center w-fit text-sm px-2 py-2 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl"
 						type="button"
 						type="button"
-						href={`/workspace/tools/edit?id=${encodeURIComponent(tool.id)}`}
+						href={`/workspace/functions/edit?id=${encodeURIComponent(func.id)}`}
 					>
 					>
 						<svg
 						<svg
 							xmlns="http://www.w3.org/2000/svg"
 							xmlns="http://www.w3.org/2000/svg"
@@ -141,18 +153,20 @@
 						on:click={async (e) => {
 						on:click={async (e) => {
 							e.stopPropagation();
 							e.stopPropagation();
 
 
-							const _tool = await getToolById(localStorage.token, tool.id).catch((error) => {
-								toast.error(error);
-								return null;
-							});
-
-							if (_tool) {
-								sessionStorage.tool = JSON.stringify({
-									..._tool,
-									id: `${_tool.id}_clone`,
-									name: `${_tool.name} (Clone)`
+							const _function = await getFunctionById(localStorage.token, func.id).catch(
+								(error) => {
+									toast.error(error);
+									return null;
+								}
+							);
+
+							if (_function) {
+								sessionStorage.function = JSON.stringify({
+									..._function,
+									id: `${_function.id}_clone`,
+									name: `${_function.name} (Clone)`
 								});
 								});
-								goto('/workspace/tools/create');
+								goto('/workspace/functions/create');
 							}
 							}
 						}}
 						}}
 					>
 					>
@@ -180,16 +194,18 @@
 						on:click={async (e) => {
 						on:click={async (e) => {
 							e.stopPropagation();
 							e.stopPropagation();
 
 
-							const _tool = await getToolById(localStorage.token, tool.id).catch((error) => {
-								toast.error(error);
-								return null;
-							});
+							const _function = await getFunctionById(localStorage.token, func.id).catch(
+								(error) => {
+									toast.error(error);
+									return null;
+								}
+							);
 
 
-							if (_tool) {
-								let blob = new Blob([JSON.stringify([_tool])], {
+							if (_function) {
+								let blob = new Blob([JSON.stringify([_function])], {
 									type: 'application/json'
 									type: 'application/json'
 								});
 								});
-								saveAs(blob, `tool-${_tool.id}-export-${Date.now()}.json`);
+								saveAs(blob, `function-${_function.id}-export-${Date.now()}.json`);
 							}
 							}
 						}}
 						}}
 					>
 					>
@@ -204,14 +220,16 @@
 						on:click={async (e) => {
 						on:click={async (e) => {
 							e.stopPropagation();
 							e.stopPropagation();
 
 
-							const res = await deleteToolById(localStorage.token, tool.id).catch((error) => {
+							const res = await deleteFunctionById(localStorage.token, func.id).catch((error) => {
 								toast.error(error);
 								toast.error(error);
 								return null;
 								return null;
 							});
 							});
 
 
 							if (res) {
 							if (res) {
-								toast.success('Tool deleted successfully');
-								tools.set(await getTools(localStorage.token));
+								toast.success('Function deleted successfully');
+
+								functions.set(await getFunctions(localStorage.token));
+								models.set(await getModels(localStorage.token));
 							}
 							}
 						}}
 						}}
 					>
 					>
@@ -246,7 +264,7 @@
 	<div class="flex space-x-2">
 	<div class="flex space-x-2">
 		<input
 		<input
 			id="documents-import-input"
 			id="documents-import-input"
-			bind:this={toolsImportInputElement}
+			bind:this={functionsImportInputElement}
 			bind:files={importFiles}
 			bind:files={importFiles}
 			type="file"
 			type="file"
 			accept=".json"
 			accept=".json"
@@ -260,7 +278,7 @@
 		<button
 		<button
 			class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition"
 			class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition"
 			on:click={() => {
 			on:click={() => {
-				toolsImportInputElement.click();
+				functionsImportInputElement.click();
 			}}
 			}}
 		>
 		>
 			<div class=" self-center mr-2 font-medium">{$i18n.t('Import Functions')}</div>
 			<div class=" self-center mr-2 font-medium">{$i18n.t('Import Functions')}</div>
@@ -284,16 +302,16 @@
 		<button
 		<button
 			class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition"
 			class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition"
 			on:click={async () => {
 			on:click={async () => {
-				const _tools = await exportTools(localStorage.token).catch((error) => {
+				const _functions = await exportFunctions(localStorage.token).catch((error) => {
 					toast.error(error);
 					toast.error(error);
 					return null;
 					return null;
 				});
 				});
 
 
-				if (_tools) {
-					let blob = new Blob([JSON.stringify(_tools)], {
+				if (_functions) {
+					let blob = new Blob([JSON.stringify(_functions)], {
 						type: 'application/json'
 						type: 'application/json'
 					});
 					});
-					saveAs(blob, `tools-export-${Date.now()}.json`);
+					saveAs(blob, `functions-export-${Date.now()}.json`);
 				}
 				}
 			}}
 			}}
 		>
 		>
@@ -322,18 +340,19 @@
 	on:confirm={() => {
 	on:confirm={() => {
 		const reader = new FileReader();
 		const reader = new FileReader();
 		reader.onload = async (event) => {
 		reader.onload = async (event) => {
-			const _tools = JSON.parse(event.target.result);
-			console.log(_tools);
+			const _functions = JSON.parse(event.target.result);
+			console.log(_functions);
 
 
-			for (const tool of _tools) {
-				const res = await createNewTool(localStorage.token, tool).catch((error) => {
+			for (const func of _functions) {
+				const res = await createNewFunction(localStorage.token, func).catch((error) => {
 					toast.error(error);
 					toast.error(error);
 					return null;
 					return null;
 				});
 				});
 			}
 			}
 
 
-			toast.success('Tool imported successfully');
-			tools.set(await getTools(localStorage.token));
+			toast.success('Functions imported successfully');
+			functions.set(await getFunctions(localStorage.token));
+			models.set(await getModels(localStorage.token));
 		};
 		};
 
 
 		reader.readAsText(importFiles[0]);
 		reader.readAsText(importFiles[0]);
@@ -344,8 +363,8 @@
 			<div>Please carefully review the following warnings:</div>
 			<div>Please carefully review the following warnings:</div>
 
 
 			<ul class=" mt-1 list-disc pl-4 text-xs">
 			<ul class=" mt-1 list-disc pl-4 text-xs">
-				<li>Tools have a function calling system that allows arbitrary code execution.</li>
-				<li>Do not install tools from sources you do not fully trust.</li>
+				<li>Functions allow arbitrary code execution.</li>
+				<li>Do not install functions from sources you do not fully trust.</li>
 			</ul>
 			</ul>
 		</div>
 		</div>
 
 

+ 235 - 0
src/lib/components/workspace/Functions/FunctionEditor.svelte

@@ -0,0 +1,235 @@
+<script>
+	import { getContext, createEventDispatcher, onMount } from 'svelte';
+	import { goto } from '$app/navigation';
+
+	const dispatch = createEventDispatcher();
+	const i18n = getContext('i18n');
+
+	import CodeEditor from '$lib/components/common/CodeEditor.svelte';
+	import ConfirmDialog from '$lib/components/common/ConfirmDialog.svelte';
+
+	let formElement = null;
+	let loading = false;
+	let showConfirm = false;
+
+	export let edit = false;
+	export let clone = false;
+
+	export let id = '';
+	export let name = '';
+	export let meta = {
+		description: ''
+	};
+	export let content = '';
+
+	$: if (name && !edit && !clone) {
+		id = name.replace(/\s+/g, '_').toLowerCase();
+	}
+
+	let codeEditor;
+	let boilerplate = `from pydantic import BaseModel
+from typing import Optional
+
+
+class Filter:
+    class Valves(BaseModel):
+        max_turns: int = 4
+        pass
+
+    def __init__(self):
+        # Indicates custom file handling logic. This flag helps disengage default routines in favor of custom
+        # implementations, informing the WebUI to defer file-related operations to designated methods within this class.
+        # Alternatively, you can remove the files directly from the body in from the inlet hook
+        self.file_handler = True
+
+        # Initialize 'valves' with specific configurations. Using 'Valves' instance helps encapsulate settings,
+        # which ensures settings are managed cohesively and not confused with operational flags like 'file_handler'.
+        self.valves = self.Valves(**{"max_turns": 2})
+        pass
+
+    def inlet(self, body: dict, user: Optional[dict] = None) -> dict:
+        # Modify the request body or validate it before processing by the chat completion API.
+        # This function is the pre-processor for the API where various checks on the input can be performed.
+        # It can also modify the request before sending it to the API.
+        print(f"inlet:{__name__}")
+        print(f"inlet:body:{body}")
+        print(f"inlet:user:{user}")
+
+        if user.get("role", "admin") in ["user", "admin"]:
+            messages = body.get("messages", [])
+            if len(messages) > self.valves.max_turns:
+                raise Exception(
+                    f"Conversation turn limit exceeded. Max turns: {self.valves.max_turns}"
+                )
+
+        return body
+
+    def outlet(self, body: dict, user: Optional[dict] = None) -> dict:
+        # Modify or analyze the response body after processing by the API.
+        # This function is the post-processor for the API, which can be used to modify the response
+        # or perform additional checks and analytics.
+        print(f"outlet:{__name__}")
+        print(f"outlet:body:{body}")
+        print(f"outlet:user:{user}")
+
+        messages = [
+            {
+                **message,
+                "content": f"{message['content']} - @@Modified from Filter Outlet",
+            }
+            for message in body.get("messages", [])
+        ]
+
+        return {"messages": messages}
+
+`;
+
+	const saveHandler = async () => {
+		loading = true;
+		dispatch('save', {
+			id,
+			name,
+			meta,
+			content
+		});
+	};
+
+	const submitHandler = async () => {
+		if (codeEditor) {
+			const res = await codeEditor.formatPythonCodeHandler();
+
+			if (res) {
+				console.log('Code formatted successfully');
+				saveHandler();
+			}
+		}
+	};
+</script>
+
+<div class=" flex flex-col justify-between w-full overflow-y-auto h-full">
+	<div class="mx-auto w-full md:px-0 h-full">
+		<form
+			bind:this={formElement}
+			class=" flex flex-col max-h-[100dvh] h-full"
+			on:submit|preventDefault={() => {
+				if (edit) {
+					submitHandler();
+				} else {
+					showConfirm = true;
+				}
+			}}
+		>
+			<div class="mb-2.5">
+				<button
+					class="flex space-x-1"
+					on:click={() => {
+						goto('/workspace/functions');
+					}}
+					type="button"
+				>
+					<div class=" self-center">
+						<svg
+							xmlns="http://www.w3.org/2000/svg"
+							viewBox="0 0 20 20"
+							fill="currentColor"
+							class="w-4 h-4"
+						>
+							<path
+								fill-rule="evenodd"
+								d="M17 10a.75.75 0 01-.75.75H5.612l4.158 3.96a.75.75 0 11-1.04 1.08l-5.5-5.25a.75.75 0 010-1.08l5.5-5.25a.75.75 0 111.04 1.08L5.612 9.25H16.25A.75.75 0 0117 10z"
+								clip-rule="evenodd"
+							/>
+						</svg>
+					</div>
+					<div class=" self-center font-medium text-sm">{$i18n.t('Back')}</div>
+				</button>
+			</div>
+
+			<div class="flex flex-col flex-1 overflow-auto h-0 rounded-lg">
+				<div class="w-full mb-2 flex flex-col gap-1.5">
+					<div class="flex gap-2 w-full">
+						<input
+							class="w-full px-3 py-2 text-sm font-medium bg-gray-50 dark:bg-gray-850 dark:text-gray-200 rounded-lg outline-none"
+							type="text"
+							placeholder="Function Name (e.g. My Filter)"
+							bind:value={name}
+							required
+						/>
+
+						<input
+							class="w-full px-3 py-2 text-sm font-medium disabled:text-gray-300 dark:disabled:text-gray-700 bg-gray-50 dark:bg-gray-850 dark:text-gray-200 rounded-lg outline-none"
+							type="text"
+							placeholder="Function ID (e.g. my_filter)"
+							bind:value={id}
+							required
+							disabled={edit}
+						/>
+					</div>
+					<input
+						class="w-full px-3 py-2 text-sm font-medium bg-gray-50 dark:bg-gray-850 dark:text-gray-200 rounded-lg outline-none"
+						type="text"
+						placeholder="Function Description (e.g. A filter to remove profanity from text)"
+						bind:value={meta.description}
+						required
+					/>
+				</div>
+
+				<div class="mb-2 flex-1 overflow-auto h-0 rounded-lg">
+					<CodeEditor
+						bind:value={content}
+						bind:this={codeEditor}
+						{boilerplate}
+						on:save={() => {
+							if (formElement) {
+								formElement.requestSubmit();
+							}
+						}}
+					/>
+				</div>
+
+				<div class="pb-3 flex justify-between">
+					<div class="flex-1 pr-3">
+						<div class="text-xs text-gray-500 line-clamp-2">
+							<span class=" font-semibold dark:text-gray-200">Warning:</span> Functions allow
+							arbitrary code execution <br />—
+							<span class=" font-medium dark:text-gray-400"
+								>don't install random functions from sources you don't trust.</span
+							>
+						</div>
+					</div>
+
+					<button
+						class="px-3 py-1.5 text-sm font-medium bg-emerald-600 hover:bg-emerald-700 text-gray-50 transition rounded-lg"
+						type="submit"
+					>
+						{$i18n.t('Save')}
+					</button>
+				</div>
+			</div>
+		</form>
+	</div>
+</div>
+
+<ConfirmDialog
+	bind:show={showConfirm}
+	on:confirm={() => {
+		submitHandler();
+	}}
+>
+	<div class="text-sm text-gray-500">
+		<div class=" bg-yellow-500/20 text-yellow-700 dark:text-yellow-200 rounded-lg px-4 py-3">
+			<div>Please carefully review the following warnings:</div>
+
+			<ul class=" mt-1 list-disc pl-4 text-xs">
+				<li>Functions allow arbitrary code execution.</li>
+				<li>Do not install functions from sources you do not fully trust.</li>
+			</ul>
+		</div>
+
+		<div class="my-3">
+			I acknowledge that I have read and I understand the implications of my action. I am aware of
+			the risks associated with executing arbitrary code and I have verified the trustworthiness of
+			the source.
+		</div>
+	</div>
+</ConfirmDialog>

+ 60 - 0
src/lib/components/workspace/Models/FiltersSelector.svelte

@@ -0,0 +1,60 @@
+<script lang="ts">
+	import { getContext, onMount } from 'svelte';
+	import Checkbox from '$lib/components/common/Checkbox.svelte';
+	import Tooltip from '$lib/components/common/Tooltip.svelte';
+
+	const i18n = getContext('i18n');
+
+	export let filters = [];
+	export let selectedFilterIds = [];
+
+	let _filters = {};
+
+	onMount(() => {
+		_filters = filters.reduce((acc, filter) => {
+			acc[filter.id] = {
+				...filter,
+				selected: selectedFilterIds.includes(filter.id)
+			};
+
+			return acc;
+		}, {});
+	});
+</script>
+
+<div>
+	<div class="flex w-full justify-between mb-1">
+		<div class=" self-center text-sm font-semibold">{$i18n.t('Filters')}</div>
+	</div>
+
+	<div class=" text-xs dark:text-gray-500">
+		{$i18n.t('To select filters here, add them to the "Functions" workspace first.')}
+	</div>
+
+	<!-- TODO: Filer order matters -->
+	<div class="flex flex-col">
+		{#if filters.length > 0}
+			<div class=" flex items-center mt-2 flex-wrap">
+				{#each Object.keys(_filters) as filter, filterIdx}
+					<div class=" flex items-center gap-2 mr-3">
+						<div class="self-center flex items-center">
+							<Checkbox
+								state={_filters[filter].selected ? 'checked' : 'unchecked'}
+								on:change={(e) => {
+									_filters[filter].selected = e.detail === 'checked';
+									selectedFilterIds = Object.keys(_filters).filter((t) => _filters[t].selected);
+								}}
+							/>
+						</div>
+
+						<div class=" py-0.5 text-sm w-full capitalize font-medium">
+							<Tooltip content={_filters[filter].meta.description}>
+								{_filters[filter].name}
+							</Tooltip>
+						</div>
+					</div>
+				{/each}
+			</div>
+		{/if}
+	</div>
+</div>

+ 2 - 0
src/lib/stores/index.ts

@@ -27,7 +27,9 @@ export const tags = writable([]);
 export const models: Writable<Model[]> = writable([]);
 export const models: Writable<Model[]> = writable([]);
 export const prompts: Writable<Prompt[]> = writable([]);
 export const prompts: Writable<Prompt[]> = writable([]);
 export const documents: Writable<Document[]> = writable([]);
 export const documents: Writable<Document[]> = writable([]);
+
 export const tools = writable([]);
 export const tools = writable([]);
+export const functions = writable([]);
 
 
 export const banners: Writable<Banner[]> = writable([]);
 export const banners: Writable<Banner[]> = writable([]);
 
 

+ 6 - 1
src/routes/(app)/workspace/+layout.svelte

@@ -1,11 +1,16 @@
 <script lang="ts">
 <script lang="ts">
 	import { onMount, getContext } from 'svelte';
 	import { onMount, getContext } from 'svelte';
 
 
-	import { WEBUI_NAME, showSidebar } from '$lib/stores';
+	import { WEBUI_NAME, showSidebar, functions } from '$lib/stores';
 	import MenuLines from '$lib/components/icons/MenuLines.svelte';
 	import MenuLines from '$lib/components/icons/MenuLines.svelte';
 	import { page } from '$app/stores';
 	import { page } from '$app/stores';
+	import { getFunctions } from '$lib/apis/functions';
 
 
 	const i18n = getContext('i18n');
 	const i18n = getContext('i18n');
+
+	onMount(async () => {
+		functions.set(await getFunctions(localStorage.token));
+	});
 </script>
 </script>
 
 
 <svelte:head>
 <svelte:head>

+ 22 - 19
src/routes/(app)/workspace/functions/create/+page.svelte

@@ -1,18 +1,20 @@
 <script>
 <script>
-	import { goto } from '$app/navigation';
-	import { createNewTool, getTools } from '$lib/apis/tools';
-	import ToolkitEditor from '$lib/components/workspace/Tools/ToolkitEditor.svelte';
-	import { tools } from '$lib/stores';
-	import { onMount } from 'svelte';
 	import { toast } from 'svelte-sonner';
 	import { toast } from 'svelte-sonner';
+	import { onMount } from 'svelte';
+	import { goto } from '$app/navigation';
+
+	import { functions, models } from '$lib/stores';
+	import { createNewFunction, getFunctions } from '$lib/apis/functions';
+	import FunctionEditor from '$lib/components/workspace/Functions/FunctionEditor.svelte';
+	import { getModels } from '$lib/apis';
 
 
 	let mounted = false;
 	let mounted = false;
 	let clone = false;
 	let clone = false;
-	let tool = null;
+	let func = null;
 
 
 	const saveHandler = async (data) => {
 	const saveHandler = async (data) => {
 		console.log(data);
 		console.log(data);
-		const res = await createNewTool(localStorage.token, {
+		const res = await createNewFunction(localStorage.token, {
 			id: data.id,
 			id: data.id,
 			name: data.name,
 			name: data.name,
 			meta: data.meta,
 			meta: data.meta,
@@ -23,19 +25,20 @@
 		});
 		});
 
 
 		if (res) {
 		if (res) {
-			toast.success('Tool created successfully');
-			tools.set(await getTools(localStorage.token));
+			toast.success('Function created successfully');
+			functions.set(await getFunctions(localStorage.token));
+			models.set(await getModels(localStorage.token));
 
 
-			await goto('/workspace/tools');
+			await goto('/workspace/functions');
 		}
 		}
 	};
 	};
 
 
 	onMount(() => {
 	onMount(() => {
-		if (sessionStorage.tool) {
-			tool = JSON.parse(sessionStorage.tool);
-			sessionStorage.removeItem('tool');
+		if (sessionStorage.function) {
+			func = JSON.parse(sessionStorage.function);
+			sessionStorage.removeItem('function');
 
 
-			console.log(tool);
+			console.log(func);
 			clone = true;
 			clone = true;
 		}
 		}
 
 
@@ -44,11 +47,11 @@
 </script>
 </script>
 
 
 {#if mounted}
 {#if mounted}
-	<ToolkitEditor
-		id={tool?.id ?? ''}
-		name={tool?.name ?? ''}
-		meta={tool?.meta ?? { description: '' }}
-		content={tool?.content ?? ''}
+	<FunctionEditor
+		id={func?.id ?? ''}
+		name={func?.name ?? ''}
+		meta={func?.meta ?? { description: '' }}
+		content={func?.content ?? ''}
 		{clone}
 		{clone}
 		on:save={(e) => {
 		on:save={(e) => {
 			saveHandler(e.detail);
 			saveHandler(e.detail);

+ 22 - 20
src/routes/(app)/workspace/functions/edit/+page.svelte

@@ -1,18 +1,21 @@
 <script>
 <script>
+	import { toast } from 'svelte-sonner';
+	import { onMount } from 'svelte';
+
 	import { goto } from '$app/navigation';
 	import { goto } from '$app/navigation';
 	import { page } from '$app/stores';
 	import { page } from '$app/stores';
-	import { getToolById, getTools, updateToolById } from '$lib/apis/tools';
+	import { functions, models } from '$lib/stores';
+	import { updateFunctionById, getFunctions, getFunctionById } from '$lib/apis/functions';
+
+	import FunctionEditor from '$lib/components/workspace/Functions/FunctionEditor.svelte';
 	import Spinner from '$lib/components/common/Spinner.svelte';
 	import Spinner from '$lib/components/common/Spinner.svelte';
-	import ToolkitEditor from '$lib/components/workspace/Tools/ToolkitEditor.svelte';
-	import { tools } from '$lib/stores';
-	import { onMount } from 'svelte';
-	import { toast } from 'svelte-sonner';
+	import { getModels } from '$lib/apis';
 
 
-	let tool = null;
+	let func = null;
 
 
 	const saveHandler = async (data) => {
 	const saveHandler = async (data) => {
 		console.log(data);
 		console.log(data);
-		const res = await updateToolById(localStorage.token, tool.id, {
+		const res = await updateFunctionById(localStorage.token, func.id, {
 			id: data.id,
 			id: data.id,
 			name: data.name,
 			name: data.name,
 			meta: data.meta,
 			meta: data.meta,
@@ -23,10 +26,9 @@
 		});
 		});
 
 
 		if (res) {
 		if (res) {
-			toast.success('Tool updated successfully');
-			tools.set(await getTools(localStorage.token));
-
-			// await goto('/workspace/tools');
+			toast.success('Function updated successfully');
+			functions.set(await getFunctions(localStorage.token));
+			models.set(await getModels(localStorage.token));
 		}
 		}
 	};
 	};
 
 
@@ -35,24 +37,24 @@
 		const id = $page.url.searchParams.get('id');
 		const id = $page.url.searchParams.get('id');
 
 
 		if (id) {
 		if (id) {
-			tool = await getToolById(localStorage.token, id).catch((error) => {
+			func = await getFunctionById(localStorage.token, id).catch((error) => {
 				toast.error(error);
 				toast.error(error);
-				goto('/workspace/tools');
+				goto('/workspace/functions');
 				return null;
 				return null;
 			});
 			});
 
 
-			console.log(tool);
+			console.log(func);
 		}
 		}
 	});
 	});
 </script>
 </script>
 
 
-{#if tool}
-	<ToolkitEditor
+{#if func}
+	<FunctionEditor
 		edit={true}
 		edit={true}
-		id={tool.id}
-		name={tool.name}
-		meta={tool.meta}
-		content={tool.content}
+		id={func.id}
+		name={func.name}
+		meta={func.meta}
+		content={func.content}
 		on:save={(e) => {
 		on:save={(e) => {
 			saveHandler(e.detail);
 			saveHandler(e.detail);
 		}}
 		}}

+ 22 - 1
src/routes/(app)/workspace/models/edit/+page.svelte

@@ -5,7 +5,7 @@
 
 
 	import { onMount, getContext } from 'svelte';
 	import { onMount, getContext } from 'svelte';
 	import { page } from '$app/stores';
 	import { page } from '$app/stores';
-	import { settings, user, config, models, tools } from '$lib/stores';
+	import { settings, user, config, models, tools, functions } from '$lib/stores';
 	import { splitStream } from '$lib/utils';
 	import { splitStream } from '$lib/utils';
 
 
 	import { getModelInfos, updateModelById } from '$lib/apis/models';
 	import { getModelInfos, updateModelById } from '$lib/apis/models';
@@ -16,6 +16,7 @@
 	import Tags from '$lib/components/common/Tags.svelte';
 	import Tags from '$lib/components/common/Tags.svelte';
 	import Knowledge from '$lib/components/workspace/Models/Knowledge.svelte';
 	import Knowledge from '$lib/components/workspace/Models/Knowledge.svelte';
 	import ToolsSelector from '$lib/components/workspace/Models/ToolsSelector.svelte';
 	import ToolsSelector from '$lib/components/workspace/Models/ToolsSelector.svelte';
+	import FiltersSelector from '$lib/components/workspace/Models/FiltersSelector.svelte';
 
 
 	const i18n = getContext('i18n');
 	const i18n = getContext('i18n');
 
 
@@ -62,6 +63,7 @@
 
 
 	let knowledge = [];
 	let knowledge = [];
 	let toolIds = [];
 	let toolIds = [];
+	let filterIds = [];
 
 
 	const updateHandler = async () => {
 	const updateHandler = async () => {
 		loading = true;
 		loading = true;
@@ -86,6 +88,14 @@
 			}
 			}
 		}
 		}
 
 
+		if (filterIds.length > 0) {
+			info.meta.filterIds = filterIds;
+		} else {
+			if (info.meta.filterIds) {
+				delete info.meta.filterIds;
+			}
+		}
+
 		info.params.stop = params.stop ? params.stop.split(',').filter((s) => s.trim()) : null;
 		info.params.stop = params.stop ? params.stop.split(',').filter((s) => s.trim()) : null;
 		Object.keys(info.params).forEach((key) => {
 		Object.keys(info.params).forEach((key) => {
 			if (info.params[key] === '' || info.params[key] === null) {
 			if (info.params[key] === '' || info.params[key] === null) {
@@ -147,6 +157,10 @@
 					toolIds = [...model?.info?.meta?.toolIds];
 					toolIds = [...model?.info?.meta?.toolIds];
 				}
 				}
 
 
+				if (model?.info?.meta?.filterIds) {
+					filterIds = [...model?.info?.meta?.filterIds];
+				}
+
 				if (model?.owned_by === 'openai') {
 				if (model?.owned_by === 'openai') {
 					capabilities.usage = false;
 					capabilities.usage = false;
 				}
 				}
@@ -534,6 +548,13 @@
 				<ToolsSelector bind:selectedToolIds={toolIds} tools={$tools} />
 				<ToolsSelector bind:selectedToolIds={toolIds} tools={$tools} />
 			</div>
 			</div>
 
 
+			<div class="my-2">
+				<FiltersSelector
+					bind:selectedFilterIds={filterIds}
+					filters={$functions.filter((func) => func.type === 'filter')}
+				/>
+			</div>
+
 			<div class="my-2">
 			<div class="my-2">
 				<div class="flex w-full justify-between mb-1">
 				<div class="flex w-full justify-between mb-1">
 					<div class=" self-center text-sm font-semibold">{$i18n.t('Capabilities')}</div>
 					<div class=" self-center text-sm font-semibold">{$i18n.t('Capabilities')}</div>