Timothy J. Baek 10 月之前
父節點
當前提交
3f5f410453

+ 4 - 4
backend/apps/images/main.py

@@ -16,7 +16,7 @@ from faster_whisper import WhisperModel
 
 from constants import ERROR_MESSAGES
 from utils.utils import (
-    get_current_user,
+    get_verified_user,
     get_admin_user,
 )
 
@@ -258,7 +258,7 @@ async def update_image_size(
 
 
 @app.get("/models")
-def get_models(user=Depends(get_current_user)):
+def get_models(user=Depends(get_verified_user)):
     try:
         if app.state.config.ENGINE == "openai":
             return [
@@ -347,7 +347,7 @@ def set_model_handler(model: str):
 @app.post("/models/default/update")
 def update_default_model(
     form_data: UpdateModelForm,
-    user=Depends(get_current_user),
+    user=Depends(get_verified_user),
 ):
     return set_model_handler(form_data.model)
 
@@ -424,7 +424,7 @@ def save_url_image(url):
 @app.post("/generations")
 def generate_image(
     form_data: GenerateImageForm,
-    user=Depends(get_current_user),
+    user=Depends(get_verified_user),
 ):
     width, height = tuple(map(int, app.state.config.IMAGE_SIZE.split("x")))
 

+ 2 - 2
backend/apps/openai/main.py

@@ -16,7 +16,7 @@ from apps.webui.models.users import Users
 from constants import ERROR_MESSAGES
 from utils.utils import (
     decode_token,
-    get_current_user,
+    get_verified_user,
     get_verified_user,
     get_admin_user,
 )
@@ -296,7 +296,7 @@ async def get_all_models(raw: bool = False):
 
 @app.get("/models")
 @app.get("/models/{url_idx}")
-async def get_models(url_idx: Optional[int] = None, user=Depends(get_current_user)):
+async def get_models(url_idx: Optional[int] = None, user=Depends(get_verified_user)):
     if url_idx == None:
         models = await get_all_models()
         if app.state.config.ENABLE_MODEL_FILTER:

+ 10 - 10
backend/apps/rag/main.py

@@ -85,7 +85,7 @@ from utils.misc import (
     sanitize_filename,
     extract_folders_after_data_docs,
 )
-from utils.utils import get_current_user, get_admin_user
+from utils.utils import get_verified_user, get_admin_user
 
 from config import (
     AppConfig,
@@ -529,7 +529,7 @@ async def update_rag_config(form_data: ConfigUpdateForm, user=Depends(get_admin_
 
 
 @app.get("/template")
-async def get_rag_template(user=Depends(get_current_user)):
+async def get_rag_template(user=Depends(get_verified_user)):
     return {
         "status": True,
         "template": app.state.config.RAG_TEMPLATE,
@@ -586,7 +586,7 @@ class QueryDocForm(BaseModel):
 @app.post("/query/doc")
 def query_doc_handler(
     form_data: QueryDocForm,
-    user=Depends(get_current_user),
+    user=Depends(get_verified_user),
 ):
     try:
         if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
@@ -626,7 +626,7 @@ class QueryCollectionsForm(BaseModel):
 @app.post("/query/collection")
 def query_collection_handler(
     form_data: QueryCollectionsForm,
-    user=Depends(get_current_user),
+    user=Depends(get_verified_user),
 ):
     try:
         if app.state.config.ENABLE_RAG_HYBRID_SEARCH:
@@ -657,7 +657,7 @@ def query_collection_handler(
 
 
 @app.post("/youtube")
-def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
+def store_youtube_video(form_data: UrlForm, user=Depends(get_verified_user)):
     try:
         loader = YoutubeLoader.from_youtube_url(
             form_data.url,
@@ -686,7 +686,7 @@ def store_youtube_video(form_data: UrlForm, user=Depends(get_current_user)):
 
 
 @app.post("/web")
-def store_web(form_data: UrlForm, user=Depends(get_current_user)):
+def store_web(form_data: UrlForm, user=Depends(get_verified_user)):
     # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
     try:
         loader = get_web_loader(
@@ -864,7 +864,7 @@ def search_web(engine: str, query: str) -> list[SearchResult]:
 
 
 @app.post("/web/search")
-def store_web_search(form_data: SearchForm, user=Depends(get_current_user)):
+def store_web_search(form_data: SearchForm, user=Depends(get_verified_user)):
     try:
         logging.info(
             f"trying to web search with {app.state.config.RAG_WEB_SEARCH_ENGINE, form_data.query}"
@@ -1084,7 +1084,7 @@ def get_loader(filename: str, file_content_type: str, file_path: str):
 def store_doc(
     collection_name: Optional[str] = Form(None),
     file: UploadFile = File(...),
-    user=Depends(get_current_user),
+    user=Depends(get_verified_user),
 ):
     # "https://www.gutenberg.org/files/1727/1727-h/1727-h.htm"
 
@@ -1145,7 +1145,7 @@ class ProcessDocForm(BaseModel):
 @app.post("/process/doc")
 def process_doc(
     form_data: ProcessDocForm,
-    user=Depends(get_current_user),
+    user=Depends(get_verified_user),
 ):
     try:
         file = Files.get_file_by_id(form_data.file_id)
@@ -1200,7 +1200,7 @@ class TextRAGForm(BaseModel):
 @app.post("/text")
 def store_text(
     form_data: TextRAGForm,
-    user=Depends(get_current_user),
+    user=Depends(get_verified_user),
 ):
 
     collection_name = form_data.collection_name

+ 22 - 22
backend/apps/webui/routers/chats.py

@@ -1,7 +1,7 @@
 from fastapi import Depends, Request, HTTPException, status
 from datetime import datetime, timedelta
 from typing import List, Union, Optional
-from utils.utils import get_current_user, get_admin_user
+from utils.utils import get_verified_user, get_admin_user
 from fastapi import APIRouter
 from pydantic import BaseModel
 import json
@@ -43,7 +43,7 @@ router = APIRouter()
 @router.get("/", response_model=List[ChatTitleIdResponse])
 @router.get("/list", response_model=List[ChatTitleIdResponse])
 async def get_session_user_chat_list(
-    user=Depends(get_current_user), skip: int = 0, limit: int = 50
+    user=Depends(get_verified_user), skip: int = 0, limit: int = 50
 ):
     return Chats.get_chat_list_by_user_id(user.id, skip, limit)
 
@@ -54,7 +54,7 @@ async def get_session_user_chat_list(
 
 
 @router.delete("/", response_model=bool)
-async def delete_all_user_chats(request: Request, user=Depends(get_current_user)):
+async def delete_all_user_chats(request: Request, user=Depends(get_verified_user)):
 
     if (
         user.role == "user"
@@ -89,7 +89,7 @@ async def get_user_chat_list_by_user_id(
 
 
 @router.post("/new", response_model=Optional[ChatResponse])
-async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
+async def create_new_chat(form_data: ChatForm, user=Depends(get_verified_user)):
     try:
         chat = Chats.insert_new_chat(user.id, form_data)
         return ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
@@ -106,7 +106,7 @@ async def create_new_chat(form_data: ChatForm, user=Depends(get_current_user)):
 
 
 @router.get("/all", response_model=List[ChatResponse])
-async def get_user_chats(user=Depends(get_current_user)):
+async def get_user_chats(user=Depends(get_verified_user)):
     return [
         ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
         for chat in Chats.get_chats_by_user_id(user.id)
@@ -119,7 +119,7 @@ async def get_user_chats(user=Depends(get_current_user)):
 
 
 @router.get("/all/archived", response_model=List[ChatResponse])
-async def get_user_chats(user=Depends(get_current_user)):
+async def get_user_chats(user=Depends(get_verified_user)):
     return [
         ChatResponse(**{**chat.model_dump(), "chat": json.loads(chat.chat)})
         for chat in Chats.get_archived_chats_by_user_id(user.id)
@@ -151,7 +151,7 @@ async def get_all_user_chats_in_db(user=Depends(get_admin_user)):
 
 @router.get("/archived", response_model=List[ChatTitleIdResponse])
 async def get_archived_session_user_chat_list(
-    user=Depends(get_current_user), skip: int = 0, limit: int = 50
+    user=Depends(get_verified_user), skip: int = 0, limit: int = 50
 ):
     return Chats.get_archived_chat_list_by_user_id(user.id, skip, limit)
 
@@ -162,7 +162,7 @@ async def get_archived_session_user_chat_list(
 
 
 @router.post("/archive/all", response_model=bool)
-async def archive_all_chats(user=Depends(get_current_user)):
+async def archive_all_chats(user=Depends(get_verified_user)):
     return Chats.archive_all_chats_by_user_id(user.id)
 
 
@@ -172,7 +172,7 @@ async def archive_all_chats(user=Depends(get_current_user)):
 
 
 @router.get("/share/{share_id}", response_model=Optional[ChatResponse])
-async def get_shared_chat_by_id(share_id: str, user=Depends(get_current_user)):
+async def get_shared_chat_by_id(share_id: str, user=Depends(get_verified_user)):
     if user.role == "pending":
         raise HTTPException(
             status_code=status.HTTP_401_UNAUTHORIZED, detail=ERROR_MESSAGES.NOT_FOUND
@@ -204,7 +204,7 @@ class TagNameForm(BaseModel):
 
 @router.post("/tags", response_model=List[ChatTitleIdResponse])
 async def get_user_chat_list_by_tag_name(
-    form_data: TagNameForm, user=Depends(get_current_user)
+    form_data: TagNameForm, user=Depends(get_verified_user)
 ):
 
     print(form_data)
@@ -229,7 +229,7 @@ async def get_user_chat_list_by_tag_name(
 
 
 @router.get("/tags/all", response_model=List[TagModel])
-async def get_all_tags(user=Depends(get_current_user)):
+async def get_all_tags(user=Depends(get_verified_user)):
     try:
         tags = Tags.get_tags_by_user_id(user.id)
         return tags
@@ -246,7 +246,7 @@ async def get_all_tags(user=Depends(get_current_user)):
 
 
 @router.get("/{id}", response_model=Optional[ChatResponse])
-async def get_chat_by_id(id: str, user=Depends(get_current_user)):
+async def get_chat_by_id(id: str, user=Depends(get_verified_user)):
     chat = Chats.get_chat_by_id_and_user_id(id, user.id)
 
     if chat:
@@ -264,7 +264,7 @@ async def get_chat_by_id(id: str, user=Depends(get_current_user)):
 
 @router.post("/{id}", response_model=Optional[ChatResponse])
 async def update_chat_by_id(
-    id: str, form_data: ChatForm, user=Depends(get_current_user)
+    id: str, form_data: ChatForm, user=Depends(get_verified_user)
 ):
     chat = Chats.get_chat_by_id_and_user_id(id, user.id)
     if chat:
@@ -285,7 +285,7 @@ async def update_chat_by_id(
 
 
 @router.delete("/{id}", response_model=bool)
-async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_user)):
+async def delete_chat_by_id(request: Request, id: str, user=Depends(get_verified_user)):
 
     if user.role == "admin":
         result = Chats.delete_chat_by_id(id)
@@ -307,7 +307,7 @@ async def delete_chat_by_id(request: Request, id: str, user=Depends(get_current_
 
 
 @router.get("/{id}/clone", response_model=Optional[ChatResponse])
-async def clone_chat_by_id(id: str, user=Depends(get_current_user)):
+async def clone_chat_by_id(id: str, user=Depends(get_verified_user)):
     chat = Chats.get_chat_by_id_and_user_id(id, user.id)
     if chat:
 
@@ -333,7 +333,7 @@ async def clone_chat_by_id(id: str, user=Depends(get_current_user)):
 
 
 @router.get("/{id}/archive", response_model=Optional[ChatResponse])
-async def archive_chat_by_id(id: str, user=Depends(get_current_user)):
+async def archive_chat_by_id(id: str, user=Depends(get_verified_user)):
     chat = Chats.get_chat_by_id_and_user_id(id, user.id)
     if chat:
         chat = Chats.toggle_chat_archive_by_id(id)
@@ -350,7 +350,7 @@ async def archive_chat_by_id(id: str, user=Depends(get_current_user)):
 
 
 @router.post("/{id}/share", response_model=Optional[ChatResponse])
-async def share_chat_by_id(id: str, user=Depends(get_current_user)):
+async def share_chat_by_id(id: str, user=Depends(get_verified_user)):
     chat = Chats.get_chat_by_id_and_user_id(id, user.id)
     if chat:
         if chat.share_id:
@@ -382,7 +382,7 @@ async def share_chat_by_id(id: str, user=Depends(get_current_user)):
 
 
 @router.delete("/{id}/share", response_model=Optional[bool])
-async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)):
+async def delete_shared_chat_by_id(id: str, user=Depends(get_verified_user)):
     chat = Chats.get_chat_by_id_and_user_id(id, user.id)
     if chat:
         if not chat.share_id:
@@ -405,7 +405,7 @@ async def delete_shared_chat_by_id(id: str, user=Depends(get_current_user)):
 
 
 @router.get("/{id}/tags", response_model=List[TagModel])
-async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)):
+async def get_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
     tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
 
     if tags != None:
@@ -423,7 +423,7 @@ async def get_chat_tags_by_id(id: str, user=Depends(get_current_user)):
 
 @router.post("/{id}/tags", response_model=Optional[ChatIdTagModel])
 async def add_chat_tag_by_id(
-    id: str, form_data: ChatIdTagForm, user=Depends(get_current_user)
+    id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user)
 ):
     tags = Tags.get_tags_by_chat_id_and_user_id(id, user.id)
 
@@ -450,7 +450,7 @@ async def add_chat_tag_by_id(
 
 @router.delete("/{id}/tags", response_model=Optional[bool])
 async def delete_chat_tag_by_id(
-    id: str, form_data: ChatIdTagForm, user=Depends(get_current_user)
+    id: str, form_data: ChatIdTagForm, user=Depends(get_verified_user)
 ):
     result = Tags.delete_tag_by_tag_name_and_chat_id_and_user_id(
         form_data.tag_name, id, user.id
@@ -470,7 +470,7 @@ async def delete_chat_tag_by_id(
 
 
 @router.delete("/{id}/tags/all", response_model=Optional[bool])
-async def delete_all_chat_tags_by_id(id: str, user=Depends(get_current_user)):
+async def delete_all_chat_tags_by_id(id: str, user=Depends(get_verified_user)):
     result = Tags.delete_tags_by_chat_id_and_user_id(id, user.id)
 
     if result:

+ 2 - 2
backend/apps/webui/routers/configs.py

@@ -14,7 +14,7 @@ from apps.webui.models.users import Users
 
 from utils.utils import (
     get_password_hash,
-    get_current_user,
+    get_verified_user,
     get_admin_user,
     create_token,
 )
@@ -84,6 +84,6 @@ async def set_banners(
 @router.get("/banners", response_model=List[BannerModel])
 async def get_banners(
     request: Request,
-    user=Depends(get_current_user),
+    user=Depends(get_verified_user),
 ):
     return request.app.state.config.BANNERS

+ 4 - 4
backend/apps/webui/routers/documents.py

@@ -14,7 +14,7 @@ from apps.webui.models.documents import (
     DocumentResponse,
 )
 
-from utils.utils import get_current_user, get_admin_user
+from utils.utils import get_verified_user, get_admin_user
 from constants import ERROR_MESSAGES
 
 router = APIRouter()
@@ -25,7 +25,7 @@ router = APIRouter()
 
 
 @router.get("/", response_model=List[DocumentResponse])
-async def get_documents(user=Depends(get_current_user)):
+async def get_documents(user=Depends(get_verified_user)):
     docs = [
         DocumentResponse(
             **{
@@ -74,7 +74,7 @@ async def create_new_doc(form_data: DocumentForm, user=Depends(get_admin_user)):
 
 
 @router.get("/doc", response_model=Optional[DocumentResponse])
-async def get_doc_by_name(name: str, user=Depends(get_current_user)):
+async def get_doc_by_name(name: str, user=Depends(get_verified_user)):
     doc = Documents.get_doc_by_name(name)
 
     if doc:
@@ -106,7 +106,7 @@ class TagDocumentForm(BaseModel):
 
 
 @router.post("/doc/tags", response_model=Optional[DocumentResponse])
-async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_current_user)):
+async def tag_doc_by_name(form_data: TagDocumentForm, user=Depends(get_verified_user)):
     doc = Documents.update_doc_content_by_name(form_data.name, {"tags": form_data.tags})
 
     if doc:

+ 3 - 3
backend/apps/webui/routers/prompts.py

@@ -8,7 +8,7 @@ import json
 
 from apps.webui.models.prompts import Prompts, PromptForm, PromptModel
 
-from utils.utils import get_current_user, get_admin_user
+from utils.utils import get_verified_user, get_admin_user
 from constants import ERROR_MESSAGES
 
 router = APIRouter()
@@ -19,7 +19,7 @@ router = APIRouter()
 
 
 @router.get("/", response_model=List[PromptModel])
-async def get_prompts(user=Depends(get_current_user)):
+async def get_prompts(user=Depends(get_verified_user)):
     return Prompts.get_prompts()
 
 
@@ -52,7 +52,7 @@ async def create_new_prompt(form_data: PromptForm, user=Depends(get_admin_user))
 
 
 @router.get("/command/{command}", response_model=Optional[PromptModel])
-async def get_prompt_by_command(command: str, user=Depends(get_current_user)):
+async def get_prompt_by_command(command: str, user=Depends(get_verified_user)):
     prompt = Prompts.get_prompt_by_command(f"/{command}")
 
     if prompt: