Browse Source

feat: functions router

Timothy J. Baek 10 months ago
parent
commit
f68aba687e

+ 1 - 1
backend/apps/webui/main.py

@@ -60,7 +60,7 @@ app.state.config.ENABLE_COMMUNITY_SHARING = ENABLE_COMMUNITY_SHARING
 
 app.state.MODELS = {}
 app.state.TOOLS = {}
-
+app.state.FUNCTIONS = {}
 
 app.add_middleware(
     CORSMiddleware,

+ 2 - 2
backend/apps/webui/models/functions.py

@@ -69,7 +69,7 @@ class FunctionForm(BaseModel):
     meta: FunctionMeta
 
 
-class ToolsTable:
+class FunctionsTable:
     def __init__(self, db):
         self.db = db
         self.db.create_tables([Function])
@@ -137,4 +137,4 @@ class ToolsTable:
             return False
 
 
-Tools = ToolsTable(DB)
+Functions = FunctionsTable(DB)

+ 180 - 0
backend/apps/webui/routers/functions.py

@@ -0,0 +1,180 @@
+from fastapi import Depends, FastAPI, HTTPException, status, Request
+from datetime import datetime, timedelta
+from typing import List, Union, Optional
+
+from fastapi import APIRouter
+from pydantic import BaseModel
+import json
+
+from apps.webui.models.functions import (
+    Functions,
+    FunctionForm,
+    FunctionModel,
+    FunctionResponse,
+)
+from apps.webui.utils import load_function_module_by_id
+from utils.utils import get_verified_user, get_admin_user
+from constants import ERROR_MESSAGES
+
+from importlib import util
+import os
+from pathlib import Path
+
+from config import DATA_DIR, CACHE_DIR, FUNCTIONS_DIR
+
+
+router = APIRouter()
+
+############################
+# GetFunctions
+############################
+
+
+@router.get("/", response_model=List[FunctionResponse])
+async def get_functions(user=Depends(get_verified_user)):
+    return Functions.get_functions()
+
+
+############################
+# ExportFunctions
+############################
+
+
+@router.get("/export", response_model=List[FunctionModel])
+async def get_functions(user=Depends(get_admin_user)):
+    return Functions.get_functions()
+
+
+############################
+# CreateNewFunction
+############################
+
+
+@router.post("/create", response_model=Optional[FunctionResponse])
+async def create_new_function(
+    request: Request, form_data: FunctionForm, user=Depends(get_admin_user)
+):
+    if not form_data.id.isidentifier():
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail="Only alphanumeric characters and underscores are allowed in the id",
+        )
+
+    form_data.id = form_data.id.lower()
+
+    function = Functions.get_function_by_id(form_data.id)
+    if function == None:
+        function_path = os.path.join(FUNCTIONS_DIR, f"{form_data.id}.py")
+        try:
+            with open(function_path, "w") as function_file:
+                function_file.write(form_data.content)
+
+            function_module = load_function_module_by_id(form_data.id)
+
+            FUNCTIONS = request.app.state.FUNCTIONS
+            FUNCTIONS[form_data.id] = function_module
+
+            function = Functions.insert_new_function(user.id, form_data)
+
+            function_cache_dir = Path(CACHE_DIR) / "functions" / form_data.id
+            function_cache_dir.mkdir(parents=True, exist_ok=True)
+
+            if function:
+                return function
+            else:
+                raise HTTPException(
+                    status_code=status.HTTP_400_BAD_REQUEST,
+                    detail=ERROR_MESSAGES.DEFAULT("Error creating function"),
+                )
+        except Exception as e:
+            print(e)
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT(e),
+            )
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.ID_TAKEN,
+        )
+
+
+############################
+# GetFunctionById
+############################
+
+
+@router.get("/id/{id}", response_model=Optional[FunctionModel])
+async def get_function_by_id(id: str, user=Depends(get_admin_user)):
+    function = Functions.get_function_by_id(id)
+
+    if function:
+        return function
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+
+############################
+# UpdateFunctionById
+############################
+
+
+@router.post("/id/{id}/update", response_model=Optional[FunctionModel])
+async def update_toolkit_by_id(
+    request: Request, id: str, form_data: FunctionForm, user=Depends(get_admin_user)
+):
+    function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
+
+    try:
+        with open(function_path, "w") as function_file:
+            function_file.write(form_data.content)
+
+        function_module = load_function_module_by_id(id)
+
+        FUNCTIONS = request.app.state.FUNCTIONS
+        FUNCTIONS[id] = function_module
+
+        updated = {**form_data.model_dump(exclude={"id"})}
+        print(updated)
+
+        function = Functions.update_function_by_id(id, updated)
+
+        if function:
+            return function
+        else:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT("Error updating function"),
+            )
+
+    except Exception as e:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.DEFAULT(e),
+        )
+
+
+############################
+# DeleteFunctionById
+############################
+
+
+@router.delete("/id/{id}/delete", response_model=bool)
+async def delete_function_by_id(
+    request: Request, id: str, user=Depends(get_admin_user)
+):
+    result = Functions.delete_function_by_id(id)
+
+    if result:
+        FUNCTIONS = request.app.state.FUNCTIONS
+        if id in FUNCTIONS:
+            del FUNCTIONS[id]
+
+        # delete the function file
+        function_path = os.path.join(FUNCTIONS_DIR, f"{id}.py")
+        os.remove(function_path)
+
+    return result

+ 23 - 1
backend/apps/webui/utils.py

@@ -1,7 +1,7 @@
 from importlib import util
 import os
 
-from config import TOOLS_DIR
+from config import TOOLS_DIR, FUNCTIONS_DIR
 
 
 def load_toolkit_module_by_id(toolkit_id):
@@ -21,3 +21,25 @@ def load_toolkit_module_by_id(toolkit_id):
         # Move the file to the error folder
         os.rename(toolkit_path, f"{toolkit_path}.error")
         raise e
+
+
+def load_function_module_by_id(function_id):
+    function_path = os.path.join(FUNCTIONS_DIR, f"{function_id}.py")
+
+    spec = util.spec_from_file_location(function_id, function_path)
+    module = util.module_from_spec(spec)
+
+    try:
+        spec.loader.exec_module(module)
+        print(f"Loaded module: {module.__name__}")
+        if hasattr(module, "Pipe"):
+            return module.Pipe()
+        elif hasattr(module, "Filter"):
+            return module.Filter()
+        else:
+            raise Exception("No Function class found")
+    except Exception as e:
+        print(f"Error loading module: {function_id}")
+        # Move the file to the error folder
+        os.rename(function_path, f"{function_path}.error")
+        raise e

+ 8 - 0
backend/config.py

@@ -377,6 +377,14 @@ TOOLS_DIR = os.getenv("TOOLS_DIR", f"{DATA_DIR}/tools")
 Path(TOOLS_DIR).mkdir(parents=True, exist_ok=True)
 
 
+####################################
+# Functions DIR
+####################################
+
+FUNCTIONS_DIR = os.getenv("FUNCTIONS_DIR", f"{DATA_DIR}/functions")
+Path(FUNCTIONS_DIR).mkdir(parents=True, exist_ok=True)
+
+
 ####################################
 # LITELLM_CONFIG
 ####################################