Browse Source

Add batching

Gabriel Ecegi 4 months ago
parent
commit
f2e2b59c18

+ 104 - 10
backend/open_webui/apps/retrieval/main.py

@@ -2,16 +2,14 @@
 
 import json
 import logging
-import mimetypes
 import os
 import shutil
 
 import uuid
 from datetime import datetime
-from pathlib import Path
-from typing import Iterator, Optional, Sequence, Union
+from typing import List, Optional
 
-from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status
+from fastapi import Depends, FastAPI, HTTPException, status
 from fastapi.middleware.cors import CORSMiddleware
 from pydantic import BaseModel
 import tiktoken
@@ -52,7 +50,7 @@ from open_webui.apps.retrieval.utils import (
     query_doc_with_hybrid_search,
 )
 
-from open_webui.apps.webui.models.files import Files
+from open_webui.apps.webui.models.files import FileModel, Files
 from open_webui.config import (
     BRAVE_SEARCH_API_KEY,
     KAGI_SEARCH_API_KEY,
@@ -64,7 +62,6 @@ from open_webui.config import (
     CONTENT_EXTRACTION_ENGINE,
     CORS_ALLOW_ORIGIN,
     ENABLE_RAG_HYBRID_SEARCH,
-    ENABLE_RAG_LOCAL_WEB_FETCH,
     ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
     ENABLE_RAG_WEB_SEARCH,
     ENV,
@@ -86,7 +83,6 @@ from open_webui.config import (
     RAG_RERANKING_MODEL,
     RAG_RERANKING_MODEL_AUTO_UPDATE,
     RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
-    DEFAULT_RAG_TEMPLATE,
     RAG_TEMPLATE,
     RAG_TOP_K,
     RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
@@ -118,10 +114,7 @@ from open_webui.env import (
     DOCKER,
 )
 from open_webui.utils.misc import (
-    calculate_sha256,
     calculate_sha256_string,
-    extract_folders_after_data_docs,
-    sanitize_filename,
 )
 from open_webui.utils.auth import get_admin_user, get_verified_user
 
@@ -1047,6 +1040,106 @@ def process_file(
             )
 
 
+class BatchProcessFilesForm(BaseModel):
+    files: List[FileModel]
+    collection_name: str
+
+class BatchProcessFilesResult(BaseModel):
+    file_id: str
+    status: str
+    error: Optional[str] = None
+
+class BatchProcessFilesResponse(BaseModel):
+    results: List[BatchProcessFilesResult]
+    errors: List[BatchProcessFilesResult]
+
+@app.post("/process/files/batch")
+def process_files_batch(
+    form_data: BatchProcessFilesForm,
+    user=Depends(get_verified_user),
+) -> BatchProcessFilesResponse:
+    """
+    Process a batch of files and save them to the vector database.
+    """
+    results: List[BatchProcessFilesResult] = []
+    errors: List[BatchProcessFilesResult]   = []
+    collection_name = form_data.collection_name
+    
+
+    # Prepare all documents first
+    all_docs: List[Document] = []
+    for file_request in form_data.files:
+        try:
+            file = Files.get_file_by_id(file_request.file_id)
+            if not file:
+                log.error(f"process_files_batch: File {file_request.file_id} not found")
+                raise ValueError(f"File {file_request.file_id} not found")
+
+            text_content = file_request.content
+            
+            docs: List[Document] = [
+                Document(
+                    page_content=text_content.replace("<br/>", "\n"),
+                    metadata={
+                        **file.meta,
+                        "name": file_request.filename,
+                        "created_by": file.user_id,
+                        "file_id": file.id,
+                        "source": file_request.filename,
+                    },
+                )
+            ]
+
+            hash = calculate_sha256_string(text_content)
+            Files.update_file_hash_by_id(file.id, hash)
+            Files.update_file_data_by_id(file.id, {"content": text_content})
+            
+            all_docs.extend(docs)
+            results.append(BatchProcessFilesResult(
+                file_id=file.id,
+                status="prepared"
+            ))
+
+        except Exception as e:
+            log.error(f"process_files_batch: Error processing file {file_request.file_id}: {str(e)}")
+            errors.append(BatchProcessFilesResult(
+                file_id=file_request.file_id,
+                status="failed",
+                error=str(e)
+            ))
+
+    # Save all documents in one batch
+    if all_docs:
+        try:
+            save_docs_to_vector_db(
+                docs=all_docs,
+                collection_name=collection_name,
+                add=True
+            )
+            
+            # Update all files with collection name
+            for result in results:
+                Files.update_file_metadata_by_id(
+                    result.file_id,
+                    {"collection_name": collection_name}
+                )
+                result.status = "completed"
+
+        except Exception as e:
+            log.error(f"process_files_batch: Error saving documents to vector DB: {str(e)}")
+            for result in results:
+                result.status = "failed"
+                errors.append(BatchProcessFilesResult(
+                    file_id=result.file_id,
+                    error=str(e)
+                ))
+
+    return BatchProcessFilesResponse(
+        results=results,
+        errors=errors
+    )
+
+
 class ProcessTextForm(BaseModel):
     name: str
     content: str
@@ -1509,3 +1602,4 @@ if ENV == "dev":
     @app.get("/ef/{text}")
     async def get_embeddings_text(text: str):
         return {"result": app.state.EMBEDDING_FUNCTION(text)}
+

+ 78 - 4
backend/open_webui/apps/webui/routers/knowledge.py

@@ -1,5 +1,4 @@
-import json
-from typing import Optional, Union
+from typing import List, Optional
 from pydantic import BaseModel
 from fastapi import APIRouter, Depends, HTTPException, status, Request
 import logging
@@ -12,11 +11,11 @@ from open_webui.apps.webui.models.knowledge import (
 )
 from open_webui.apps.webui.models.files import Files, FileModel
 from open_webui.apps.retrieval.vector.connector import VECTOR_DB_CLIENT
-from open_webui.apps.retrieval.main import process_file, ProcessFileForm
+from open_webui.apps.retrieval.main import BatchProcessFilesForm, process_file, ProcessFileForm, process_files_batch
 
 
 from open_webui.constants import ERROR_MESSAGES
-from open_webui.utils.auth import get_admin_user, get_verified_user
+from open_webui.utils.auth import get_verified_user
 from open_webui.utils.access_control import has_access, has_permission
 
 
@@ -508,3 +507,78 @@ async def reset_knowledge_by_id(id: str, user=Depends(get_verified_user)):
     knowledge = Knowledges.update_knowledge_data_by_id(id=id, data={"file_ids": []})
 
     return knowledge
+
+
+############################
+# AddFilesToKnowledge
+############################
+
+@router.post("/{id}/files/batch/add", response_model=Optional[KnowledgeFilesResponse])
+def add_files_to_knowledge_batch(
+    id: str,
+    form_data: list[KnowledgeFileIdForm],
+    user=Depends(get_verified_user),
+):
+    """
+    Add multiple files to a knowledge base
+    """
+    knowledge = Knowledges.get_knowledge_by_id(id=id)
+    if not knowledge:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.NOT_FOUND,
+        )
+
+    if knowledge.user_id != user.id and user.role != "admin":
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail=ERROR_MESSAGES.ACCESS_PROHIBITED,
+        )
+
+    # Get files content
+    print(f"files/batch/add - {len(form_data)} files")
+    files: List[FileModel] = []
+    for form in form_data:
+        file = Files.get_file_by_id(form.file_id)
+        if not file:
+            raise HTTPException(
+                status_code=status.HTTP_400_BAD_REQUEST,
+                detail=f"File {form.file_id} not found",
+            )
+        files.append(file)
+
+    # Process files
+    result = process_files_batch(BatchProcessFilesForm(
+        files=files,
+        collection_name=id
+    ))
+    
+    # Add successful files to knowledge base
+    data = knowledge.data or {}
+    existing_file_ids = data.get("file_ids", [])
+    
+    # Only add files that were successfully processed
+    successful_file_ids = [r.file_id for r in result.results if r.status == "completed"]
+    for file_id in successful_file_ids:
+        if file_id not in existing_file_ids:
+            existing_file_ids.append(file_id)
+    
+    data["file_ids"] = existing_file_ids
+    knowledge = Knowledges.update_knowledge_data_by_id(id=id, data=data)
+
+    # If there were any errors, include them in the response
+    if result.errors:
+        error_details = [f"{err.file_id}: {err.error}" for err in result.errors]
+        return KnowledgeFilesResponse(
+            **knowledge.model_dump(),
+            files=Files.get_files_by_ids(existing_file_ids),
+            warnings={
+                "message": "Some files failed to process",
+                "errors": error_details
+            }
+        )
+
+    return KnowledgeFilesResponse(
+        **knowledge.model_dump(),
+        files=Files.get_files_by_ids(existing_file_ids)
+    )