Browse Source

feat: tools backend

Timothy J. Baek 10 months ago
parent
commit
3a96e1f109

+ 61 - 0
backend/apps/webui/internal/migrations/012_add_tools.py

@@ -0,0 +1,61 @@
+"""Peewee migrations -- 009_add_models.py.
+
+Some examples (model - class or model name)::
+
+    > Model = migrator.orm['table_name']            # Return model in current state by name
+    > Model = migrator.ModelClass                   # Return model in current state by name
+
+    > migrator.sql(sql)                             # Run custom SQL
+    > migrator.run(func, *args, **kwargs)           # Run python function with the given args
+    > migrator.create_model(Model)                  # Create a model (could be used as decorator)
+    > migrator.remove_model(model, cascade=True)    # Remove a model
+    > migrator.add_fields(model, **fields)          # Add fields to a model
+    > migrator.change_fields(model, **fields)       # Change fields
+    > migrator.remove_fields(model, *field_names, cascade=True)
+    > migrator.rename_field(model, old_field_name, new_field_name)
+    > migrator.rename_table(model, new_table_name)
+    > migrator.add_index(model, *col_names, unique=False)
+    > migrator.add_not_null(model, *field_names)
+    > migrator.add_default(model, field_name, default)
+    > migrator.add_constraint(model, name, sql)
+    > migrator.drop_index(model, *col_names)
+    > migrator.drop_not_null(model, *field_names)
+    > migrator.drop_constraints(model, *constraints)
+
+"""
+
+from contextlib import suppress
+
+import peewee as pw
+from peewee_migrate import Migrator
+
+
+with suppress(ImportError):
+    import playhouse.postgres_ext as pw_pext
+
+
+def migrate(migrator: Migrator, database: pw.Database, *, fake=False):
+    """Write your migrations here."""
+
+    @migrator.create_model
+    class Tool(pw.Model):
+        id = pw.TextField(unique=True)
+        user_id = pw.TextField()
+
+        name = pw.TextField()
+        content = pw.TextField()
+        specs = pw.TextField()
+
+        meta = pw.TextField()
+
+        created_at = pw.BigIntegerField(null=False)
+        updated_at = pw.BigIntegerField(null=False)
+
+        class Meta:
+            table_name = "tool"
+
+
+def rollback(migrator: Migrator, database: pw.Database, *, fake=False):
+    """Write your rollback migrations here."""
+
+    migrator.remove_model("tool")

+ 3 - 1
backend/apps/webui/main.py

@@ -6,6 +6,7 @@ from apps.webui.routers import (
     users,
     chats,
     documents,
+    tools,
     models,
     prompts,
     configs,
@@ -26,8 +27,8 @@ from config import (
     WEBUI_AUTH_TRUSTED_EMAIL_HEADER,
     JWT_EXPIRES_IN,
     WEBUI_BANNERS,
-    AppConfig,
     ENABLE_COMMUNITY_SHARING,
+    AppConfig,
 )
 
 app = FastAPI()
@@ -70,6 +71,7 @@ app.include_router(users.router, prefix="/users", tags=["users"])
 app.include_router(chats.router, prefix="/chats", tags=["chats"])
 
 app.include_router(documents.router, prefix="/documents", tags=["documents"])
+app.include_router(tools.router, prefix="/tools", tags=["tools"])
 app.include_router(models.router, prefix="/models", tags=["models"])
 app.include_router(prompts.router, prefix="/prompts", tags=["prompts"])
 app.include_router(memories.router, prefix="/memories", tags=["memories"])

+ 131 - 0
backend/apps/webui/models/tools.py

@@ -0,0 +1,131 @@
+from pydantic import BaseModel
+from peewee import *
+from playhouse.shortcuts import model_to_dict
+from typing import List, Union, Optional
+import time
+import logging
+from apps.webui.internal.db import DB, JSONField
+
+import json
+
+from config import SRC_LOG_LEVELS
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["MODELS"])
+
+####################
+# Tools DB Schema
+####################
+
+
+class Tool(Model):
+    id = CharField(unique=True)
+    user_id = CharField()
+    name = TextField()
+    content = TextField()
+    specs = JSONField()
+    meta = JSONField()
+    updated_at = BigIntegerField()
+    created_at = BigIntegerField()
+
+    class Meta:
+        database = DB
+
+
+class ToolMeta(BaseModel):
+    description: Optional[str] = None
+
+
+class ToolModel(BaseModel):
+    id: str
+    user_id: str
+    name: str
+    content: str
+    specs: dict
+    meta: ToolMeta
+    updated_at: int  # timestamp in epoch
+    created_at: int  # timestamp in epoch
+
+
+####################
+# Forms
+####################
+
+
+class ToolResponse(BaseModel):
+    id: str
+    user_id: str
+    name: str
+    meta: ToolMeta
+    updated_at: int  # timestamp in epoch
+    created_at: int  # timestamp in epoch
+
+
+class ToolForm(BaseModel):
+    id: str
+    name: str
+    content: str
+    meta: ToolMeta
+
+
+class ToolsTable:
+    def __init__(self, db):
+        self.db = db
+        self.db.create_tables([Tool])
+
+    def insert_new_tool(
+        self, user_id: str, form_data: ToolForm, specs: dict
+    ) -> Optional[ToolModel]:
+        tool = ToolModel(
+            **{
+                **form_data.model_dump(),
+                "specs": specs,
+                "user_id": user_id,
+                "updated_at": int(time.time()),
+                "created_at": int(time.time()),
+            }
+        )
+
+        try:
+            result = Tool.create(**tool.model_dump())
+            if result:
+                return tool
+            else:
+                return None
+        except:
+            return None
+
+    def get_tool_by_id(self, id: str) -> Optional[ToolModel]:
+        try:
+            tool = Tool.get(Tool.id == id)
+            return ToolModel(**model_to_dict(tool))
+        except:
+            return None
+
+    def get_tools(self) -> List[ToolModel]:
+        return [ToolModel(**model_to_dict(tool)) for tool in Tool.select()]
+
+    def update_tool_by_id(self, id: str, updated: dict) -> Optional[ToolModel]:
+        try:
+            query = Tool.update(
+                **updated,
+                updated_at=int(time.time()),
+            ).where(Tool.id == id)
+            query.execute()
+
+            tool = Tool.get(Tool.id == id)
+            return ToolModel(**model_to_dict(tool))
+        except:
+            return None
+
+    def delete_tool_by_id(self, id: str) -> bool:
+        try:
+            query = Tool.delete().where((Tool.id == id))
+            query.execute()  # Remove the rows, return number of rows removed.
+
+            return True
+        except:
+            return False
+
+
+Tools = ToolsTable(DB)

+ 162 - 0
backend/apps/webui/routers/tools.py

@@ -0,0 +1,162 @@
+from fastapi import Depends, FastAPI, HTTPException, status
+from datetime import datetime, timedelta
+from typing import List, Union, Optional
+
+from fastapi import APIRouter
+from pydantic import BaseModel
+import json
+
+from apps.webui.models.tools import Tools, ToolForm, ToolModel, ToolResponse
+
+from utils.utils import get_current_user, get_admin_user
+from utils.tools import get_tools_specs
+from constants import ERROR_MESSAGES
+
+from importlib import util
+import os
+
+from config import DATA_DIR
+
+TOOLS_DIR = f"{DATA_DIR}/tools"
+os.makedirs(TOOLS_DIR, exist_ok=True)
+
+TOOLS = {}
+
+
+router = APIRouter()
+
+
+def load_toolkit_module_from_path(tools_id, tools_path):
+    spec = util.spec_from_file_location(tools_id, tools_path)
+    module = util.module_from_spec(spec)
+
+    try:
+        spec.loader.exec_module(module)
+        print(f"Loaded module: {module.__name__}")
+        if hasattr(module, "Tools"):
+            return module.Tools()
+        else:
+            raise Exception("No Tools class found")
+    except Exception as e:
+        print(f"Error loading module: {tools_id}")
+
+        # Move the file to the error folder
+        os.rename(tools_path, f"{tools_path}.error")
+        raise e
+
+
+############################
+# GetToolkits
+############################
+
+
+@router.get("/", response_model=List[ToolResponse])
+async def get_toolkits(user=Depends(get_current_user)):
+    toolkits = [ToolResponse(**toolkit) for toolkit in Tools.get_tools()]
+    return toolkits
+
+
+############################
+# CreateNewToolKit
+############################
+
+
+@router.post("/create", response_model=Optional[ToolResponse])
+async def create_new_toolkit(form_data: ToolForm, user=Depends(get_admin_user)):
+    toolkit = Tools.get_tool_by_id(form_data.id)
+    if toolkit == None:
+        toolkit_path = os.path.join(TOOLS_DIR, f"{form_data.id}.py")
+        try:
+            with open(toolkit_path, "w") as tool_file:
+                tool_file.write(form_data.content)
+
+            toolkit_module = load_toolkit_module_from_path(form_data.id, toolkit_path)
+            TOOLS[form_data.id] = toolkit_module
+
+            specs = get_tools_specs(TOOLS[form_data.id])
+            toolkit = Tools.insert_new_tool(user.id, form_data, specs)
+
+            if toolkit:
+                return ToolResponse(**toolkit)
+            else:
+                raise HTTPException(
+                    status_code=status.HTTP_400_BAD_REQUEST,
+                    detail=ERROR_MESSAGES.FILE_EXISTS,
+                )
+        except Exception as e:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT(e),
+            )
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.NAME_TAG_TAKEN,
+        )
+
+
+############################
+# GetToolkitById
+############################
+
+
+@router.get("/id/{id}", response_model=Optional[ToolResponse])
+async def get_toolkit_by_id(id: str, user=Depends(get_admin_user)):
+    toolkit = Tools.get_tool_by_id(id)
+
+    if toolkit:
+        return ToolResponse(**toolkit)
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+
+############################
+# UpdateToolkitById
+############################
+
+
+@router.post("/id/{id}/update", response_model=Optional[ToolResponse])
+async def update_toolkit_by_id(
+    id: str, form_data: ToolForm, user=Depends(get_admin_user)
+):
+    toolkit_path = os.path.join(TOOLS_DIR, f"{id}.py")
+
+    try:
+        with open(toolkit_path, "w") as tool_file:
+            tool_file.write(form_data.content)
+
+        toolkit_module = load_toolkit_module_from_path(id, toolkit_path)
+        TOOLS[id] = toolkit_module
+
+        specs = get_tools_specs(TOOLS[id])
+        toolkit = Tools.update_tool_by_id(
+            id, {**form_data.model_dump(), "specs": specs}
+        )
+
+        if toolkit:
+            return ToolResponse(**toolkit)
+        else:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT("Error updating toolkit"),
+            )
+
+    except Exception as e:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.DEFAULT(e),
+        )
+
+
+############################
+# DeleteToolkitById
+############################
+
+
+@router.delete("/id/{id}/delete", response_model=bool)
+async def delete_toolkit_by_id(id: str, user=Depends(get_admin_user)):
+    result = Tools.delete_tool_by_id(id)
+    return result

+ 73 - 0
backend/utils/tools.py

@@ -0,0 +1,73 @@
+import inspect
+from typing import get_type_hints, List, Dict, Any
+
+
+def doc_to_dict(docstring):
+    lines = docstring.split("\n")
+    description = lines[1].strip()
+    param_dict = {}
+
+    for line in lines:
+        if ":param" in line:
+            line = line.replace(":param", "").strip()
+            param, desc = line.split(":", 1)
+            param_dict[param.strip()] = desc.strip()
+    ret_dict = {"description": description, "params": param_dict}
+    return ret_dict
+
+
+def get_tools_specs(tools) -> List[dict]:
+    function_list = [
+        {"name": func, "function": getattr(tools, func)}
+        for func in dir(tools)
+        if callable(getattr(tools, func)) and not func.startswith("__")
+    ]
+
+    specs = []
+    for function_item in function_list:
+        function_name = function_item["name"]
+        function = function_item["function"]
+
+        function_doc = doc_to_dict(function.__doc__ or function_name)
+        specs.append(
+            {
+                "name": function_name,
+                # TODO: multi-line desc?
+                "description": function_doc.get("description", function_name),
+                "parameters": {
+                    "type": "object",
+                    "properties": {
+                        param_name: {
+                            "type": param_annotation.__name__.lower(),
+                            **(
+                                {
+                                    "enum": (
+                                        param_annotation.__args__
+                                        if hasattr(param_annotation, "__args__")
+                                        else None
+                                    )
+                                }
+                                if hasattr(param_annotation, "__args__")
+                                else {}
+                            ),
+                            "description": function_doc.get("params", {}).get(
+                                param_name, param_name
+                            ),
+                        }
+                        for param_name, param_annotation in get_type_hints(
+                            function
+                        ).items()
+                        if param_name != "return"
+                    },
+                    "required": [
+                        name
+                        for name, param in inspect.signature(
+                            function
+                        ).parameters.items()
+                        if param.default is param.empty
+                    ],
+                },
+            }
+        )
+
+    return specs

+ 169 - 1
src/lib/components/workspace/Tools.svelte

@@ -4,12 +4,16 @@
 	const { saveAs } = fileSaver;
 
 	import { onMount, getContext } from 'svelte';
-	import { WEBUI_NAME, prompts } from '$lib/stores';
+	import { WEBUI_NAME, prompts, tools } from '$lib/stores';
 	import { createNewPrompt, deletePromptByCommand, getPrompts } from '$lib/apis/prompts';
 
 	import { goto } from '$app/navigation';
 
 	const i18n = getContext('i18n');
+
+	let toolsImportInputElement: HTMLInputElement;
+	let importFiles;
+
 	let query = '';
 </script>
 
@@ -65,3 +69,167 @@
 	</div>
 </div>
 <hr class=" dark:border-gray-850 my-2.5" />
+
+<div class="my-3 mb-5">
+	{#each $tools.filter((t) => query === '' || t.name.includes(query)) as tool}
+		<div
+			class=" flex space-x-4 cursor-pointer w-full px-3 py-2 dark:hover:bg-white/5 hover:bg-black/5 rounded-xl"
+		>
+			<div class=" flex flex-1 space-x-4 cursor-pointer w-full">
+				<a href={`/workspace/tools/edit?id=${encodeURIComponent(tool.id)}`}>
+					<div class=" flex-1 self-center pl-5">
+						<div class=" font-bold">{tool.name}</div>
+						<div class=" text-xs overflow-hidden text-ellipsis line-clamp-1">
+							{tool.meta.description}
+						</div>
+					</div>
+				</a>
+			</div>
+			<div class="flex flex-row space-x-1 self-center">
+				<a
+					class="self-center w-fit text-sm px-2 py-2 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl"
+					type="button"
+					href={`/workspace/tools/edit?command=${encodeURIComponent(tool.id)}`}
+				>
+					<svg
+						xmlns="http://www.w3.org/2000/svg"
+						fill="none"
+						viewBox="0 0 24 24"
+						stroke-width="1.5"
+						stroke="currentColor"
+						class="w-4 h-4"
+					>
+						<path
+							stroke-linecap="round"
+							stroke-linejoin="round"
+							d="M16.862 4.487l1.687-1.688a1.875 1.875 0 112.652 2.652L6.832 19.82a4.5 4.5 0 01-1.897 1.13l-2.685.8.8-2.685a4.5 4.5 0 011.13-1.897L16.863 4.487zm0 0L19.5 7.125"
+						/>
+					</svg>
+				</a>
+
+				<button
+					class="self-center w-fit text-sm px-2 py-2 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl"
+					type="button"
+					on:click={() => {
+						sessionStorage.tool = JSON.stringify(tool);
+						goto('/workspace/tools/create');
+					}}
+				>
+					<svg
+						xmlns="http://www.w3.org/2000/svg"
+						fill="none"
+						viewBox="0 0 24 24"
+						stroke-width="1.5"
+						stroke="currentColor"
+						class="w-4 h-4"
+					>
+						<path
+							stroke-linecap="round"
+							stroke-linejoin="round"
+							d="M15.75 17.25v3.375c0 .621-.504 1.125-1.125 1.125h-9.75a1.125 1.125 0 0 1-1.125-1.125V7.875c0-.621.504-1.125 1.125-1.125H6.75a9.06 9.06 0 0 1 1.5.124m7.5 10.376h3.375c.621 0 1.125-.504 1.125-1.125V11.25c0-4.46-3.243-8.161-7.5-8.876a9.06 9.06 0 0 0-1.5-.124H9.375c-.621 0-1.125.504-1.125 1.125v3.5m7.5 10.375H9.375a1.125 1.125 0 0 1-1.125-1.125v-9.25m12 6.625v-1.875a3.375 3.375 0 0 0-3.375-3.375h-1.5a1.125 1.125 0 0 1-1.125-1.125v-1.5a3.375 3.375 0 0 0-3.375-3.375H9.75"
+						/>
+					</svg>
+				</button>
+
+				<button
+					class="self-center w-fit text-sm px-2 py-2 dark:text-gray-300 dark:hover:text-white hover:bg-black/5 dark:hover:bg-white/5 rounded-xl"
+					type="button"
+					on:click={() => {
+						// deletePrompt(prompt.command);
+						// deleteTool
+					}}
+				>
+					<svg
+						xmlns="http://www.w3.org/2000/svg"
+						fill="none"
+						viewBox="0 0 24 24"
+						stroke-width="1.5"
+						stroke="currentColor"
+						class="w-4 h-4"
+					>
+						<path
+							stroke-linecap="round"
+							stroke-linejoin="round"
+							d="M14.74 9l-.346 9m-4.788 0L9.26 9m9.968-3.21c.342.052.682.107 1.022.166m-1.022-.165L18.16 19.673a2.25 2.25 0 01-2.244 2.077H8.084a2.25 2.25 0 01-2.244-2.077L4.772 5.79m14.456 0a48.108 48.108 0 00-3.478-.397m-12 .562c.34-.059.68-.114 1.022-.165m0 0a48.11 48.11 0 013.478-.397m7.5 0v-.916c0-1.18-.91-2.164-2.09-2.201a51.964 51.964 0 00-3.32 0c-1.18.037-2.09 1.022-2.09 2.201v.916m7.5 0a48.667 48.667 0 00-7.5 0"
+						/>
+					</svg>
+				</button>
+			</div>
+		</div>
+	{/each}
+</div>
+
+<div class=" flex justify-end w-full mb-2">
+	<div class="flex space-x-2">
+		<input
+			id="documents-import-input"
+			bind:this={toolsImportInputElement}
+			bind:files={importFiles}
+			type="file"
+			accept=".json"
+			hidden
+			on:change={() => {
+				console.log(importFiles);
+
+				const reader = new FileReader();
+				reader.onload = async (event) => {
+					const tools = JSON.parse(event.target.result);
+					console.log(tools);
+				};
+
+				reader.readAsText(importFiles[0]);
+			}}
+		/>
+
+		<button
+			class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition"
+			on:click={() => {
+				toolsImportInputElement.click();
+			}}
+		>
+			<div class=" self-center mr-2 font-medium">{$i18n.t('Import Tools')}</div>
+
+			<div class=" self-center">
+				<svg
+					xmlns="http://www.w3.org/2000/svg"
+					viewBox="0 0 16 16"
+					fill="currentColor"
+					class="w-4 h-4"
+				>
+					<path
+						fill-rule="evenodd"
+						d="M4 2a1.5 1.5 0 0 0-1.5 1.5v9A1.5 1.5 0 0 0 4 14h8a1.5 1.5 0 0 0 1.5-1.5V6.621a1.5 1.5 0 0 0-.44-1.06L9.94 2.439A1.5 1.5 0 0 0 8.878 2H4Zm4 9.5a.75.75 0 0 1-.75-.75V8.06l-.72.72a.75.75 0 0 1-1.06-1.06l2-2a.75.75 0 0 1 1.06 0l2 2a.75.75 0 1 1-1.06 1.06l-.72-.72v2.69a.75.75 0 0 1-.75.75Z"
+						clip-rule="evenodd"
+					/>
+				</svg>
+			</div>
+		</button>
+
+		<button
+			class="flex text-xs items-center space-x-1 px-3 py-1.5 rounded-xl bg-gray-50 hover:bg-gray-100 dark:bg-gray-800 dark:hover:bg-gray-700 dark:text-gray-200 transition"
+			on:click={async () => {
+				let blob = new Blob([JSON.stringify($tools)], {
+					type: 'application/json'
+				});
+				saveAs(blob, `tools-export-${Date.now()}.json`);
+			}}
+		>
+			<div class=" self-center mr-2 font-medium">{$i18n.t('Export Tools')}</div>
+
+			<div class=" self-center">
+				<svg
+					xmlns="http://www.w3.org/2000/svg"
+					viewBox="0 0 16 16"
+					fill="currentColor"
+					class="w-4 h-4"
+				>
+					<path
+						fill-rule="evenodd"
+						d="M4 2a1.5 1.5 0 0 0-1.5 1.5v9A1.5 1.5 0 0 0 4 14h8a1.5 1.5 0 0 0 1.5-1.5V6.621a1.5 1.5 0 0 0-.44-1.06L9.94 2.439A1.5 1.5 0 0 0 8.878 2H4Zm4 3.5a.75.75 0 0 1 .75.75v2.69l.72-.72a.75.75 0 1 1 1.06 1.06l-2 2a.75.75 0 0 1-1.06 0l-2-2a.75.75 0 0 1 1.06-1.06l.72.72V6.25A.75.75 0 0 1 8 5.5Z"
+						clip-rule="evenodd"
+					/>
+				</svg>
+			</div>
+		</button>
+	</div>
+</div>

+ 2 - 2
src/lib/components/workspace/Tools/ToolkitEditor.svelte

@@ -14,7 +14,7 @@
 		description: ''
 	};
 
-	let code = '';
+	let content = '';
 
 	$: if (name) {
 		id = name.replace(/\s+/g, '_').toLowerCase();
@@ -97,7 +97,7 @@
 				</div>
 
 				<div class="mb-2 flex-1 overflow-auto h-0 rounded-lg">
-					<CodeEditor bind:value={code} bind:this={codeEditor} {saveHandler} />
+					<CodeEditor bind:value={content} bind:this={codeEditor} {saveHandler} />
 				</div>
 
 				<div class="pb-3 flex justify-end">

+ 10 - 16
src/lib/stores/index.ts

@@ -23,24 +23,11 @@ export const chatId = writable('');
 
 export const chats = writable([]);
 export const tags = writable([]);
-export const models: Writable<Model[]> = writable([]);
 
-export const modelfiles = writable([]);
+export const models: Writable<Model[]> = writable([]);
 export const prompts: Writable<Prompt[]> = writable([]);
-export const documents = writable([
-	{
-		collection_name: 'collection_name',
-		filename: 'filename',
-		name: 'name',
-		title: 'title'
-	},
-	{
-		collection_name: 'collection_name1',
-		filename: 'filename1',
-		name: 'name1',
-		title: 'title1'
-	}
-]);
+export const documents: Writable<Document[]> = writable([]);
+export const tools = writable([]);
 
 export const banners: Writable<Banner[]> = writable([]);
 
@@ -135,6 +122,13 @@ type Prompt = {
 	timestamp: number;
 };
 
+type Document = {
+	collection_name: string;
+	filename: string;
+	name: string;
+	title: string;
+};
+
 type Config = {
 	status: boolean;
 	name: string;