Переглянути джерело

Merge pull request #3006 from open-webui/tools

feat: tools
Timothy Jaeryang Baek 11 місяців тому
батько
коміт
e6fe3ada57

+ 61 - 0
backend/apps/webui/internal/migrations/012_add_tools.py

@@ -0,0 +1,61 @@
+"""Peewee migrations -- 009_add_models.py.
+
+Some examples (model - class or model name)::
+
+    > Model = migrator.orm['table_name']            # Return model in current state by name
+    > Model = migrator.ModelClass                   # Return model in current state by name
+
+    > migrator.sql(sql)                             # Run custom SQL
+    > migrator.run(func, *args, **kwargs)           # Run python function with the given args
+    > migrator.create_model(Model)                  # Create a model (could be used as decorator)
+    > migrator.remove_model(model, cascade=True)    # Remove a model
+    > migrator.add_fields(model, **fields)          # Add fields to a model
+    > migrator.change_fields(model, **fields)       # Change fields
+    > migrator.remove_fields(model, *field_names, cascade=True)
+    > migrator.rename_field(model, old_field_name, new_field_name)
+    > migrator.rename_table(model, new_table_name)
+    > migrator.add_index(model, *col_names, unique=False)
+    > migrator.add_not_null(model, *field_names)
+    > migrator.add_default(model, field_name, default)
+    > migrator.add_constraint(model, name, sql)
+    > migrator.drop_index(model, *col_names)
+    > migrator.drop_not_null(model, *field_names)
+    > migrator.drop_constraints(model, *constraints)
+
+"""
+
+from contextlib import suppress
+
+import peewee as pw
+from peewee_migrate import Migrator
+
+
+with suppress(ImportError):
+    import playhouse.postgres_ext as pw_pext
+
+
+def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
+    """Write your migrations here."""
+
+    @migrator.create_model
+    class Tool(pw.Model):
+        id = pw.TextField(unique=True)
+        user_id = pw.TextField()
+
+        name = pw.TextField()
+        content = pw.TextField()
+        specs = pw.TextField()
+
+        meta = pw.TextField()
+
+        created_at = pw.BigIntegerField(null=False)
+        updated_at = pw.BigIntegerField(null=False)
+
+        class Meta:
+            table_name = "tool"
+
+
+def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
+    """Write your rollback migrations here."""
+
+    migrator.remove_model("tool")

+ 5 - 2
backend/apps/webui/main.py

@@ -6,6 +6,7 @@ from apps.webui.routers import (
     users,
     chats,
     documents,
+    tools,
     models,
     prompts,
     configs,
@@ -26,8 +27,8 @@ from config import (
     WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
     JWT_EXPIRES_IN,
     WEBUI_BANNERS,
-    AppConfig,
     ENABLE_COMMUNITY_SHARING,
+    AppConfig,
 )
 
 app = FastAPI()
@@ -38,6 +39,7 @@ app.state.config = AppConfig()
 
 app.state.config.ENABLE_SIGNUP = ENABLE_SIGNUP
 app.state.config.JWT_EXPIRES_IN = JWT_EXPIRES_IN
+app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
 
 
 app.state.config.SHOW_ADMIN_DETAILS = SHOW_ADMIN_DETAILS
@@ -54,7 +56,7 @@ app.state.config.BANNERS = WEBUI_BANNERS
 app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
 
 app.state.MODELS = {}
-app.state.AUTH_TRUSTED_EMAIL_HEADER = WEBUI_AUTH_TRUSTED_EMAIL_HEADER
+app.state.TOOLS = {}
 
 
 app.add_middleware(
@@ -70,6 +72,7 @@ app.include_router(users.router, prefix="/users", tags=["users"])
 app.include_router(chats.router, prefix="/chats", tags=["chats"])
 
 app.include_router(documents.router, prefix="/documents", tags=["documents"])
+app.include_router(tools.router, prefix="/tools", tags=["tools"])
 app.include_router(models.router, prefix="/models", tags=["models"])
 app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
 app.include_router(memories.router, prefix="/memories", tags=["memories"])

+ 131 - 0
backend/apps/webui/models/tools.py

@@ -0,0 +1,131 @@
+from pydantic import BaseModel
+from peewee import *
+from playhouse.shortcuts import model_to_dict
+from typing import List, Union, Optional
+import time
+import logging
+from apps.webui.internal.db import DB, JSONField
+
+import json
+
+from config import SRC_LOG_LEVELS
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["MODELS"])
+
+####################
+# Tools DB Schema
+####################
+
+
+class Tool(Model):
+    id = CharField(unique=True)
+    user_id = CharField()
+    name = TextField()
+    content = TextField()
+    specs = JSONField()
+    meta = JSONField()
+    updated_at = BigIntegerField()
+    created_at = BigIntegerField()
+
+    class Meta:
+        database = DB
+
+
+class ToolMeta(BaseModel):
+    description: Optional[str] = None
+
+
+class ToolModel(BaseModel):
+    id: str
+    user_id: str
+    name: str
+    content: str
+    specs: List[dict]
+    meta: ToolMeta
+    updated_at: int  # timestamp in epoch
+    created_at: int  # timestamp in epoch
+
+
+####################
+# Forms
+####################
+
+
+class ToolResponse(BaseModel):
+    id: str
+    user_id: str
+    name: str
+    meta: ToolMeta
+    updated_at: int  # timestamp in epoch
+    created_at: int  # timestamp in epoch
+
+
+class ToolForm(BaseModel):
+    id: str
+    name: str
+    content: str
+    meta: ToolMeta
+
+
+class ToolsTable:
+    def __init__(self, db):
+        self.db = db
+        self.db.create_tables([Tool])
+
+    def insert_new_tool(
+        self, user_id: str, form_data: ToolForm, specs: List[dict]
+    ) -> Optional[ToolModel]:
+        tool = ToolModel(
+            **{
+                **form_data.model_dump(),
+                "specs": specs,
+                "user_id": user_id,
+                "updated_at": int(time.time()),
+                "created_at": int(time.time()),
+            }
+        )
+
+        try:
+            result = Tool.create(**tool.model_dump())
+            if result:
+                return tool
+            else:
+                return None
+        except:
+            return None
+
+    def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
+        try:
+            tool = Tool.get(Tool.id == id)
+            return ToolModel(**model_to_dict(tool))
+        except:
+            return None
+
+    def get_tools(self) -> List[ToolModel]:
+        return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()]
+
+    def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
+        try:
+            query = Tool.update(
+                **updated,
+                updated_at=int(time.time()),
+            ).where(Tool.id == id)
+            query.execute()
+
+            tool = Tool.get(Tool.id == id)
+            return ToolModel(**model_to_dict(tool))
+        except:
+            return None
+
+    def delete_tool_by_id(self, id: str) -> bool:
+        try:
+            query = Tool.delete().where((Tool.id == id))
+            query.execute()  # Remove the rows, return number of rows removed.
+
+            return True
+        except:
+            return False
+
+
+Tools = ToolsTable(DB)

+ 177 - 0
backend/apps/webui/routers/tools.py

@@ -0,0 +1,177 @@
+from fastapi import Depends, FastAPI, HTTPException, status, Request
+from datetime import datetime, timedelta
+from typing import List, Union, Optional
+
+from fastapi import APIRouter
+from pydantic import BaseModel
+import json
+
+from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
+from apps.webui.utils import load_toolkit_module_by_id
+
+from utils.utils import get_current_user, get_admin_user
+from utils.tools import get_tools_specs
+from constants import ERROR_MESSAGES
+
+from importlib import util
+import os
+
+from config import DATA_DIR
+
+
+TOOLS_DIR = f"{DATA_DIR}/tools"
+os.makedirs(TOOLS_DIR, exist_ok=True)
+
+
+router = APIRouter()
+
+############################
+# GetToolkits
+############################
+
+
+@router.get("/", response_model=List[ToolResponse])
+async def get_toolkits(user=Depends(get_current_user)):
+    toolkits = [toolkit for toolkit in Tools.get_tools()]
+    return toolkits
+
+
+############################
+# ExportToolKits
+############################
+
+
+@router.get("/export", response_model=List[ToolModel])
+async def get_toolkits(user=Depends(get_admin_user)):
+    toolkits = [toolkit for toolkit in Tools.get_tools()]
+    return toolkits
+
+
+############################
+# CreateNewToolKit
+############################
+
+
+@router.post("/create", response_model=Optional[ToolResponse])
+async def create_new_toolkit(
+    request: Request, form_data: ToolForm, user=Depends(get_admin_user)
+):
+    if not form_data.id.isidentifier():
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail="Only alphanumeric characters and underscores are allowed in the id",
+        )
+
+    form_data.id = form_data.id.lower()
+
+    toolkit = Tools.get_tool_by_id(form_data.id)
+    if toolkit == None:
+        toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
+        try:
+            with open(toolkit_path, "w") as tool_file:
+                tool_file.write(form_data.content)
+
+            toolkit_module = load_toolkit_module_by_id(form_data.id)
+
+            TOOLS = request.app.state.TOOLS
+            TOOLS[form_data.id] = toolkit_module
+
+            specs = get_tools_specs(TOOLS[form_data.id])
+            toolkit = Tools.insert_new_tool(user.id, form_data, specs)
+
+            if toolkit:
+                return toolkit
+            else:
+                raise HTTPException(
+                    status_code=status.HTTP_400_BAD_REQUEST,
+                    detail=ERROR_MESSAGES.FILE_EXISTS,
+                )
+        except Exception as e:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT(e),
+            )
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.ID_TAKEN,
+        )
+
+
+############################
+# GetToolkitById
+############################
+
+
+@router.get("/id/{id}", response_model=Optional[ToolModel])
+async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
+    toolkit = Tools.get_tool_by_id(id)
+
+    if toolkit:
+        return toolkit
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+
+############################
+# UpdateToolkitById
+############################
+
+
+@router.post("/id/{id}/update", response_model=Optional[ToolModel])
+async def update_toolkit_by_id(
+    request: Request, id: str, form_data: ToolForm, user=Depends(get_admin_user)
+):
+    toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
+
+    try:
+        with open(toolkit_path, "w") as tool_file:
+            tool_file.write(form_data.content)
+
+        toolkit_module = load_toolkit_module_by_id(id)
+
+        TOOLS = request.app.state.TOOLS
+        TOOLS[id] = toolkit_module
+
+        specs = get_tools_specs(TOOLS[id])
+
+        updated = {
+            **form_data.model_dump(exclude={"id"}),
+            "specs": specs,
+        }
+
+        print(updated)
+        toolkit = Tools.update_tool_by_id(id, updated)
+
+        if toolkit:
+            return toolkit
+        else:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT("Error updating toolkit"),
+            )
+
+    except Exception as e:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.DEFAULT(e),
+        )
+
+
+############################
+# DeleteToolkitById
+############################
+
+
+@router.delete("/id/{id}/delete", response_model=bool)
+async def delete_toolkit_by_id(request: Request, id: str, user=Depends(get_admin_user)):
+    result = Tools.delete_tool_by_id(id)
+
+    if result:
+        TOOLS = request.app.state.TOOLS
+        del TOOLS[id]
+
+    return result

+ 23 - 0
backend/apps/webui/utils.py

@@ -0,0 +1,23 @@
+from importlib import util
+import os
+
+from config import TOOLS_DIR
+
+
+def load_toolkit_module_by_id(toolkit_id):
+    toolkit_path = os.path.join(TOOLS_DIR, f"{toolkit_id}.py")
+    spec = util.spec_from_file_location(toolkit_id, toolkit_path)
+    module = util.module_from_spec(spec)
+
+    try:
+        spec.loader.exec_module(module)
+        print(f"Loaded module: {module.__name__}")
+        if hasattr(module, "Tools"):
+            return module.Tools()
+        else:
+            raise Exception("No Tools class found")
+    except Exception as e:
+        print(f"Error loading module: {toolkit_id}")
+        # Move the file to the error folder
+        os.rename(toolkit_path, f"{toolkit_path}.error")
+        raise e

+ 19 - 1
backend/config.py

@@ -368,6 +368,14 @@ DOCS_DIR = os.getenv("DOCS_DIR", f"{DATA_DIR}/docs")
 Path(DOCS_DIR).mkdir(parents=True, exist_ok=True)
 
 
+####################################
+# Tools DIR
+####################################
+
+TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
+Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
+
+
 ####################################
 # LITELLM_CONFIG
 ####################################
@@ -669,7 +677,6 @@ Question:
     ),
 )
 
-
 SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig(
     "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD",
     "task.search.prompt_length_threshold",
@@ -679,6 +686,17 @@ SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = PersistentConfig(
     ),
 )
 
+TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = PersistentConfig(
+    "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
+    "task.tools.prompt_template",
+    os.environ.get(
+        "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE",
+        """Tools: {{TOOLS}}
+If a function tool doesn't match the query, return an empty string. Else, pick a function tool, fill in the parameters from the function tool's schema, and return it in the format { "name": \"functionName\", "parameters": { "key": "value" } }. Only pick a function if the user asks.  Only return the object. Do not return any other text.""",
+    ),
+)
+
+
 ####################################
 # WEBUI_SECRET_KEY
 ####################################

+ 1 - 0
backend/constants.py

@@ -32,6 +32,7 @@ class ERROR_MESSAGES(str, Enum):
     COMMAND_TAKEN = "Uh-oh! This command is already registered. Please choose another command string."
     FILE_EXISTS = "Uh-oh! This file is already registered. Please choose another file."
 
+    ID_TAKEN = "Uh-oh! This id is already registered. Please choose another id string."
     MODEL_ID_TAKEN = "Uh-oh! This model id is already registered. Please choose another model id string."
 
     NAME_TAG_TAKEN = "Uh-oh! This name tag is already registered. Please choose another name tag string."

+ 174 - 15
backend/main.py

@@ -47,15 +47,24 @@ from pydantic import BaseModel
 from typing import List, Optional
 
 from apps.webui.models.models import Models, ModelModel
+from apps.webui.models.tools import Tools
+from apps.webui.utils import load_toolkit_module_by_id
+
+
 from utils.utils import (
     get_admin_user,
     get_verified_user,
     get_current_user,
     get_http_authorization_cred,
 )
-from utils.task import title_generation_template, search_query_generation_template
+from utils.task import (
+    title_generation_template,
+    search_query_generation_template,
+    tools_function_calling_generation_template,
+)
+from utils.misc import get_last_user_message, add_or_update_system_message
 
-from apps.rag.utils import rag_messages
+from apps.rag.utils import rag_messages, rag_template
 
 from config import (
     CONFIG_DATA,
@@ -82,6 +91,7 @@ from config import (
     TITLE_GENERATION_PROMPT_TEMPLATE,
     SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
     SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
+    TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
     AppConfig,
 )
 from constants import ERROR_MESSAGES
@@ -148,24 +158,80 @@ app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE = (
 app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
     SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
 )
+app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
+    TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
+)
 
 app.state.MODELS = {}
 
 origins = ["*"]
 
-# Custom middleware to add security headers
-# class SecurityHeadersMiddleware(BaseHTTPMiddleware):
-#     async def dispatch(self, request: Request, call_next):
-#         response: Response = await call_next(request)
-#         response.headers["Cross-Origin-Opener-Policy"] = "same-origin"
-#         response.headers["Cross-Origin-Embedder-Policy"] = "require-corp"
-#         return response
 
+async def get_function_call_response(prompt, tool_id, template, task_model_id, user):
+    tool = Tools.get_tool_by_id(tool_id)
+    tools_specs = json.dumps(tool.specs, indent=2)
+    content = tools_function_calling_generation_template(template, tools_specs)
+
+    payload = {
+        "model": task_model_id,
+        "messages": [
+            {"role": "system", "content": content},
+            {"role": "user", "content": f"Query: {prompt}"},
+        ],
+        "stream": False,
+    }
+
+    payload = filter_pipeline(payload, user)
+    model = app.state.MODELS[task_model_id]
+
+    response = None
+    try:
+        if model["owned_by"] == "ollama":
+            response = await generate_ollama_chat_completion(
+                OpenAIChatCompletionForm(**payload), user=user
+            )
+        else:
+            response = await generate_openai_chat_completion(payload, user=user)
+
+        content = None
+        async for chunk in response.body_iterator:
+            data = json.loads(chunk.decode("utf-8"))
+            content = data["choices"][0]["message"]["content"]
+
+        # Cleanup any remaining background tasks if necessary
+        if response.background is not None:
+            await response.background()
+
+        # Parse the function response
+        if content is not None:
+            result = json.loads(content)
+            print(result)
+
+            # Call the function
+            if "name" in result:
+                if tool_id in webui_app.state.TOOLS:
+                    toolkit_module = webui_app.state.TOOLS[tool_id]
+                else:
+                    toolkit_module = load_toolkit_module_by_id(tool_id)
+                    webui_app.state.TOOLS[tool_id] = toolkit_module
+
+                function = getattr(toolkit_module, result["name"])
+                function_result = None
+                try:
+                    function_result = function(**result["parameters"])
+                except Exception as e:
+                    print(e)
+
+                # Add the function result to the system prompt
+                if function_result:
+                    return function_result
+    except Exception as e:
+        print(f"Error: {e}")
 
-# app.add_middleware(SecurityHeadersMiddleware)
+    return None
 
 
-class RAGMiddleware(BaseHTTPMiddleware):
+class ChatCompletionMiddleware(BaseHTTPMiddleware):
     async def dispatch(self, request: Request, call_next):
         return_citations = False
 
@@ -182,12 +248,68 @@ class RAGMiddleware(BaseHTTPMiddleware):
             # Parse string to JSON
             data = json.loads(body_str) if body_str else {}
 
+            # Remove the citations from the body
             return_citations = data.get("citations", False)
             if "citations" in data:
                 del data["citations"]
 
-            # Example: Add a new key-value pair or modify existing ones
-            # data["modified"] = True  # Example modification
+            # Set the task model
+            task_model_id = data["model"]
+            if task_model_id not in app.state.MODELS:
+                raise HTTPException(
+                    status_code=status.HTTP_404_NOT_FOUND,
+                    detail="Model not found",
+                )
+
+            # Check if the user has a custom task model
+            # If the user has a custom task model, use that model
+            if app.state.MODELS[task_model_id]["owned_by"] == "ollama":
+                if (
+                    app.state.config.TASK_MODEL
+                    and app.state.config.TASK_MODEL in app.state.MODELS
+                ):
+                    task_model_id = app.state.config.TASK_MODEL
+            else:
+                if (
+                    app.state.config.TASK_MODEL_EXTERNAL
+                    and app.state.config.TASK_MODEL_EXTERNAL in app.state.MODELS
+                ):
+                    task_model_id = app.state.config.TASK_MODEL_EXTERNAL
+
+            if "tool_ids" in data:
+                user = get_current_user(
+                    get_http_authorization_cred(request.headers.get("Authorization"))
+                )
+                prompt = get_last_user_message(data["messages"])
+                context = ""
+
+                for tool_id in data["tool_ids"]:
+                    print(tool_id)
+                    response = await get_function_call_response(
+                        prompt=prompt,
+                        tool_id=tool_id,
+                        template=app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
+                        task_model_id=task_model_id,
+                        user=user,
+                    )
+
+                    if response:
+                        context += ("\n" if context != "" else "") + response
+
+                if context != "":
+                    system_prompt = rag_template(
+                        rag_app.state.config.RAG_TEMPLATE, context, prompt
+                    )
+
+                    print(system_prompt)
+
+                    data["messages"] = add_or_update_system_message(
+                        f"\n{system_prompt}", data["messages"]
+                    )
+
+                del data["tool_ids"]
+
+            # If docs field is present, generate RAG completions
             if "docs" in data:
                 data = {**data}
                 data["messages"], citations = rag_messages(
@@ -210,7 +332,6 @@ class RAGMiddleware(BaseHTTPMiddleware):
 
             # Replace the request body with the modified one
             request._body = modified_body_bytes
-
             # Set custom header to ensure content-length matches new body length
             request.headers.__dict__["_list"] = [
                 (b"content-length", str(len(modified_body_bytes)).encode("utf-8")),
@@ -253,7 +374,7 @@ class RAGMiddleware(BaseHTTPMiddleware):
             yield data
 
 
-app.add_middleware(RAGMiddleware)
+app.add_middleware(ChatCompletionMiddleware)
 
 
 def filter_pipeline(payload, user):
@@ -515,6 +636,7 @@ async def get_task_config(user=Depends(get_verified_user)):
         "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
         "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
         "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
+        "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
     }
 
 
@@ -524,6 +646,7 @@ class TaskConfigForm(BaseModel):
     TITLE_GENERATION_PROMPT_TEMPLATE: str
     SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE: str
     SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD: int
+    TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE: str
 
 
 @app.post("/api/task/config/update")
@@ -539,6 +662,9 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
     app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD = (
         form_data.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD
     )
+    app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE = (
+        form_data.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
+    )
 
     return {
         "TASK_MODEL": app.state.config.TASK_MODEL,
@@ -546,6 +672,7 @@ async def update_task_config(form_data: TaskConfigForm, user=Depends(get_admin_u
         "TITLE_GENERATION_PROMPT_TEMPLATE": app.state.config.TITLE_GENERATION_PROMPT_TEMPLATE,
         "SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE": app.state.config.SEARCH_QUERY_GENERATION_PROMPT_TEMPLATE,
         "SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD": app.state.config.SEARCH_QUERY_PROMPT_LENGTH_THRESHOLD,
+        "TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE": app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE,
     }
 
 
@@ -659,6 +786,38 @@ async def generate_search_query(form_data: dict, user=Depends(get_verified_user)
         return await generate_openai_chat_completion(payload, user=user)
 
 
+@app.post("/api/task/tools/completions")
+async def get_tools_function_calling(form_data: dict, user=Depends(get_verified_user)):
+    print("get_tools_function_calling")
+
+    model_id = form_data["model"]
+    if model_id not in app.state.MODELS:
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail="Model not found",
+        )
+
+    # Check if the user has a custom task model
+    # If the user has a custom task model, use that model
+    if app.state.MODELS[model_id]["owned_by"] == "ollama":
+        if app.state.config.TASK_MODEL:
+            task_model_id = app.state.config.TASK_MODEL
+            if task_model_id in app.state.MODELS:
+                model_id = task_model_id
+    else:
+        if app.state.config.TASK_MODEL_EXTERNAL:
+            task_model_id = app.state.config.TASK_MODEL_EXTERNAL
+            if task_model_id in app.state.MODELS:
+                model_id = task_model_id
+
+    print(model_id)
+    template = app.state.config.TOOLS_FUNCTION_CALLING_PROMPT_TEMPLATE
+
+    return await get_function_call_response(
+        form_data["prompt"], form_data["tool_id"], template, model_id, user
+    )
+
+
 @app.post("/api/chat/completions")
 async def generate_chat_completions(form_data: dict, user=Depends(get_verified_user)):
     model_id = form_data["model"]

+ 5 - 0
backend/utils/task.py

@@ -110,3 +110,8 @@ def search_query_generation_template(
         ),
     )
     return template
+
+
+def tools_function_calling_generation_template(template: str, tools_specs: str) -> str:
+    template = template.replace("{{TOOLS}}", tools_specs)
+    return template

+ 73 - 0
backend/utils/tools.py

@@ -0,0 +1,73 @@
+import inspect
+from typing import get_type_hints, List, Dict, Any
+
+
+def doc_to_dict(docstring):
+    lines = docstring.split("\n")
+    description = lines[1].strip()
+    param_dict = {}
+
+    for line in lines:
+        if ":param" in line:
+            line = line.replace(":param", "").strip()
+            param, desc = line.split(":", 1)
+            param_dict[param.strip()] = desc.strip()
+    ret_dict = {"description": description, "params": param_dict}
+    return ret_dict
+
+
+def get_tools_specs(tools) -> List[dict]:
+    function_list = [
+        {"name": func, "function": getattr(tools, func)}
+        for func in dir(tools)
+        if callable(getattr(tools, func)) and not func.startswith("__")
+    ]
+
+    specs = []
+    for function_item in function_list:
+        function_name = function_item["name"]
+        function = function_item["function"]
+
+        function_doc = doc_to_dict(function.__doc__ or function_name)
+        specs.append(
+            {
+                "name": function_name,
+                # TODO: multi-line desc?
+                "description": function_doc.get("description", function_name),
+                "parameters": {
+                    "type": "object",
+                    "properties": {
+                        param_name: {
+                            "type": param_annotation.__name__.lower(),
+                            **(
+                                {
+                                    "enum": (
+                                        param_annotation.__args__
+                                        if hasattr(param_annotation, "__args__")
+                                        else None
+                                    )
+                                }
+                                if hasattr(param_annotation, "__args__")
+                                else {}
+                            ),
+                            "description": function_doc.get("params", {}).get(
+                                param_name, param_name
+                            ),
+                        }
+                        for param_name, param_annotation in get_type_hints(
+                            function
+                        ).items()
+                        if param_name != "return"
+                    },
+                    "required": [
+                        name
+                        for name, param in inspect.signature(
+                            function
+                        ).parameters.items()
+                        if param.default is param.empty
+                    ],
+                },
+            }
+        )
+
+    return specs

+ 193 - 0
src/lib/apis/tools/index.ts

@@ -0,0 +1,193 @@
+import { WEBUI_API_BASE_URL } from '$lib/constants';
+
+export const createNewTool = async (token: string, tool: object) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/tools/create`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			...tool
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			error = err.detail;
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const getTools = async (token: string = '') => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/tools`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const exportTools = async (token: string = '') => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/tools/export`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const getToolById = async (token: string, id: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const updateToolById = async (token: string, id: string, tool: object) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/update`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			...tool
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const deleteToolById = async (token: string, id: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/tools/id/${id}/delete`, {
+		method: 'DELETE',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err.detail;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};

+ 17 - 1
src/lib/components/chat/Chat.svelte

@@ -24,7 +24,8 @@
 		banners,
 		user,
 		socket,
-		showCallOverlay
+		showCallOverlay,
+		tools
 	} from '$lib/stores';
 	import {
 		convertMessagesToHistory,
@@ -73,6 +74,7 @@
 	let selectedModels = [''];
 	let atSelectedModel: Model | undefined;
 
+	let selectedToolIds = [];
 	let webSearchEnabled = false;
 
 	let chat = null;
@@ -687,6 +689,7 @@
 			},
 			format: $settings.requestFormat ?? undefined,
 			keep_alive: $settings.keepAlive ?? undefined,
+			tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
 			docs: docs.length > 0 ? docs : undefined,
 			citations: docs.length > 0,
 			chat_id: $chatId
@@ -948,6 +951,7 @@
 					top_p: $settings?.params?.top_p ?? undefined,
 					frequency_penalty: $settings?.params?.frequency_penalty ?? undefined,
 					max_tokens: $settings?.params?.max_tokens ?? undefined,
+					tool_ids: selectedToolIds.length > 0 ? selectedToolIds : undefined,
 					docs: docs.length > 0 ? docs : undefined,
 					citations: docs.length > 0,
 					chat_id: $chatId
@@ -1274,8 +1278,20 @@
 				bind:files
 				bind:prompt
 				bind:autoScroll
+				bind:selectedToolIds
 				bind:webSearchEnabled
 				bind:atSelectedModel
+				availableTools={$user.role === 'admin'
+					? $tools.reduce((a, e, i, arr) => {
+							a[e.id] = {
+								name: e.name,
+								description: e.meta.description,
+								enabled: false
+							};
+
+							return a;
+					  }, {})
+					: {}}
 				{selectedModels}
 				{messages}
 				{submitPrompt}

+ 6 - 1
src/lib/components/chat/MessageInput.svelte

@@ -8,7 +8,8 @@
 		showSidebar,
 		models,
 		config,
-		showCallOverlay
+		showCallOverlay,
+		tools
 	} from '$lib/stores';
 	import { blobToFile, calculateSHA256, findWordIndices } from '$lib/utils';
 
@@ -58,6 +59,8 @@
 
 	export let files = [];
 
+	export let availableTools = {};
+	export let selectedToolIds = [];
 	export let webSearchEnabled = false;
 
 	export let prompt = '';
@@ -653,6 +656,8 @@
 								<div class=" ml-0.5 self-end mb-1.5 flex space-x-1">
 									<InputMenu
 										bind:webSearchEnabled
+										bind:selectedToolIds
+										tools={availableTools}
 										uploadFilesHandler={() => {
 											filesInputElement.click();
 										}}

+ 30 - 5
src/lib/components/chat/MessageInput/InputMenu.svelte

@@ -4,22 +4,21 @@
 	import { getContext } from 'svelte';
 
 	import Dropdown from '$lib/components/common/Dropdown.svelte';
-	import GarbageBin from '$lib/components/icons/GarbageBin.svelte';
-	import Pencil from '$lib/components/icons/Pencil.svelte';
 	import Tooltip from '$lib/components/common/Tooltip.svelte';
-	import Tags from '$lib/components/chat/Tags.svelte';
-	import Share from '$lib/components/icons/Share.svelte';
-	import ArchiveBox from '$lib/components/icons/ArchiveBox.svelte';
 	import DocumentArrowUpSolid from '$lib/components/icons/DocumentArrowUpSolid.svelte';
 	import Switch from '$lib/components/common/Switch.svelte';
 	import GlobeAltSolid from '$lib/components/icons/GlobeAltSolid.svelte';
 	import { config } from '$lib/stores';
+	import WrenchSolid from '$lib/components/icons/WrenchSolid.svelte';
 
 	const i18n = getContext('i18n');
 
 	export let uploadFilesHandler: Function;
+
+	export let selectedToolIds: string[] = [];
 	export let webSearchEnabled: boolean;
 
+	export let tools = {};
 	export let onClose: Function;
 
 	let show = false;
@@ -46,6 +45,32 @@
 			align="start"
 			transition={flyAndScale}
 		>
+			{#if Object.keys(tools).length > 0}
+				{#each Object.keys(tools) as toolId}
+					<div
+						class="flex gap-2 items-center px-3 py-2 text-sm font-medium cursor-pointer rounded-xl"
+					>
+						<div class="flex-1 flex items-center gap-2">
+							<WrenchSolid />
+
+							<Tooltip content={tools[toolId]?.description ?? ''}>
+								<div class="flex items-center line-clamp-1">{tools[toolId].name}</div>
+							</Tooltip>
+						</div>
+
+						<Switch
+							bind:state={tools[toolId].enabled}
+							on:change={(e) => {
+								selectedToolIds = e.detail
+									? [...selectedToolIds, toolId]
+									: selectedToolIds.filter((id) => id !== toolId);
+							}}
+						/>
+					</div>
+				{/each}
+				<hr class="border-gray-100 dark:border-gray-800 my-1" />
+			{/if}
+
 			{#if $config?.features?.enable_web_search}
 				<div
 					class="flex gap-2 items-center px-3 py-2 text-sm font-medium cursor-pointer rounded-xl"

+ 13 - 13
src/lib/components/common/CodeEditor.svelte

@@ -60,7 +60,10 @@
 	];
 
 	onMount(() => {
-		value = boilerplate;
+		console.log(value);
+		if (value === '') {
+			value = boilerplate;
+		}
 
 		// Check if html class has dark mode
 		isDarkMode = document.documentElement.classList.contains('dark');
@@ -107,27 +110,24 @@
 			attributeFilter: ['class']
 		});
 
-		// Add a keyboard shortcut to format the code when Ctrl/Cmd + S is pressed
-		// Override the default browser save functionality
-
-		const handleSave = async (e) => {
+		const keydownHandler = async (e) => {
 			if ((e.ctrlKey || e.metaKey) && e.key === 's') {
 				e.preventDefault();
-				const res = await formatPythonCodeHandler().catch((error) => {
-					return null;
-				});
+				dispatch('save');
+			}
 
-				if (res) {
-					dispatch('save');
-				}
+			// Format code when Ctrl + Shift + F is pressed
+			if ((e.ctrlKey || e.metaKey) && e.shiftKey && e.key === 'f') {
+				e.preventDefault();
+				await formatPythonCodeHandler();
 			}
 		};
 
-		document.addEventListener('keydown', handleSave);
+		document.addEventListener('keydown', keydownHandler);
 
 		return () => {
 			observer.disconnect();
-			document.removeEventListener('keydown', handleSave);
+			document.removeEventListener('keydown', keydownHandler);
 		};
 	});
 </script>

+ 11 - 0
src/lib/components/icons/WrenchSolid.svelte

@@ -0,0 +1,11 @@
+<script lang="ts">
+	export let className = 'size-4';
+</script>
+
+<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 24 24" fill="currentColor" class={className}>
+	<path
+		fill-rule="evenodd"
+		d="M12 6.75a5.25 5.25 0 0 1 6.775-5.025.75.75 0 0 1 .313 1.248l-3.32 3.319c.063.475.276.934.641 1.299.365.365.824.578 1.3.64l3.318-3.319a.75.75 0 0 1 1.248.313 5.25 5.25 0 0 1-5.472 6.756c-1.018-.086-1.87.1-2.309.634L7.344 21.3A3.298 3.298 0 1 1 2.7 16.657l8.684-7.151c.533-.44.72-1.291.634-2.309A5.342 5.342 0 0 1 12 6.75ZM4.117 19.125a.75.75 0 0 1 .75-.75h.008a.75.75 0 0 1 .75.75v.008a.75.75 0 0 1-.75.75h-.008a.75.75 0 0 1-.75-.75v-.008Z"
+		clip-rule="evenodd"
+	/>
+</svg>

+ 225 - 1
src/lib/components/workspace/Tools.svelte

@@ -4,12 +4,23 @@
 	const { saveAs } = fileSaver;
 
 	import { onMount, getContext } from 'svelte';
-	import { WEBUI_NAME, prompts } from '$lib/stores';
+	import { WEBUI_NAME, prompts, tools } from '$lib/stores';
 	import { createNewPrompt, deletePromptByCommand, getPrompts } from '$lib/apis/prompts';
 
 	import { goto } from '$app/navigation';
+	import {
+		createNewTool,
+		deleteToolById,
+		exportTools,
+		getToolById,
+		getTools
+	} from '$lib/apis/tools';
 
 	const i18n = getContext('i18n');
+
+	let toolsImportInputElement: HTMLInputElement;
+	let importFiles;
+
 	let query = '';
 </script>
 
@@ -65,3 +76,216 @@
 	</div>
 </div>
 <hr class=" dark:border-gray-850 my-2.5" />
+
+<div class="my-3 mb-5">
+	{#each $tools.filter((t) => query === '' || t.name
+				.toLowerCase()
+				.includes(query.toLowerCase()) || t.id.toLowerCase().includes(query.toLowerCase())) as tool}
+		<button
+			class=" flex space-x-4 cursor-pointer w-full px-3 py-2 dark:hover:bg-white/5 hover:bg-black/5 rounded-xl"
+			type="button"
+			on:click={() => {
+				goto(`/workspace/tools/edit?id=${encodeURIComponent(tool.id)}`);
+			}}
+		>
+			<div class=" flex flex-1 space-x-4 cursor-pointer w-full">
+				<a
+					href={`/workspace/tools/edit?id=${encodeURIComponent(tool.id)}`}
+					class="flex items-center text-left"
+				>
+					<div class=" flex-1 self-center pl-5">
+						<div class=" font-semibold flex items-center gap-1.5">
+							<div>
+								{tool.name}
+							</div>
+							<div class=" text-gray-500 text-xs font-medium">{tool.id}</div>
+						</div>
+						<div class=" text-xs overflow-hidden text-ellipsis line-clamp-1">
+							{tool.meta.description}
+						</div>
+					</div>
+				</a>
+			</div>
+			<div class="flex flex-row space-x-1 self-center">
+				<a
+					class="self-center w-fit text-sm px-2 py-2 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl"
+					type="button"
+					href={`/workspace/tools/edit?id=${encodeURIComponent(tool.id)}`}
+				>
+					<svg
+						xmlns="http://www.w3.org/2000/svg"
+						fill="none"
+						viewBox="0 0 24 24"
+						stroke-width="1.5"
+						stroke="currentColor"
+						class="w-4 h-4"
+					>
+						<path
+							stroke-linecap="round"
+							stroke-linejoin="round"
+							d="M16.862 4.487l1.687-1.688a1.875 1.875 0 112.652 2.652L6.832 19.82a4.5 4.5 0 01-1.897 1.13l-2.685.8.8-2.685a4.5 4.5 0 011.13-1.897L16.863 4.487zm0 0L19.5 7.125"
+						/>
+					</svg>
+				</a>
+
+				<button
+					class="self-center w-fit text-sm px-2 py-2 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl"
+					type="button"
+					on:click={async () => {
+						const _tool = await getToolById(localStorage.token, tool.id).catch((error) => {
+							toast.error(error);
+							return null;
+						});
+
+						if (_tool) {
+							sessionStorage.tool = JSON.stringify({
+								..._tool,
+								id: `${_tool.id}_clone`,
+								name: `${_tool.name} (Clone)`
+							});
+							goto('/workspace/tools/create');
+						}
+					}}
+				>
+					<svg
+						xmlns="http://www.w3.org/2000/svg"
+						fill="none"
+						viewBox="0 0 24 24"
+						stroke-width="1.5"
+						stroke="currentColor"
+						class="w-4 h-4"
+					>
+						<path
+							stroke-linecap="round"
+							stroke-linejoin="round"
+							d="M15.75 17.25v3.375c0 .621-.504 1.125-1.125 1.125h-9.75a1.125 1.125 0 0 1-1.125-1.125V7.875c0-.621.504-1.125 1.125-1.125H6.75a9.06 9.06 0 0 1 1.5.124m7.5 10.376h3.375c.621 0 1.125-.504 1.125-1.125V11.25c0-4.46-3.243-8.161-7.5-8.876a9.06 9.06 0 0 0-1.5-.124H9.375c-.621 0-1.125.504-1.125 1.125v3.5m7.5 10.375H9.375a1.125 1.125 0 0 1-1.125-1.125v-9.25m12 6.625v-1.875a3.375 3.375 0 0 0-3.375-3.375h-1.5a1.125 1.125 0 0 1-1.125-1.125v-1.5a3.375 3.375 0 0 0-3.375-3.375H9.75"
+						/>
+					</svg>
+				</button>
+
+				<button
+					class="self-center w-fit text-sm px-2 py-2 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl"
+					type="button"
+					on:click={async () => {
+						const res = await deleteToolById(localStorage.token, tool.id).catch((error) => {
+							toast.error(error);
+							return null;
+						});
+
+						if (res) {
+							toast.success('Tool deleted successfully');
+							tools.set(await getTools(localStorage.token));
+						}
+					}}
+				>
+					<svg
+						xmlns="http://www.w3.org/2000/svg"
+						fill="none"
+						viewBox="0 0 24 24"
+						stroke-width="1.5"
+						stroke="currentColor"
+						class="w-4 h-4"
+					>
+						<path
+							stroke-linecap="round"
+							stroke-linejoin="round"
+							d="M14.74 9l-.346 9m-4.788 0L9.26 9m9.968-3.21c.342.052.682.107 1.022.166m-1.022-.165L18.16 19.673a2.25 2.25 0 01-2.244 2.077H8.084a2.25 2.25 0 01-2.244-2.077L4.772 5.79m14.456 0a48.108 48.108 0 00-3.478-.397m-12 .562c.34-.059.68-.114 1.022-.165m0 0a48.11 48.11 0 013.478-.397m7.5 0v-.916c0-1.18-.91-2.164-2.09-2.201a51.964 51.964 0 00-3.32 0c-1.18.037-2.09 1.022-2.09 2.201v.916m7.5 0a48.667 48.667 0 00-7.5 0"
+						/>
+					</svg>
+				</button>
+			</div>
+		</button>
+	{/each}
+</div>
+
+<div class=" flex justify-end w-full mb-2">
+	<div class="flex space-x-2">
+		<input
+			id="documents-import-input"
+			bind:this={toolsImportInputElement}
+			bind:files={importFiles}
+			type="file"
+			accept=".json"
+			hidden
+			on:change={() => {
+				console.log(importFiles);
+
+				const reader = new FileReader();
+				reader.onload = async (event) => {
+					const _tools = JSON.parse(event.target.result);
+					console.log(_tools);
+
+					for (const tool of _tools) {
+						const res = await createNewTool(localStorage.token, tool).catch((error) => {
+							toast.error(error);
+							return null;
+						});
+					}
+
+					toast.success('Tool imported successfully');
+					tools.set(await getTools(localStorage.token));
+				};
+
+				reader.readAsText(importFiles[0]);
+			}}
+		/>
+
+		<button
+			class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition"
+			on:click={() => {
+				toolsImportInputElement.click();
+			}}
+		>
+			<div class=" self-center mr-2 font-medium">{$i18n.t('Import Tools')}</div>
+
+			<div class=" self-center">
+				<svg
+					xmlns="http://www.w3.org/2000/svg"
+					viewBox="0 0 16 16"
+					fill="currentColor"
+					class="w-4 h-4"
+				>
+					<path
+						fill-rule="evenodd"
+						d="M4 2a1.5 1.5 0 0 0-1.5 1.5v9A1.5 1.5 0 0 0 4 14h8a1.5 1.5 0 0 0 1.5-1.5V6.621a1.5 1.5 0 0 0-.44-1.06L9.94 2.439A1.5 1.5 0 0 0 8.878 2H4Zm4 9.5a.75.75 0 0 1-.75-.75V8.06l-.72.72a.75.75 0 0 1-1.06-1.06l2-2a.75.75 0 0 1 1.06 0l2 2a.75.75 0 1 1-1.06 1.06l-.72-.72v2.69a.75.75 0 0 1-.75.75Z"
+						clip-rule="evenodd"
+					/>
+				</svg>
+			</div>
+		</button>
+
+		<button
+			class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition"
+			on:click={async () => {
+				const _tools = await exportTools(localStorage.token).catch((error) => {
+					toast.error(error);
+					return null;
+				});
+
+				if (_tools) {
+					let blob = new Blob([JSON.stringify(_tools)], {
+						type: 'application/json'
+					});
+					saveAs(blob, `tools-export-${Date.now()}.json`);
+				}
+			}}
+		>
+			<div class=" self-center mr-2 font-medium">{$i18n.t('Export Tools')}</div>
+
+			<div class=" self-center">
+				<svg
+					xmlns="http://www.w3.org/2000/svg"
+					viewBox="0 0 16 16"
+					fill="currentColor"
+					class="w-4 h-4"
+				>
+					<path
+						fill-rule="evenodd"
+						d="M4 2a1.5 1.5 0 0 0-1.5 1.5v9A1.5 1.5 0 0 0 4 14h8a1.5 1.5 0 0 0 1.5-1.5V6.621a1.5 1.5 0 0 0-.44-1.06L9.94 2.439A1.5 1.5 0 0 0 8.878 2H4Zm4 3.5a.75.75 0 0 1 .75.75v2.69l.72-.72a.75.75 0 1 1 1.06 1.06l-2 2a.75.75 0 0 1-1.06 0l-2-2a.75.75 0 0 1 1.06-1.06l.72.72V6.25A.75.75 0 0 1 8 5.5Z"
+						clip-rule="evenodd"
+					/>
+				</svg>
+			</div>
+		</button>
+	</div>
+</div>

+ 55 - 6
src/lib/components/workspace/Tools/CodeEditor.svelte

@@ -1,15 +1,15 @@
 <script lang="ts">
 	import CodeEditor from '$lib/components/common/CodeEditor.svelte';
+	import { createEventDispatcher } from 'svelte';
+
+	const dispatch = createEventDispatcher();
 
-	export let saveHandler: Function;
 	export let value = '';
 
 	let codeEditor;
-
-	let boilerplate = `# Tip: Use Ctrl/Cmd + S to format the code
-
-from datetime import datetime
+	let boilerplate = `import os
 import requests
+from datetime import datetime
 
 
 class Tools:
@@ -20,6 +20,20 @@ class Tools:
     # Use Sphinx-style docstrings to document your tools, they will be used for generating tools specifications
     # Please refer to function_calling_filter_pipeline.py file from pipelines project for an example
 
+    def get_environment_variable(self, variable_name: str) -> str:
+        """
+        Get the value of an environment variable.
+        :param variable_name: The name of the environment variable.
+        :return: The value of the environment variable or a message if it doesn't exist.
+        """
+        value = os.getenv(variable_name)
+        if value is not None:
+            return (
+                f"The value of the environment variable '{variable_name}' is '{value}'"
+            )
+        else:
+            return f"The environment variable '{variable_name}' does not exist."
+
     def get_current_time(self) -> str:
         """
         Get the current time.
@@ -45,6 +59,41 @@ class Tools:
             print(e)
             return "Invalid equation"
 
+    def get_current_weather(self, city: str) -> str:
+        """
+        Get the current weather for a given city.
+        :param city: The name of the city to get the weather for.
+        :return: The current weather information or an error message.
+        """
+        api_key = os.getenv("OPENWEATHER_API_KEY")
+        if not api_key:
+            return (
+                "API key is not set in the environment variable 'OPENWEATHER_API_KEY'."
+            )
+
+        base_url = "http://api.openweathermap.org/data/2.5/weather"
+        params = {
+            "q": city,
+            "appid": api_key,
+            "units": "metric",  # Optional: Use 'imperial' for Fahrenheit
+        }
+
+        try:
+            response = requests.get(base_url, params=params)
+            response.raise_for_status()  # Raise HTTPError for bad responses (4xx and 5xx)
+            data = response.json()
+
+            if data.get("cod") != 200:
+                return f"Error fetching weather data: {data.get('message')}"
+
+            weather_description = data["weather"][0]["description"]
+            temperature = data["main"]["temp"]
+            humidity = data["main"]["humidity"]
+            wind_speed = data["wind"]["speed"]
+
+            return f"Weather in {city}: {temperature}°C"
+        except requests.RequestException as e:
+            return f"Error fetching weather data: {str(e)}"
 `;
 
 	export const formatHandler = async () => {
@@ -60,6 +109,6 @@ class Tools:
 	{boilerplate}
 	bind:this={codeEditor}
 	on:save={() => {
-		saveHandler();
+		dispatch('save');
 	}}
 />

+ 38 - 16
src/lib/components/workspace/Tools/ToolkitEditor.svelte

@@ -1,22 +1,27 @@
 <script>
-	import { getContext } from 'svelte';
+	import { getContext, createEventDispatcher, onMount } from 'svelte';
 
 	const i18n = getContext('i18n');
 
 	import CodeEditor from './CodeEditor.svelte';
 	import { goto } from '$app/navigation';
 
+	const dispatch = createEventDispatcher();
+
+	let formElement = null;
 	let loading = false;
 
-	let id = '';
-	let name = '';
-	let meta = {
+	export let edit = false;
+	export let clone = false;
+
+	export let id = '';
+	export let name = '';
+	export let meta = {
 		description: ''
 	};
+	export let content = '';
 
-	let code = '';
-
-	$: if (name) {
+	$: if (name && !edit && !clone) {
 		id = name.replace(/\s+/g, '_').toLowerCase();
 	}
 
@@ -24,8 +29,12 @@
 
 	const saveHandler = async () => {
 		loading = true;
-		// Call the API to save the toolkit
-		console.log('saveHandler');
+		dispatch('save', {
+			id,
+			name,
+			meta,
+			content
+		});
 	};
 
 	const submitHandler = async () => {
@@ -42,13 +51,20 @@
 
 <div class=" flex flex-col justify-between w-full overflow-y-auto h-full">
 	<div class="mx-auto w-full md:px-0 h-full">
-		<div class=" flex flex-col max-h-[100dvh] h-full">
+		<form
+			bind:this={formElement}
+			class=" flex flex-col max-h-[100dvh] h-full"
+			on:submit|preventDefault={() => {
+				submitHandler();
+			}}
+		>
 			<div class="mb-2.5">
 				<button
 					class="flex space-x-1"
 					on:click={() => {
 						goto('/workspace/tools');
 					}}
+					type="button"
 				>
 					<div class=" self-center">
 						<svg
@@ -80,11 +96,12 @@
 						/>
 
 						<input
-							class="w-full px-3 py-2 text-sm font-medium bg-gray-50 dark:bg-gray-850 dark:text-gray-200 rounded-lg outline-none"
+							class="w-full px-3 py-2 text-sm font-medium disabled:text-gray-300 dark:disabled:text-gray-700 bg-gray-50 dark:bg-gray-850 dark:text-gray-200 rounded-lg outline-none"
 							type="text"
 							placeholder="Toolkit ID (e.g. my_toolkit)"
 							bind:value={id}
 							required
+							disabled={edit}
 						/>
 					</div>
 					<input
@@ -97,20 +114,25 @@
 				</div>
 
 				<div class="mb-2 flex-1 overflow-auto h-0 rounded-lg">
-					<CodeEditor bind:value={code} bind:this={codeEditor} {saveHandler} />
+					<CodeEditor
+						bind:value={content}
+						bind:this={codeEditor}
+						on:save={() => {
+							if (formElement) {
+								formElement.requestSubmit();
+							}
+						}}
+					/>
 				</div>
 
 				<div class="pb-3 flex justify-end">
 					<button
 						class="px-3 py-1.5 text-sm font-medium bg-emerald-600 hover:bg-emerald-700 text-gray-50 transition rounded-lg"
-						on:click={() => {
-							submitHandler();
-						}}
 					>
 						{$i18n.t('Save')}
 					</button>
 				</div>
 			</div>
-		</div>
+		</form>
 	</div>
 </div>

+ 10 - 16
src/lib/stores/index.ts

@@ -23,24 +23,11 @@ export const chatId = writable('');
 
 export const chats = writable([]);
 export const tags = writable([]);
-export const models: Writable<Model[]> = writable([]);
 
-export const modelfiles = writable([]);
+export const models: Writable<Model[]> = writable([]);
 export const prompts: Writable<Prompt[]> = writable([]);
-export const documents = writable([
-	{
-		collection_name: 'collection_name',
-		filename: 'filename',
-		name: 'name',
-		title: 'title'
-	},
-	{
-		collection_name: 'collection_name1',
-		filename: 'filename1',
-		name: 'name1',
-		title: 'title1'
-	}
-]);
+export const documents: Writable<Document[]> = writable([]);
+export const tools = writable([]);
 
 export const banners: Writable<Banner[]> = writable([]);
 
@@ -135,6 +122,13 @@ type Prompt = {
 	timestamp: number;
 };
 
+type Document = {
+	collection_name: string;
+	filename: string;
+	name: string;
+	title: string;
+};
+
 type Config = {
 	status: boolean;
 	name: string;

+ 11 - 17
src/routes/(app)/+layout.svelte

@@ -8,11 +8,14 @@
 	import { goto } from '$app/navigation';
 
 	import { getModels as _getModels } from '$lib/apis';
-	import { getOllamaVersion } from '$lib/apis/ollama';
-	import { getPrompts } from '$lib/apis/prompts';
+	import { getAllChatTags } from '$lib/apis/chats';
 
+	import { getPrompts } from '$lib/apis/prompts';
 	import { getDocs } from '$lib/apis/documents';
-	import { getAllChatTags } from '$lib/apis/chats';
+	import { getTools } from '$lib/apis/tools';
+
+	import { getBanners } from '$lib/apis/configs';
+	import { getUserSettings } from '$lib/apis/users';
 
 	import {
 		user,
@@ -25,33 +28,21 @@
 		banners,
 		showChangelog,
 		config,
-		showCallOverlay
+		showCallOverlay,
+		tools
 	} from '$lib/stores';
-	import { REQUIRED_OLLAMA_VERSION, WEBUI_API_BASE_URL } from '$lib/constants';
-	import { compareVersion } from '$lib/utils';
 
 	import SettingsModal from '$lib/components/chat/SettingsModal.svelte';
 	import Sidebar from '$lib/components/layout/Sidebar.svelte';
-	import ShortcutsModal from '$lib/components/chat/ShortcutsModal.svelte';
 	import ChangelogModal from '$lib/components/ChangelogModal.svelte';
-	import Tooltip from '$lib/components/common/Tooltip.svelte';
-	import { getBanners } from '$lib/apis/configs';
-	import { getUserSettings } from '$lib/apis/users';
-	import Help from '$lib/components/layout/Help.svelte';
 	import AccountPending from '$lib/components/layout/Overlay/AccountPending.svelte';
-	import { error } from '@sveltejs/kit';
-	import CallOverlay from '$lib/components/chat/MessageInput/CallOverlay.svelte';
 
 	const i18n = getContext('i18n');
 
-	let ollamaVersion = '';
 	let loaded = false;
-	let showShortcutsButtonElement: HTMLButtonElement;
 	let DB = null;
 	let localDBChats = [];
 
-	let showShortcuts = false;
-
 	const getModels = async () => {
 		return _getModels(localStorage.token);
 	};
@@ -99,6 +90,9 @@
 				(async () => {
 					documents.set(await getDocs(localStorage.token));
 				})(),
+				(async () => {
+					tools.set(await getTools(localStorage.token));
+				})(),
 				(async () => {
 					banners.set(await getBanners(localStorage.token));
 				})(),

+ 48 - 1
src/routes/(app)/workspace/tools/create/+page.svelte

@@ -1,5 +1,52 @@
 <script>
+	import { goto } from '$app/navigation';
+	import { createNewTool, getTools } from '$lib/apis/tools';
 	import ToolkitEditor from '$lib/components/workspace/Tools/ToolkitEditor.svelte';
+	import { tools } from '$lib/stores';
+	import { onMount } from 'svelte';
+	import { toast } from 'svelte-sonner';
+
+	let clone = false;
+	let tool = null;
+
+	const saveHandler = async (data) => {
+		console.log(data);
+		const res = await createNewTool(localStorage.token, {
+			id: data.id,
+			name: data.name,
+			meta: data.meta,
+			content: data.content
+		}).catch((error) => {
+			toast.error(error);
+			return null;
+		});
+
+		if (res) {
+			toast.success('Tool created successfully');
+			tools.set(await getTools(localStorage.token));
+
+			await goto('/workspace/tools');
+		}
+	};
+
+	onMount(() => {
+		console.log('mounted');
+
+		if (sessionStorage.tool) {
+			tool = JSON.parse(sessionStorage.tool);
+			sessionStorage.removeItem('tool');
+			clone = true;
+		}
+	});
 </script>
 
-<ToolkitEditor />
+<ToolkitEditor
+	id={tool?.id ?? ''}
+	name={tool?.name ?? ''}
+	meta={tool?.meta ?? { description: '' }}
+	content={tool?.content ?? ''}
+	{clone}
+	on:save={(e) => {
+		saveHandler(e.detail);
+	}}
+/>

+ 62 - 1
src/routes/(app)/workspace/tools/edit/+page.svelte

@@ -1,5 +1,66 @@
 <script>
+	import { goto } from '$app/navigation';
+	import { page } from '$app/stores';
+	import { getToolById, getTools, updateToolById } from '$lib/apis/tools';
+	import Spinner from '$lib/components/common/Spinner.svelte';
 	import ToolkitEditor from '$lib/components/workspace/Tools/ToolkitEditor.svelte';
+	import { tools } from '$lib/stores';
+	import { onMount } from 'svelte';
+	import { toast } from 'svelte-sonner';
+
+	let tool = null;
+
+	const saveHandler = async (data) => {
+		console.log(data);
+		const res = await updateToolById(localStorage.token, tool.id, {
+			id: data.id,
+			name: data.name,
+			meta: data.meta,
+			content: data.content
+		}).catch((error) => {
+			toast.error(error);
+			return null;
+		});
+
+		if (res) {
+			toast.success('Tool updated successfully');
+			tools.set(await getTools(localStorage.token));
+
+			await goto('/workspace/tools');
+		}
+	};
+
+	onMount(async () => {
+		console.log('mounted');
+		const id = $page.url.searchParams.get('id');
+
+		if (id) {
+			tool = await getToolById(localStorage.token, id).catch((error) => {
+				toast.error(error);
+				goto('/workspace/tools');
+				return null;
+			});
+
+			console.log(tool);
+		}
+	});
 </script>
 
-<ToolkitEditor />
+{#if tool}
+	<ToolkitEditor
+		edit={true}
+		id={tool.id}
+		name={tool.name}
+		meta={tool.meta}
+		content={tool.content}
+		on:save={(e) => {
+			saveHandler(e.detail);
+		}}
+	/>
+{:else}
+	<div class="flex items-center justify-center h-full">
+		<div class=" pb-16">
+			<Spinner />
+		</div>
+	</div>
+{/if}