Browse Source

feat: modelfiles backend

Timothy J. Baek 1 year ago
parent
commit
a2b1e3756b

+ 122 - 0
backend/apps/web/models/modelfiles.py

@@ -0,0 +1,122 @@
+from pydantic import BaseModel
+from peewee import *
+from playhouse.shortcuts import model_to_dict
+from typing import List, Union, Optional
+import time
+
+from utils.utils import decode_token
+from utils.misc import get_gravatar_url
+
+from apps.web.internal.db import DB
+
+import json
+
+####################
+# User DB Schema
+####################
+
+
+class Modelfile(Model):
+    tag_name = CharField(unique=True)
+    user_id = CharField()
+    modelfile = TextField()
+    timestamp = DateField()
+
+    class Meta:
+        database = DB
+
+
+class ModelfileModel(BaseModel):
+    tag_name: str
+    user_id: str
+    modelfile: str
+    timestamp: int  # timestamp in epoch
+
+
+####################
+# Forms
+####################
+
+
+class ModelfileForm(BaseModel):
+    modelfile: dict
+
+
+class ModelfileResponse(BaseModel):
+    tag_name: str
+    user_id: str
+    modelfile: dict
+    timestamp: int  # timestamp in epoch
+
+
+class ModelfilesTable:
+    def __init__(self, db):
+        self.db = db
+        self.db.create_tables([Modelfile])
+
+    def insert_new_modelfile(
+        self, user_id: str, form_data: ModelfileForm
+    ) -> Optional[ModelfileModel]:
+        if "title" in form_data.modelfile:
+            modelfile = ModelfileModel(
+                **{
+                    "user_id": user_id,
+                    "tag_name": form_data.modelfile["title"],
+                    "modelfile": json.dumps(form_data.modelfile),
+                    "timestamp": int(time.time()),
+                }
+            )
+            result = Modelfile.create(**modelfile.model_dump())
+            if result:
+                return modelfile
+            else:
+                return None
+        else:
+            return None
+
+    def get_modelfile_by_tag_name(self, tag_name: str) -> Optional[ModelfileModel]:
+        try:
+            modelfile = Modelfile.get(Modelfile.tag_name == tag_name)
+            return ModelfileModel(**model_to_dict(modelfile))
+        except:
+            return None
+
+    def get_modelfiles(self, skip: int = 0, limit: int = 50) -> List[ModelfileResponse]:
+        return [
+            ModelfileResponse(
+                **{
+                    **model_to_dict(modelfile),
+                    "modelfile": json.loads(modelfile.modelfile),
+                }
+            )
+            for modelfile in Modelfile.select()
+            # .limit(limit).offset(skip)
+        ]
+
+    def update_modelfile_by_tag_name(
+        self, tag_name: str, modelfile: dict
+    ) -> Optional[ModelfileModel]:
+        try:
+            query = Modelfile.update(
+                modelfile=json.dumps(modelfile),
+                timestamp=int(time.time()),
+            ).where(Modelfile.tag_name == tag_name)
+
+            query.execute()
+
+            modelfile = Modelfile.get(Modelfile.tag_name == tag_name)
+            return ModelfileModel(**model_to_dict(modelfile))
+        except:
+            return None
+
+    def delete_modelfile_by_tag_name(self, tag_name: str) -> bool:
+        try:
+            query = Modelfile.delete().where((Modelfile.tag_name == tag_name))
+            query.execute()  # Remove the rows, return number of rows removed.
+
+            return True
+        except:
+            return False
+
+
+Modelfiles = ModelfilesTable(DB)

+ 178 - 0
backend/apps/web/routers/modelfiles.py

@@ -0,0 +1,178 @@
+from fastapi import Response
+from fastapi import Depends, FastAPI, HTTPException, status
+from datetime import datetime, timedelta
+from typing import List, Union, Optional
+
+from fastapi import APIRouter
+from pydantic import BaseModel
+import json
+
+from apps.web.models.users import Users
+from apps.web.models.modelfiles import (
+    Modelfiles,
+    ModelfileForm,
+    ModelfileResponse,
+)
+
+from utils.utils import (
+    bearer_scheme,
+)
+from constants import ERROR_MESSAGES
+
+router = APIRouter()
+
+############################
+# GetModelfiles
+############################
+
+
+@router.get("/", response_model=List[ModelfileResponse])
+async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)):
+    token = cred.credentials
+    user = Users.get_user_by_token(token)
+
+    if user:
+        return Modelfiles.get_modelfiles(skip, limit)
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.INVALID_TOKEN,
+        )
+
+
+############################
+# CreateNewModelfile
+############################
+
+
+@router.post("/create", response_model=Optional[ModelfileResponse])
+async def create_new_modelfile(form_data: ModelfileForm, cred=Depends(bearer_scheme)):
+    token = cred.credentials
+    user = Users.get_user_by_token(token)
+
+    if user:
+        # Admin Only
+        if user.role == "admin":
+            modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
+            return ModelfileResponse(
+                **{
+                    **modelfile.model_dump(),
+                    "modelfile": json.loads(modelfile.modelfile),
+                }
+            )
+        else:
+            raise HTTPException(
+                status_code=status.HTTP_401_UNAUTHORIZED,
+                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+            )
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.INVALID_TOKEN,
+        )
+
+
+############################
+# GetModelfileByTagName
+############################
+
+
+@router.get("/{tag_name}", response_model=Optional[ModelfileResponse])
+async def get_modelfile_by_tag_name(tag_name: str, cred=Depends(bearer_scheme)):
+    token = cred.credentials
+    user = Users.get_user_by_token(token)
+
+    if user:
+        modelfile = Modelfiles.get_modelfile_by_tag_name(tag_name)
+
+        if modelfile:
+            return ModelfileResponse(
+                **{
+                    **modelfile.model_dump(),
+                    "modelfile": json.loads(modelfile.modelfile),
+                }
+            )
+        else:
+            raise HTTPException(
+                status_code=status.HTTP_401_UNAUTHORIZED,
+                detail=ERROR_MESSAGES.NOT_FOUND,
+            )
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.INVALID_TOKEN,
+        )
+
+
+############################
+# UpdateModelfileByTagName
+############################
+
+
+@router.post("/{tag_name}", response_model=Optional[ModelfileResponse])
+async def update_modelfile_by_tag_name(
+    tag_name: str, form_data: ModelfileForm, cred=Depends(bearer_scheme)
+):
+    token = cred.credentials
+    user = Users.get_user_by_token(token)
+
+    if user:
+        if user.role == "admin":
+            modelfile = Modelfiles.get_modelfile_by_tag_name(tag_name)
+            if modelfile:
+                updated_modelfile = {
+                    **json.loads(modelfile.modelfile),
+                    **form_data.modelfile,
+                }
+
+                modelfile = Modelfiles.update_modelfile_by_tag_name(
+                    tag_name, updated_modelfile
+                )
+
+                return ModelfileResponse(
+                    **{
+                        **modelfile.model_dump(),
+                        "modelfile": json.loads(modelfile.modelfile),
+                    }
+                )
+            else:
+                raise HTTPException(
+                    status_code=status.HTTP_401_UNAUTHORIZED,
+                    detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+                )
+        else:
+            raise HTTPException(
+                status_code=status.HTTP_401_UNAUTHORIZED,
+                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+            )
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.INVALID_TOKEN,
+        )
+
+
+############################
+# DeleteModelfileByTagName
+############################
+
+
+@router.delete("/{tag_name}", response_model=bool)
+async def delete_modelfile_by_tag_name(tag_name: str, cred=Depends(bearer_scheme)):
+    token = cred.credentials
+    user = Users.get_user_by_token(token)
+
+    if user:
+        if user.role == "admin":
+            result = Modelfiles.delete_modelfile_by_tag_name(tag_name)
+            return result
+        else:
+            raise HTTPException(
+                status_code=status.HTTP_401_UNAUTHORIZED,
+                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+            )
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.INVALID_TOKEN,
+        )

+ 30 - 28
src/lib/components/layout/Sidebar.svelte

@@ -98,35 +98,37 @@
 			</button>
 		</div>
 
-		<div class="px-2.5 flex justify-center my-1">
-			<button
-				class="flex-grow flex space-x-3 rounded-md px-3 py-2 hover:bg-gray-900 transition"
-				on:click={async () => {
-					goto('/modelfiles');
-				}}
-			>
-				<div class="self-center">
-					<svg
-						xmlns="http://www.w3.org/2000/svg"
-						fill="none"
-						viewBox="0 0 24 24"
-						stroke-width="1.5"
-						stroke="currentColor"
-						class="w-4 h-4"
-					>
-						<path
-							stroke-linecap="round"
-							stroke-linejoin="round"
-							d="M13.5 16.875h3.375m0 0h3.375m-3.375 0V13.5m0 3.375v3.375M6 10.5h2.25a2.25 2.25 0 002.25-2.25V6a2.25 2.25 0 00-2.25-2.25H6A2.25 2.25 0 003.75 6v2.25A2.25 2.25 0 006 10.5zm0 9.75h2.25A2.25 2.25 0 0010.5 18v-2.25a2.25 2.25 0 00-2.25-2.25H6a2.25 2.25 0 00-2.25 2.25V18A2.25 2.25 0 006 20.25zm9.75-9.75H18a2.25 2.25 0 002.25-2.25V6A2.25 2.25 0 0018 3.75h-2.25A2.25 2.25 0 0013.5 6v2.25a2.25 2.25 0 002.25 2.25z"
-						/>
-					</svg>
-				</div>
+		{#if $user?.role === 'admin'}
+			<div class="px-2.5 flex justify-center my-1">
+				<button
+					class="flex-grow flex space-x-3 rounded-md px-3 py-2 hover:bg-gray-900 transition"
+					on:click={async () => {
+						goto('/modelfiles');
+					}}
+				>
+					<div class="self-center">
+						<svg
+							xmlns="http://www.w3.org/2000/svg"
+							fill="none"
+							viewBox="0 0 24 24"
+							stroke-width="1.5"
+							stroke="currentColor"
+							class="w-4 h-4"
+						>
+							<path
+								stroke-linecap="round"
+								stroke-linejoin="round"
+								d="M13.5 16.875h3.375m0 0h3.375m-3.375 0V13.5m0 3.375v3.375M6 10.5h2.25a2.25 2.25 0 002.25-2.25V6a2.25 2.25 0 00-2.25-2.25H6A2.25 2.25 0 003.75 6v2.25A2.25 2.25 0 006 10.5zm0 9.75h2.25A2.25 2.25 0 0010.5 18v-2.25a2.25 2.25 0 00-2.25-2.25H6a2.25 2.25 0 00-2.25 2.25V18A2.25 2.25 0 006 20.25zm9.75-9.75H18a2.25 2.25 0 002.25-2.25V6A2.25 2.25 0 0018 3.75h-2.25A2.25 2.25 0 0013.5 6v2.25a2.25 2.25 0 002.25 2.25z"
+							/>
+						</svg>
+					</div>
 
-				<div class="flex self-center">
-					<div class=" self-center font-medium text-sm">Modelfiles</div>
-				</div>
-			</button>
-		</div>
+					<div class="flex self-center">
+						<div class=" self-center font-medium text-sm">Modelfiles</div>
+					</div>
+				</button>
+			</div>
+		{/if}
 
 		<div class="px-2.5 mt-1 mb-2 flex justify-center space-x-2">
 			<div class="flex w-full">