浏览代码

feat: files endpoint

Timothy J. Baek 10 月之前
父节点
当前提交
146e550239
共有 4 个文件被更改,包括 242 次插入0 次删除
  1. 2 0
      backend/apps/webui/main.py
  2. 103 0
      backend/apps/webui/models/files.py
  3. 134 0
      backend/apps/webui/routers/files.py
  4. 3 0
      backend/main.py

+ 2 - 0
backend/apps/webui/main.py

@@ -12,6 +12,7 @@ from apps.webui.routers import (
     configs,
     memories,
     utils,
+    files,
 )
 from config import (
     WEBUI_BUILD_HASH,
@@ -81,6 +82,7 @@ app.include_router(memories.router, prefix="/memories", tags=["memories"])
 
 app.include_router(configs.router, prefix="/configs", tags=["configs"])
 app.include_router(utils.router, prefix="/utils", tags=["utils"])
+app.include_router(files.router, prefix="/files", tags=["files"])
 
 
 @app.get("/")

+ 103 - 0
backend/apps/webui/models/files.py

@@ -0,0 +1,103 @@
+from pydantic import BaseModel
+from peewee import *
+from playhouse.shortcuts import model_to_dict
+from typing import List, Union, Optional
+import time
+import logging
+from apps.webui.internal.db import DB, JSONField
+
+import json
+
+from config import SRC_LOG_LEVELS
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["MODELS"])
+
+####################
+# Files DB Schema
+####################
+
+
+class File(Model):
+    id = CharField(unique=True)
+    user_id = CharField()
+    filename = TextField()
+    meta = JSONField()
+    created_at = BigIntegerField()
+
+    class Meta:
+        database = DB
+
+
+class FileModel(BaseModel):
+    id: str
+    user_id: str
+    filename: str
+    meta: dict
+    created_at: int  # timestamp in epoch
+
+
+####################
+# Forms
+####################
+
+
+class FileResponse(BaseModel):
+    id: str
+    user_id: str
+    filename: str
+    meta: dict
+    created_at: int  # timestamp in epoch
+
+
+class FileForm(BaseModel):
+    id: str
+    filename: str
+    meta: dict = {}
+
+
+class FilesTable:
+    def __init__(self, db):
+        self.db = db
+        self.db.create_tables([File])
+
+    def insert_new_file(self, user_id: str, form_data: FileForm) -> Optional[FileModel]:
+        file = FileModel(
+            **{
+                **form_data.model_dump(),
+                "user_id": user_id,
+                "created_at": int(time.time()),
+            }
+        )
+
+        try:
+            result = File.create(**file.model_dump())
+            if result:
+                return file
+            else:
+                return None
+        except Exception as e:
+            print(f"Error creating tool: {e}")
+            return None
+
+    def get_file_by_id(self, id: str) -> Optional[FileModel]:
+        try:
+            file = File.get(File.id == id)
+            return FileModel(**model_to_dict(file))
+        except:
+            return None
+
+    def get_files(self) -> List[FileModel]:
+        return [FileModel(**model_to_dict(file)) for file in File.select()]
+
+    def delete_file_by_id(self, id: str) -> bool:
+        try:
+            query = File.delete().where((File.id == id))
+            query.execute()  # Remove the rows, return number of rows removed.
+
+            return True
+        except:
+            return False
+
+
+Files = FilesTable(DB)

+ 134 - 0
backend/apps/webui/routers/files.py

@@ -0,0 +1,134 @@
+from fastapi import (
+    Depends,
+    FastAPI,
+    HTTPException,
+    status,
+    Request,
+    UploadFile,
+    File,
+    Form,
+)
+
+
+from datetime import datetime, timedelta
+from typing import List, Union, Optional
+
+from fastapi import APIRouter
+from pydantic import BaseModel
+import json
+
+from apps.webui.models.files import Files, FileForm, FileModel, FileResponse
+from utils.utils import get_verified_user, get_admin_user
+from constants import ERROR_MESSAGES
+
+from importlib import util
+import os
+import uuid
+
+from config import SRC_LOG_LEVELS, UPLOAD_DIR
+
+
+import logging
+
+log = logging.getLogger(__name__)
+log.setLevel(SRC_LOG_LEVELS["MODELS"])
+
+
+router = APIRouter()
+
+############################
+# Upload File
+############################
+
+
+@router.post("/")
+def upload_file(
+    file: UploadFile = File(...),
+    user=Depends(get_verified_user),
+):
+    log.info(f"file.content_type: {file.content_type}")
+    try:
+        unsanitized_filename = file.filename
+        filename = os.path.basename(unsanitized_filename)
+
+        # replace filename with uuid
+        id = str(uuid.uuid4())
+        file_path = f"{UPLOAD_DIR}/{filename}"
+
+        contents = file.file.read()
+        with open(file_path, "wb") as f:
+            f.write(contents)
+            f.close()
+
+        file = Files.insert_new_file(
+            user.id, FileForm(**{"id": id, "filename": filename})
+        )
+
+        if file:
+            return file
+        else:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT("Error uploading file"),
+            )
+
+    except Exception as e:
+        log.exception(e)
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.DEFAULT(e),
+        )
+
+
+############################
+# List Files
+############################
+
+
+@router.get("/", response_model=List[FileModel])
+async def list_files(user=Depends(get_verified_user)):
+    files = Files.get_files()
+    return files
+
+
+############################
+# Get File By Id
+############################
+
+
+@router.get("/{id}", response_model=Optional[FileModel])
+async def get_file_by_id(id: str, user=Depends(get_verified_user)):
+    file = Files.get_file_by_id(id)
+
+    if file:
+        return file
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+
+############################
+# Delete File By Id
+############################
+
+
+@router.delete("/{id}")
+async def delete_file_by_id(id: str, user=Depends(get_verified_user)):
+    file = Files.get_file_by_id(id)
+
+    if file:
+        result = Files.delete_file_by_id(id)
+        if result:
+            return {"message": "File deleted successfully"}
+        else:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=ERROR_MESSAGES.DEFAULT("Error deleting file"),
+            )
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )

+ 3 - 0
backend/main.py

@@ -11,6 +11,7 @@ import requests
 import mimetypes
 import shutil
 import os
+import uuid
 import inspect
 import asyncio
 
@@ -76,6 +77,7 @@ from config import (
     VERSION,
     CHANGELOG,
     FRONTEND_BUILD_DIR,
+    UPLOAD_DIR,
     CACHE_DIR,
     STATIC_DIR,
     ENABLE_OPENAI_API,
@@ -1378,6 +1380,7 @@ async def update_pipeline_valves(
         )
 
 
+
 @app.get("/api/config")
 async def get_app_config():
     # Checking and Handling the Absence of 'ui' in CONFIG_DATA