瀏覽代碼

Add batching

Gabriel Ecegi 4 月之前
父節點
當前提交
f2e2b59c18
共有 2 個文件被更改,包括 182 次插入14 次删除
  1. 104 10
      backend/open_webui/apps/retrieval/main.py
  2. 78 4
      backend/open_webui/apps/webui/routers/knowledge.py

+ 104 - 10
backend/open_webui/apps/retrieval/main.py

@@ -2,16 +2,14 @@
 
 
 import json
 import json
 import logging
 import logging
-import mimetypes
 import os
 import os
 import shutil
 import shutil
 
 
 import uuid
 import uuid
 from datetime import datetime
 from datetime import datetime
-from pathlib import Path
-from typing import Iterator, Optional, Sequence, Union
+from typing import List, Optional
 
 
-from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status
+from fastapi import Depends, FastAPI, HTTPException, status
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from pydantic import BaseModel
 from pydantic import BaseModel
 import tiktoken
 import tiktoken
@@ -52,7 +50,7 @@ from open_webui.apps.retrieval.utils import (
     query_doc_with_hybrid_search,
     query_doc_with_hybrid_search,
 )
 )
 
 
-from open_webui.apps.webui.models.files import Files
+from open_webui.apps.webui.models.files import FileModel, Files
 from open_webui.config import (
 from open_webui.config import (
     BRAVE_SEARCH_API_KEY,
     BRAVE_SEARCH_API_KEY,
     KAGI_SEARCH_API_KEY,
     KAGI_SEARCH_API_KEY,
@@ -64,7 +62,6 @@ from open_webui.config import (
     CONTENT_EXTRACTION_ENGINE,
     CONTENT_EXTRACTION_ENGINE,
     CORS_ALLOW_ORIGIN,
     CORS_ALLOW_ORIGIN,
     ENABLE_RAG_HYBRID_SEARCH,
     ENABLE_RAG_HYBRID_SEARCH,
-    ENABLE_RAG_LOCAL_WEB_FETCH,
     ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
     ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
     ENABLE_RAG_WEB_SEARCH,
     ENABLE_RAG_WEB_SEARCH,
     ENV,
     ENV,
@@ -86,7 +83,6 @@ from open_webui.config import (
     RAG_RERANKING_MODEL,
     RAG_RERANKING_MODEL,
     RAG_RERANKING_MODEL_AUTO_UPDATE,
     RAG_RERANKING_MODEL_AUTO_UPDATE,
     RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
     RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
-    DEFAULT_RAG_TEMPLATE,
     RAG_TEMPLATE,
     RAG_TEMPLATE,
     RAG_TOP_K,
     RAG_TOP_K,
     RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
     RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
@@ -118,10 +114,7 @@ from open_webui.env import (
     DOCKER,
     DOCKER,
 )
 )
 from open_webui.utils.misc import (
 from open_webui.utils.misc import (
-    calculate_sha256,
     calculate_sha256_string,
     calculate_sha256_string,
-    extract_folders_after_data_docs,
-    sanitize_filename,
 )
 )
 from open_webui.utils.auth import get_admin_user, get_verified_user
 from open_webui.utils.auth import get_admin_user, get_verified_user
 
 
@@ -1047,6 +1040,106 @@ def process_file(
             )
             )
 
 
 
 
+class BatchProcessFilesForm(BaseModel):
+    files: List[FileModel]
+    collection_name: str
+
+class BatchProcessFilesResult(BaseModel):
+    file_id: str
+    status: str
+    error: Optional[str] = None
+
+class BatchProcessFilesResponse(BaseModel):
+    results: List[BatchProcessFilesResult]
+    errors: List[BatchProcessFilesResult]
+
+@app.post("/process/files/batch")
+def process_files_batch(
+    form_data: BatchProcessFilesForm,
+    user=Depends(get_verified_user),
+) -> BatchProcessFilesResponse:
+    """
+    Process a batch of files and save them to the vector database.
+    """
+    results: List[BatchProcessFilesResult] = []
+    errors: List[BatchProcessFilesResult]   = []
+    collection_name = form_data.collection_name
+    
+
+    # Prepare all documents first
+    all_docs: List[Document] = []
+    for file_request in form_data.files:
+        try:
+            file = Files.get_file_by_id(file_request.file_id)
+            if not file:
+                log.error(f"process_files_batch: File {file_request.file_id} not found")
+                raise ValueError(f"File {file_request.file_id} not found")
+
+            text_content = file_request.content
+            
+            docs: List[Document] = [
+                Document(
+                    page_content=text_content.replace("<br/>", "\n"),
+                    metadata={
+                        **file.meta,
+                        "name": file_request.filename,
+                        "created_by": file.user_id,
+                        "file_id": file.id,
+                        "source": file_request.filename,
+                    },
+                )
+            ]
+
+            hash = calculate_sha256_string(text_content)
+            Files.update_file_hash_by_id(file.id, hash)
+            Files.update_file_data_by_id(file.id, {"content": text_content})
+            
+            all_docs.extend(docs)
+            results.append(BatchProcessFilesResult(
+                file_id=file.id,
+                status="prepared"
+            ))
+
+        except Exception as e:
+            log.error(f"process_files_batch: Error processing file {file_request.file_id}: {str(e)}")
+            errors.append(BatchProcessFilesResult(
+                file_id=file_request.file_id,
+                status="failed",
+                error=str(e)
+            ))
+
+    # Save all documents in one batch
+    if all_docs:
+        try:
+            save_docs_to_vector_db(
+                docs=all_docs,
+                collection_name=collection_name,
+                add=True
+            )
+            
+            # Update all files with collection name
+            for result in results:
+                Files.update_file_metadata_by_id(
+                    result.file_id,
+                    {"collection_name": collection_name}
+                )
+                result.status = "completed"
+
+        except Exception as e:
+            log.error(f"process_files_batch: Error saving documents to vector DB: {str(e)}")
+            for result in results:
+                result.status = "failed"
+                errors.append(BatchProcessFilesResult(
+                    file_id=result.file_id,
+                    error=str(e)
+                ))
+
+    return BatchProcessFilesResponse(
+        results=results,
+        errors=errors
+    )
+
+
 class ProcessTextForm(BaseModel):
 class ProcessTextForm(BaseModel):
     name: str
     name: str
     content: str
     content: str
@@ -1509,3 +1602,4 @@ if ENV == "dev":
     @app.get("/ef/{text}")
     @app.get("/ef/{text}")
     async def get_embeddings_text(text: str):
     async def get_embeddings_text(text: str):
         return {"result": app.state.EMBEDDING_FUNCTION(text)}
         return {"result": app.state.EMBEDDING_FUNCTION(text)}
+

+ 78 - 4
backend/open_webui/apps/webui/routers/knowledge.py

@@ -1,5 +1,4 @@
-import json
-from typing import Optional, Union
+from typing import List, Optional
 from pydantic import BaseModel
 from pydantic import BaseModel
 from fastapi import APIRouter, Depends, HTTPException, status, Request
 from fastapi import APIRouter, Depends, HTTPException, status, Request
 import logging
 import logging
@@ -12,11 +11,11 @@ from open_webui.apps.webui.models.knowledge import (
 )
 )
 from open_webui.apps.webui.models.files import Files, FileModel
 from open_webui.apps.webui.models.files import Files, FileModel
 from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
 from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
-from open_webui.apps.retrieval.main import process_file, ProcessFileForm
+from open_webui.apps.retrieval.main import BatchProcessFilesForm, process_file, ProcessFileForm, process_files_batch
 
 
 
 
 from open_webui.constants import ERROR_MESSAGES
 from open_webui.constants import ERROR_MESSAGES
-from open_webui.utils.auth import get_admin_user, get_verified_user
+from open_webui.utils.auth import get_verified_user
 from open_webui.utils.access_control import has_access, has_permission
 from open_webui.utils.access_control import has_access, has_permission
 
 
 
 
@@ -508,3 +507,78 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
     knowledge = Knowledges.update_knowledge_data_by_id(id=id, data={"file_ids": []})
     knowledge = Knowledges.update_knowledge_data_by_id(id=id, data={"file_ids": []})
 
 
     return knowledge
     return knowledge
+
+
+############################
+# AddFilesToKnowledge
+############################
+
+@router.post("/{id}/files/batch/add", response_model=Optional[KnowledgeFilesResponse])
+def add_files_to_knowledge_batch(
+    id: str,
+    form_data: list[KnowledgeFileIdForm],
+    user=Depends(get_verified_user),
+):
+    """
+    Add multiple files to a knowledge base
+    """
+    knowledge = Knowledges.get_knowledge_by_id(id=id)
+    if not knowledge:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+    if knowledge.user_id != user.id and user.role != "admin":
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
+
+    # Get files content
+    print(f"files/batch/add - {len(form_data)} files")
+    files: List[FileModel] = []
+    for form in form_data:
+        file = Files.get_file_by_id(form.file_id)
+        if not file:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=f"File {form.file_id} not found",
+            )
+        files.append(file)
+
+    # Process files
+    result = process_files_batch(BatchProcessFilesForm(
+        files=files,
+        collection_name=id
+    ))
+    
+    # Add successful files to knowledge base
+    data = knowledge.data or {}
+    existing_file_ids = data.get("file_ids", [])
+    
+    # Only add files that were successfully processed
+    successful_file_ids = [r.file_id for r in result.results if r.status == "completed"]
+    for file_id in successful_file_ids:
+        if file_id not in existing_file_ids:
+            existing_file_ids.append(file_id)
+    
+    data["file_ids"] = existing_file_ids
+    knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
+
+    # If there were any errors, include them in the response
+    if result.errors:
+        error_details = [f"{err.file_id}: {err.error}" for err in result.errors]
+        return KnowledgeFilesResponse(
+            **knowledge.model_dump(),
+            files=Files.get_files_by_ids(existing_file_ids),
+            warnings={
+                "message": "Some files failed to process",
+                "errors": error_details
+            }
+        )
+
+    return KnowledgeFilesResponse(
+        **knowledge.model_dump(),
+        files=Files.get_files_by_ids(existing_file_ids)
+    )