ソースを参照

Merge branch 'main' into dev

Timothy Jaeryang Baek 1 年間 前
コミット
127886db14

+ 2 - 1
backend/.gitignore

@@ -4,4 +4,5 @@ _old
 uploads
 .ipynb_checkpoints
 *.db
-_test
+_test
+Pipfile

+ 7 - 3
backend/apps/ollama/main.py

@@ -8,7 +8,7 @@ import json
 
 from apps.web.models.users import Users
 from constants import ERROR_MESSAGES
-from utils.utils import extract_token_from_auth_header
+from utils.utils import decode_token
 from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
 
 app = Flask(__name__)
@@ -34,8 +34,12 @@ def proxy(path):
     # Basic RBAC support
     if WEBUI_AUTH:
         if "Authorization" in headers:
-            token = extract_token_from_auth_header(headers["Authorization"])
-            user = Users.get_user_by_token(token)
+            _, credentials = headers["Authorization"].split()
+            token_data = decode_token(credentials)
+            if token_data is None or "email" not in token_data:
+                return jsonify({"detail": ERROR_MESSAGES.UNAUTHORIZED}), 401
+
+            user = Users.get_user_by_email(token_data["email"])
             if user:
                 # Only user and admin roles can access
                 if user.role in ["user", "admin"]:

+ 3 - 5
backend/apps/web/main.py

@@ -1,6 +1,6 @@
-from fastapi import FastAPI, Request, Depends, HTTPException
+from fastapi import FastAPI, Depends
+from fastapi.routing import APIRoute
 from fastapi.middleware.cors import CORSMiddleware
-
 from apps.web.routers import auths, users, chats, modelfiles, utils
 from config import WEBUI_VERSION, WEBUI_AUTH
 
@@ -16,13 +16,11 @@ app.add_middleware(
     allow_headers=["*"],
 )
 
-
 app.include_router(auths.router, prefix="/auths", tags=["auths"])
+
 app.include_router(users.router, prefix="/users", tags=["users"])
 app.include_router(chats.router, prefix="/chats", tags=["chats"])
 app.include_router(modelfiles.router, prefix="/modelfiles", tags=["modelfiles"])
-
-
 app.include_router(utils.router, prefix="/utils", tags=["utils"])
 
 

+ 0 - 10
backend/apps/web/models/users.py

@@ -3,8 +3,6 @@ from peewee import *
 from playhouse.shortcuts import model_to_dict
 from typing import List, Union, Optional
 import time
-
-from utils.utils import decode_token
 from utils.misc import get_gravatar_url
 
 from apps.web.internal.db import DB
@@ -85,14 +83,6 @@ class UsersTable:
         except:
             return None
 
-    def get_user_by_token(self, token: str) -> Optional[UserModel]:
-        data = decode_token(token)
-
-        if data != None and "email" in data:
-            return self.get_user_by_email(data["email"])
-        else:
-            return None
-
     def get_users(self, skip: int = 0, limit: int = 50) -> List[UserModel]:
         return [
             UserModel(**model_to_dict(user))

+ 12 - 25
backend/apps/web/routers/auths.py

@@ -19,11 +19,7 @@ from apps.web.models.auths import (
 from apps.web.models.users import Users
 
 
-from utils.utils import (
-    get_password_hash,
-    bearer_scheme,
-    create_token,
-)
+from utils.utils import get_password_hash, get_current_user, create_token
 from utils.misc import get_gravatar_url
 from constants import ERROR_MESSAGES
 
@@ -36,22 +32,14 @@ router = APIRouter()
 
 
 @router.get("/", response_model=UserResponse)
-async def get_session_user(cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-    if user:
-        return {
-            "id": user.id,
-            "email": user.email,
-            "name": user.name,
-            "role": user.role,
-            "profile_image_url": user.profile_image_url,
-        }
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
+async def get_session_user(user=Depends(get_current_user)):
+    return {
+        "id": user.id,
+        "email": user.email,
+        "name": user.name,
+        "role": user.role,
+        "profile_image_url": user.profile_image_url,
+    }
 
 
 ############################
@@ -60,10 +48,9 @@ async def get_session_user(cred=Depends(bearer_scheme)):
 
 
 @router.post("/update/password", response_model=bool)
-async def update_password(form_data: UpdatePasswordForm, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    session_user = Users.get_user_by_token(token)
-
+async def update_password(
+    form_data: UpdatePasswordForm, session_user=Depends(get_current_user)
+):
     if session_user:
         user = Auths.authenticate_user(session_user.email, form_data.password)
 

+ 36 - 97
backend/apps/web/routers/chats.py

@@ -1,8 +1,7 @@
-from fastapi import Response
-from fastapi import Depends, FastAPI, HTTPException, status
+from fastapi import Depends, Request, HTTPException, status
 from datetime import datetime, timedelta
 from typing import List, Union, Optional
-
+from utils.utils import get_current_user
 from fastapi import APIRouter
 from pydantic import BaseModel
 import json
@@ -30,17 +29,10 @@ router = APIRouter()
 
 
 @router.get("/", response_model=List[ChatTitleIdResponse])
-async def get_user_chats(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        return Chats.get_chat_lists_by_user_id(user.id, skip, limit)
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
+async def get_user_chats(
+    user=Depends(get_current_user), skip: int = 0, limit: int = 50
+):
+    return Chats.get_chat_lists_by_user_id(user.id, skip, limit)
 
 
 ############################
@@ -49,20 +41,11 @@ async def get_user_chats(skip: int = 0, limit: int = 50, cred=Depends(bearer_sch
 
 
 @router.get("/all", response_model=List[ChatResponse])
-async def get_all_user_chats(cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        return [
-            ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-            for chat in Chats.get_all_chats_by_user_id(user.id)
-        ]
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
+async def get_all_user_chats(user=Depends(get_current_user)):
+    return [
+        ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
+        for chat in Chats.get_all_chats_by_user_id(user.id)
+    ]
 
 
 ############################
@@ -71,18 +54,9 @@ async def get_all_user_chats(cred=Depends(bearer_scheme)):
 
 
 @router.post("/new", response_model=Optional[ChatResponse])
-async def create_new_chat(form_data: ChatForm, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        chat = Chats.insert_new_chat(user.id, form_data)
-        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
+async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
+    chat = Chats.insert_new_chat(user.id, form_data)
+    return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
 
 
 ############################
@@ -91,24 +65,14 @@ async def create_new_chat(form_data: ChatForm, cred=Depends(bearer_scheme)):
 
 
 @router.get("/{id}", response_model=Optional[ChatResponse])
-async def get_chat_by_id(id: str, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        chat = Chats.get_chat_by_id_and_user_id(id, user.id)
-
-        if chat:
-            return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.NOT_FOUND,
-            )
+async def get_chat_by_id(id: str, user=Depends(get_current_user)):
+    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+
+    if chat:
+        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
     else:
         raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
         )
 
 
@@ -118,26 +82,19 @@ async def get_chat_by_id(id: str, cred=Depends(bearer_scheme)):
 
 
 @router.post("/{id}", response_model=Optional[ChatResponse])
-async def update_chat_by_id(id: str, form_data: ChatForm, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        chat = Chats.get_chat_by_id_and_user_id(id, user.id)
-        if chat:
-            updated_chat = {**json.loads(chat.chat), **form_data.chat}
-
-            chat = Chats.update_chat_by_id(id, updated_chat)
-            return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
+async def update_chat_by_id(
+    id: str, form_data: ChatForm, user=Depends(get_current_user)
+):
+    chat = Chats.get_chat_by_id_and_user_id(id, user.id)
+    if chat:
+        updated_chat = {**json.loads(chat.chat), **form_data.chat}
+
+        chat = Chats.update_chat_by_id(id, updated_chat)
+        return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
     else:
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
         )
 
 
@@ -147,18 +104,9 @@ async def update_chat_by_id(id: str, form_data: ChatForm, cred=Depends(bearer_sc
 
 
 @router.delete("/{id}", response_model=bool)
-async def delete_chat_by_id(id: str, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        result = Chats.delete_chat_by_id_and_user_id(id, user.id)
-        return result
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
+async def delete_chat_by_id(id: str, user=Depends(get_current_user)):
+    result = Chats.delete_chat_by_id_and_user_id(id, user.id)
+    return result
 
 
 ############################
@@ -167,15 +115,6 @@ async def delete_chat_by_id(id: str, cred=Depends(bearer_scheme)):
 
 
 @router.delete("/", response_model=bool)
-async def delete_all_user_chats(cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        result = Chats.delete_chats_by_user_id(user.id)
-        return result
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
+async def delete_all_user_chats(user=Depends(get_current_user)):
+    result = Chats.delete_chats_by_user_id(user.id)
+    return result

+ 63 - 115
backend/apps/web/routers/modelfiles.py

@@ -1,4 +1,3 @@
-from fastapi import Response
 from fastapi import Depends, FastAPI, HTTPException, status
 from datetime import datetime, timedelta
 from typing import List, Union, Optional
@@ -6,8 +5,6 @@ from typing import List, Union, Optional
 from fastapi import APIRouter
 from pydantic import BaseModel
 import json
-
-from apps.web.models.users import Users
 from apps.web.models.modelfiles import (
     Modelfiles,
     ModelfileForm,
@@ -16,9 +13,7 @@ from apps.web.models.modelfiles import (
     ModelfileResponse,
 )
 
-from utils.utils import (
-    bearer_scheme,
-)
+from utils.utils import get_current_user
 from constants import ERROR_MESSAGES
 
 router = APIRouter()
@@ -29,17 +24,8 @@ router = APIRouter()
 
 
 @router.get("/", response_model=List[ModelfileResponse])
-async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        return Modelfiles.get_modelfiles(skip, limit)
-    else:
-        raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
-        )
+async def get_modelfiles(skip: int = 0, limit: int = 50, user=Depends(get_current_user)):
+    return Modelfiles.get_modelfiles(skip, limit)
 
 
 ############################
@@ -48,36 +34,28 @@ async def get_modelfiles(skip: int = 0, limit: int = 50, cred=Depends(bearer_sch
 
 
 @router.post("/create", response_model=Optional[ModelfileResponse])
-async def create_new_modelfile(form_data: ModelfileForm, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        # Admin Only
-        if user.role == "admin":
-            modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
-
-            if modelfile:
-                return ModelfileResponse(
-                    **{
-                        **modelfile.model_dump(),
-                        "modelfile": json.loads(modelfile.modelfile),
-                    }
-                )
-            else:
-                raise HTTPException(
-                    status_code=status.HTTP_401_UNAUTHORIZED,
-                    detail=ERROR_MESSAGES.DEFAULT(),
-                )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
+async def create_new_modelfile(
+    form_data: ModelfileForm, user=Depends(get_current_user)
+):
+    if user.role != "admin":
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
+
+    modelfile = Modelfiles.insert_new_modelfile(user.id, form_data)
+
+    if modelfile:
+        return ModelfileResponse(
+            **{
+                **modelfile.model_dump(),
+                "modelfile": json.loads(modelfile.modelfile),
+            }
+        )
     else:
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            detail=ERROR_MESSAGES.DEFAULT(),
         )
 
 
@@ -87,31 +65,20 @@ async def create_new_modelfile(form_data: ModelfileForm, cred=Depends(bearer_sch
 
 
 @router.post("/", response_model=Optional[ModelfileResponse])
-async def get_modelfile_by_tag_name(
-    form_data: ModelfileTagNameForm, cred=Depends(bearer_scheme)
-):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
-
-        if modelfile:
-            return ModelfileResponse(
-                **{
-                    **modelfile.model_dump(),
-                    "modelfile": json.loads(modelfile.modelfile),
-                }
-            )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.NOT_FOUND,
-            )
+async def get_modelfile_by_tag_name(form_data: ModelfileTagNameForm, user=Depends(get_current_user)):
+    modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
+
+    if modelfile:
+        return ModelfileResponse(
+            **{
+                **modelfile.model_dump(),
+                "modelfile": json.loads(modelfile.modelfile),
+            }
+        )
     else:
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            detail=ERROR_MESSAGES.NOT_FOUND,
         )
 
 
@@ -122,44 +89,34 @@ async def get_modelfile_by_tag_name(
 
 @router.post("/update", response_model=Optional[ModelfileResponse])
 async def update_modelfile_by_tag_name(
-    form_data: ModelfileUpdateForm, cred=Depends(bearer_scheme)
+    form_data: ModelfileUpdateForm, user=Depends(get_current_user)
 ):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        if user.role == "admin":
-            modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
-            if modelfile:
-                updated_modelfile = {
-                    **json.loads(modelfile.modelfile),
-                    **form_data.modelfile,
-                }
-
-                modelfile = Modelfiles.update_modelfile_by_tag_name(
-                    form_data.tag_name, updated_modelfile
-                )
-
-                return ModelfileResponse(
-                    **{
-                        **modelfile.model_dump(),
-                        "modelfile": json.loads(modelfile.modelfile),
-                    }
-                )
-            else:
-                raise HTTPException(
-                    status_code=status.HTTP_401_UNAUTHORIZED,
-                    detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-                )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
+    if user.role != "admin":
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
+    modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
+    if modelfile:
+        updated_modelfile = {
+            **json.loads(modelfile.modelfile),
+            **form_data.modelfile,
+        }
+
+        modelfile = Modelfiles.update_modelfile_by_tag_name(
+            form_data.tag_name, updated_modelfile
+        )
+
+        return ModelfileResponse(
+            **{
+                **modelfile.model_dump(),
+                "modelfile": json.loads(modelfile.modelfile),
+            }
+        )
     else:
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
         )
 
 
@@ -170,22 +127,13 @@ async def update_modelfile_by_tag_name(
 
 @router.delete("/delete", response_model=bool)
 async def delete_modelfile_by_tag_name(
-    form_data: ModelfileTagNameForm, cred=Depends(bearer_scheme)
+    form_data: ModelfileTagNameForm, user=Depends(get_current_user)
 ):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        if user.role == "admin":
-            result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
-            return result
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_401_UNAUTHORIZED,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
-    else:
+    if user.role != "admin":
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
         )
+
+    result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
+    return result

+ 31 - 61
backend/apps/web/routers/users.py

@@ -12,11 +12,7 @@ from apps.web.models.users import UserModel, UserRoleUpdateForm, Users
 from apps.web.models.auths import Auths
 
 
-from utils.utils import (
-    get_password_hash,
-    bearer_scheme,
-    create_token,
-)
+from utils.utils import get_current_user
 from constants import ERROR_MESSAGES
 
 router = APIRouter()
@@ -27,23 +23,13 @@ router = APIRouter()
 
 
 @router.get("/", response_model=List[UserModel])
-async def get_users(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        if user.role == "admin":
-            return Users.get_users(skip, limit)
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_403_FORBIDDEN,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
-    else:
+async def get_users(skip: int = 0, limit: int = 50, user=Depends(get_current_user)):
+    if user.role != "admin":
         raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            status_code=status.HTTP_403_FORBIDDEN,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
         )
+    return Users.get_users(skip, limit)
 
 
 ############################
@@ -52,28 +38,21 @@ async def get_users(skip: int = 0, limit: int = 50, cred=Depends(bearer_scheme))
 
 
 @router.post("/update/role", response_model=Optional[UserModel])
-async def update_user_role(form_data: UserRoleUpdateForm, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        if user.role == "admin":
-            if user.id != form_data.id:
-                return Users.update_user_role_by_id(form_data.id, form_data.role)
-            else:
-                raise HTTPException(
-                    status_code=status.HTTP_403_FORBIDDEN,
-                    detail=ERROR_MESSAGES.ACTION_PROHIBITED,
-                )
-        else:
-            raise HTTPException(
-                status_code=status.HTTP_403_FORBIDDEN,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
-            )
+async def update_user_role(
+    form_data: UserRoleUpdateForm, user=Depends(get_current_user)
+):
+    if user.role != "admin":
+        raise HTTPException(
+            status_code=status.HTTP_403_FORBIDDEN,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
+
+    if user.id != form_data.id:
+        return Users.update_user_role_by_id(form_data.id, form_data.role)
     else:
         raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            status_code=status.HTTP_403_FORBIDDEN,
+            detail=ERROR_MESSAGES.ACTION_PROHIBITED,
         )
 
 
@@ -83,34 +62,25 @@ async def update_user_role(form_data: UserRoleUpdateForm, cred=Depends(bearer_sc
 
 
 @router.delete("/{user_id}", response_model=bool)
-async def delete_user_by_id(user_id: str, cred=Depends(bearer_scheme)):
-    token = cred.credentials
-    user = Users.get_user_by_token(token)
-
-    if user:
-        if user.role == "admin":
-            if user.id != user_id:
-                result = Auths.delete_auth_by_id(user_id)
-
-                if result:
-                    return True
-                else:
-                    raise HTTPException(
-                        status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
-                        detail=ERROR_MESSAGES.DELETE_USER_ERROR,
-                    )
+async def delete_user_by_id(user_id: str, user=Depends(get_current_user)):
+    if user.role == "admin":
+        if user.id != user_id:
+            result = Auths.delete_auth_by_id(user_id)
+
+            if result:
+                return True
             else:
                 raise HTTPException(
-                    status_code=status.HTTP_403_FORBIDDEN,
-                    detail=ERROR_MESSAGES.ACTION_PROHIBITED,
+                    status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
+                    detail=ERROR_MESSAGES.DELETE_USER_ERROR,
                 )
         else:
             raise HTTPException(
                 status_code=status.HTTP_403_FORBIDDEN,
-                detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+                detail=ERROR_MESSAGES.ACTION_PROHIBITED,
             )
     else:
         raise HTTPException(
-            status_code=status.HTTP_401_UNAUTHORIZED,
-            detail=ERROR_MESSAGES.INVALID_TOKEN,
+            status_code=status.HTTP_403_FORBIDDEN,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
         )

+ 2 - 0
backend/requirements.txt

@@ -18,3 +18,5 @@ bcrypt
 
 PyJWT
 pyjwt[crypto]
+
+black

+ 18 - 14
backend/utils/utils.py

@@ -1,7 +1,9 @@
-from fastapi.security import HTTPBasicCredentials, HTTPBearer
+from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
+from fastapi import HTTPException, status, Depends
+from apps.web.models.users import Users
 from pydantic import BaseModel
 from typing import Union, Optional
-
+from constants import ERROR_MESSAGES
 from passlib.context import CryptContext
 from datetime import datetime, timedelta
 import requests
@@ -53,16 +55,18 @@ def extract_token_from_auth_header(auth_header: str):
     return auth_header[len("Bearer ") :]
 
 
-def verify_token(request):
-    try:
-        bearer = request.headers["authorization"]
-        if bearer:
-            token = bearer[len("Bearer ") :]
-            decoded = jwt.decode(
-                token, JWT_SECRET_KEY, options={"verify_signature": False}
+def get_current_user(auth_token: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
+    data = decode_token(auth_token.credentials)
+    if data != None and "email" in data:
+        user = Users.get_user_by_email(data["email"])
+        if user is None:
+            raise HTTPException(
+                status_code=status.HTTP_401_UNAUTHORIZED,
+                detail=ERROR_MESSAGES.INVALID_TOKEN,
             )
-            return decoded
-        else:
-            return None
-    except Exception as e:
-        return None
+        return user
+    else:
+        raise HTTPException(
+            status_code=status.HTTP_401_UNAUTHORIZED,
+            detail=ERROR_MESSAGES.UNAUTHORIZED,
+        )

+ 1 - 1
src/app.css

@@ -16,7 +16,7 @@ html {
 
 code {
 	/* white-space-collapse: preserve !important; */
-	white-space: pre;
+	overflow-x: auto;
 	width: auto;
 }
 

+ 18 - 0
src/lib/components/chat/MessageInput.svelte

@@ -298,6 +298,24 @@
 									submitPrompt(prompt);
 								}
 							}}
+							on:keydown={(e) => {
+								if (prompt === '' && e.key == 'ArrowUp') {
+									e.preventDefault();
+
+									const userMessageElement = [
+										...document.getElementsByClassName('user-message')
+									]?.at(-1);
+
+									const editButton = [
+										...document.getElementsByClassName('edit-user-message-button')
+									]?.at(-1);
+
+									console.log(userMessageElement);
+
+									userMessageElement.scrollIntoView({ block: 'center' });
+									editButton?.click();
+								}
+							}}
 							rows="1"
 							on:input={(e) => {
 								e.target.style.height = '';

+ 1 - 0
src/lib/components/chat/Messages/ResponseMessage.svelte

@@ -88,6 +88,7 @@
 				let code = block.querySelector('code');
 				code.style.borderTopRightRadius = 0;
 				code.style.borderTopLeftRadius = 0;
+				code.style.whiteSpace = 'pre';
 
 				let topBarDiv = document.createElement('div');
 				topBarDiv.style.backgroundColor = '#202123';

+ 6 - 2
src/lib/components/chat/Messages/UserMessage.svelte

@@ -24,6 +24,8 @@
 
 		editElement.style.height = '';
 		editElement.style.height = `${editElement.scrollHeight}px`;
+
+		editElement?.focus();
 	};
 
 	const editMessageConfirmHandler = async () => {
@@ -43,7 +45,9 @@
 	<ProfileImage src={user?.profile_image_url ?? '/user.png'} />
 
 	<div class="w-full overflow-hidden">
-		<Name>You</Name>
+		<div class="user-message">
+			<Name>You</Name>
+		</div>
 
 		<div
 			class="prose chat-{message.role} w-full max-w-full dark:prose-invert prose-headings:my-0 prose-p:my-0 prose-p:-mb-4 prose-pre:my-0 prose-table:my-0 prose-blockquote:my-0 prose-img:my-0 prose-ul:-my-4 prose-ol:-my-4 prose-li:-my-3 prose-ul:-mb-6 prose-ol:-mb-6 prose-li:-mb-4 whitespace-pre-line"
@@ -145,7 +149,7 @@
 						{/if}
 
 						<button
-							class="invisible group-hover:visible p-1 rounded dark:hover:bg-gray-800 transition"
+							class="invisible group-hover:visible p-1 rounded dark:hover:bg-gray-800 transition edit-user-message-button"
 							on:click={() => {
 								editMessageHandler();
 							}}

+ 46 - 5
src/lib/components/layout/Navbar.svelte

@@ -1,7 +1,10 @@
 <script lang="ts">
+	import toast from 'svelte-french-toast';
+	import fileSaver from 'file-saver';
+	const { saveAs } = fileSaver;
+
 	import { getChatById } from '$lib/apis/chats';
 	import { chatId, db, modelfiles } from '$lib/stores';
-	import toast from 'svelte-french-toast';
 
 	export let initNewChat: Function;
 	export let title: string = 'Ollama Web UI';
@@ -33,6 +36,21 @@
 			false
 		);
 	};
+
+	const downloadChat = async () => {
+		const chat = (await getChatById(localStorage.token, $chatId)).chat;
+		console.log('download', chat);
+
+		const chatText = chat.messages.reduce((a, message, i, arr) => {
+			return `${a}### ${message.role.toUpperCase()}\n${message.content}\n\n`;
+		}, '');
+
+		let blob = new Blob([chatText], {
+			type: 'text/plain'
+		});
+
+		saveAs(blob, `chat-${chat.title}.txt`);
+	};
 </script>
 
 <nav
@@ -69,7 +87,30 @@
 			</div>
 
 			{#if shareEnabled}
-				<div class="pl-2">
+				<div class="pl-2 flex space-x-1.5">
+					<button
+						class=" cursor-pointer p-2 flex dark:hover:bg-gray-700 rounded-lg transition border dark:border-gray-600"
+						on:click={async () => {
+							downloadChat();
+						}}
+					>
+						<div class=" m-auto self-center">
+							<svg
+								xmlns="http://www.w3.org/2000/svg"
+								viewBox="0 0 16 16"
+								fill="currentColor"
+								class="w-4 h-4"
+							>
+								<path
+									d="M8.75 2.75a.75.75 0 0 0-1.5 0v5.69L5.03 6.22a.75.75 0 0 0-1.06 1.06l3.5 3.5a.75.75 0 0 0 1.06 0l3.5-3.5a.75.75 0 0 0-1.06-1.06L8.75 8.44V2.75Z"
+								/>
+								<path
+									d="M3.5 9.75a.75.75 0 0 0-1.5 0v1.5A2.75 2.75 0 0 0 4.75 14h6.5A2.75 2.75 0 0 0 14 11.25v-1.5a.75.75 0 0 0-1.5 0v1.5c0 .69-.56 1.25-1.25 1.25h-6.5c-.69 0-1.25-.56-1.25-1.25v-1.5Z"
+								/>
+							</svg>
+						</div>
+					</button>
+
 					<button
 						class=" cursor-pointer p-2 flex dark:hover:bg-gray-700 rounded-lg transition border dark:border-gray-600"
 						on:click={async () => {
@@ -79,15 +120,15 @@
 						<div class=" m-auto self-center">
 							<svg
 								xmlns="http://www.w3.org/2000/svg"
-								viewBox="0 0 20 20"
+								viewBox="0 0 16 16"
 								fill="currentColor"
 								class="w-4 h-4"
 							>
 								<path
-									d="M9.25 13.25a.75.75 0 001.5 0V4.636l2.955 3.129a.75.75 0 001.09-1.03l-4.25-4.5a.75.75 0 00-1.09 0l-4.25 4.5a.75.75 0 101.09 1.03L9.25 4.636v8.614z"
+									d="M7.25 10.25a.75.75 0 0 0 1.5 0V4.56l2.22 2.22a.75.75 0 1 0 1.06-1.06l-3.5-3.5a.75.75 0 0 0-1.06 0l-3.5 3.5a.75.75 0 0 0 1.06 1.06l2.22-2.22v5.69Z"
 								/>
 								<path
-									d="M3.5 12.75a.75.75 0 00-1.5 0v2.5A2.75 2.75 0 004.75 18h10.5A2.75 2.75 0 0018 15.25v-2.5a.75.75 0 00-1.5 0v2.5c0 .69-.56 1.25-1.25 1.25H4.75c-.69 0-1.25-.56-1.25-1.25v-2.5z"
+									d="M3.5 9.75a.75.75 0 0 0-1.5 0v1.5A2.75 2.75 0 0 0 4.75 14h6.5A2.75 2.75 0 0 0 14 11.25v-1.5a.75.75 0 0 0-1.5 0v1.5c0 .69-.56 1.25-1.25 1.25h-6.5c-.69 0-1.25-.56-1.25-1.25v-1.5Z"
 								/>
 							</svg>
 						</div>