|
@@ -2,16 +2,14 @@
|
|
|
|
|
|
import json
|
|
import json
|
|
import logging
|
|
import logging
|
|
-import mimetypes
|
|
|
|
import os
|
|
import os
|
|
import shutil
|
|
import shutil
|
|
|
|
|
|
import uuid
|
|
import uuid
|
|
from datetime import datetime
|
|
from datetime import datetime
|
|
-from pathlib import Path
|
|
|
|
-from typing import Iterator, Optional, Sequence, Union
|
|
|
|
|
|
+from typing import List, Optional
|
|
|
|
|
|
-from fastapi import Depends, FastAPI, File, Form, HTTPException, UploadFile, status
|
|
|
|
|
|
+from fastapi import Depends, FastAPI, HTTPException, status
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel
|
|
from pydantic import BaseModel
|
|
import tiktoken
|
|
import tiktoken
|
|
@@ -52,7 +50,7 @@ from open_webui.apps.retrieval.utils import (
|
|
query_doc_with_hybrid_search,
|
|
query_doc_with_hybrid_search,
|
|
)
|
|
)
|
|
|
|
|
|
-from open_webui.apps.webui.models.files import Files
|
|
|
|
|
|
+from open_webui.apps.webui.models.files import FileModel, Files
|
|
from open_webui.config import (
|
|
from open_webui.config import (
|
|
BRAVE_SEARCH_API_KEY,
|
|
BRAVE_SEARCH_API_KEY,
|
|
KAGI_SEARCH_API_KEY,
|
|
KAGI_SEARCH_API_KEY,
|
|
@@ -64,7 +62,6 @@ from open_webui.config import (
|
|
CONTENT_EXTRACTION_ENGINE,
|
|
CONTENT_EXTRACTION_ENGINE,
|
|
CORS_ALLOW_ORIGIN,
|
|
CORS_ALLOW_ORIGIN,
|
|
ENABLE_RAG_HYBRID_SEARCH,
|
|
ENABLE_RAG_HYBRID_SEARCH,
|
|
- ENABLE_RAG_LOCAL_WEB_FETCH,
|
|
|
|
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
|
ENABLE_RAG_WEB_LOADER_SSL_VERIFICATION,
|
|
ENABLE_RAG_WEB_SEARCH,
|
|
ENABLE_RAG_WEB_SEARCH,
|
|
ENV,
|
|
ENV,
|
|
@@ -86,7 +83,6 @@ from open_webui.config import (
|
|
RAG_RERANKING_MODEL,
|
|
RAG_RERANKING_MODEL,
|
|
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
|
RAG_RERANKING_MODEL_AUTO_UPDATE,
|
|
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
|
RAG_RERANKING_MODEL_TRUST_REMOTE_CODE,
|
|
- DEFAULT_RAG_TEMPLATE,
|
|
|
|
RAG_TEMPLATE,
|
|
RAG_TEMPLATE,
|
|
RAG_TOP_K,
|
|
RAG_TOP_K,
|
|
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
|
RAG_WEB_SEARCH_CONCURRENT_REQUESTS,
|
|
@@ -118,10 +114,7 @@ from open_webui.env import (
|
|
DOCKER,
|
|
DOCKER,
|
|
)
|
|
)
|
|
from open_webui.utils.misc import (
|
|
from open_webui.utils.misc import (
|
|
- calculate_sha256,
|
|
|
|
calculate_sha256_string,
|
|
calculate_sha256_string,
|
|
- extract_folders_after_data_docs,
|
|
|
|
- sanitize_filename,
|
|
|
|
)
|
|
)
|
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
|
from open_webui.utils.auth import get_admin_user, get_verified_user
|
|
|
|
|
|
@@ -1047,6 +1040,106 @@ def process_file(
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
+class BatchProcessFilesForm(BaseModel):
|
|
|
|
+ files: List[FileModel]
|
|
|
|
+ collection_name: str
|
|
|
|
+
|
|
|
|
+class BatchProcessFilesResult(BaseModel):
|
|
|
|
+ file_id: str
|
|
|
|
+ status: str
|
|
|
|
+ error: Optional[str] = None
|
|
|
|
+
|
|
|
|
+class BatchProcessFilesResponse(BaseModel):
|
|
|
|
+ results: List[BatchProcessFilesResult]
|
|
|
|
+ errors: List[BatchProcessFilesResult]
|
|
|
|
+
|
|
|
|
+@app.post("/process/files/batch")
|
|
|
|
+def process_files_batch(
|
|
|
|
+ form_data: BatchProcessFilesForm,
|
|
|
|
+ user=Depends(get_verified_user),
|
|
|
|
+) -> BatchProcessFilesResponse:
|
|
|
|
+ """
|
|
|
|
+ Process a batch of files and save them to the vector database.
|
|
|
|
+ """
|
|
|
|
+ results: List[BatchProcessFilesResult] = []
|
|
|
|
+ errors: List[BatchProcessFilesResult] = []
|
|
|
|
+ collection_name = form_data.collection_name
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ # Prepare all documents first
|
|
|
|
+ all_docs: List[Document] = []
|
|
|
|
+ for file_request in form_data.files:
|
|
|
|
+ try:
|
|
|
|
+ file = Files.get_file_by_id(file_request.file_id)
|
|
|
|
+ if not file:
|
|
|
|
+ log.error(f"process_files_batch: File {file_request.file_id} not found")
|
|
|
|
+ raise ValueError(f"File {file_request.file_id} not found")
|
|
|
|
+
|
|
|
|
+ text_content = file_request.content
|
|
|
|
+
|
|
|
|
+ docs: List[Document] = [
|
|
|
|
+ Document(
|
|
|
|
+ page_content=text_content.replace("<br/>", "\n"),
|
|
|
|
+ metadata={
|
|
|
|
+ **file.meta,
|
|
|
|
+ "name": file_request.filename,
|
|
|
|
+ "created_by": file.user_id,
|
|
|
|
+ "file_id": file.id,
|
|
|
|
+ "source": file_request.filename,
|
|
|
|
+ },
|
|
|
|
+ )
|
|
|
|
+ ]
|
|
|
|
+
|
|
|
|
+ hash = calculate_sha256_string(text_content)
|
|
|
|
+ Files.update_file_hash_by_id(file.id, hash)
|
|
|
|
+ Files.update_file_data_by_id(file.id, {"content": text_content})
|
|
|
|
+
|
|
|
|
+ all_docs.extend(docs)
|
|
|
|
+ results.append(BatchProcessFilesResult(
|
|
|
|
+ file_id=file.id,
|
|
|
|
+ status="prepared"
|
|
|
|
+ ))
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ log.error(f"process_files_batch: Error processing file {file_request.file_id}: {str(e)}")
|
|
|
|
+ errors.append(BatchProcessFilesResult(
|
|
|
|
+ file_id=file_request.file_id,
|
|
|
|
+ status="failed",
|
|
|
|
+ error=str(e)
|
|
|
|
+ ))
|
|
|
|
+
|
|
|
|
+ # Save all documents in one batch
|
|
|
|
+ if all_docs:
|
|
|
|
+ try:
|
|
|
|
+ save_docs_to_vector_db(
|
|
|
|
+ docs=all_docs,
|
|
|
|
+ collection_name=collection_name,
|
|
|
|
+ add=True
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+ # Update all files with collection name
|
|
|
|
+ for result in results:
|
|
|
|
+ Files.update_file_metadata_by_id(
|
|
|
|
+ result.file_id,
|
|
|
|
+ {"collection_name": collection_name}
|
|
|
|
+ )
|
|
|
|
+ result.status = "completed"
|
|
|
|
+
|
|
|
|
+ except Exception as e:
|
|
|
|
+ log.error(f"process_files_batch: Error saving documents to vector DB: {str(e)}")
|
|
|
|
+ for result in results:
|
|
|
|
+ result.status = "failed"
|
|
|
|
+ errors.append(BatchProcessFilesResult(
|
|
|
|
+ file_id=result.file_id,
|
|
|
|
+ error=str(e)
|
|
|
|
+ ))
|
|
|
|
+
|
|
|
|
+ return BatchProcessFilesResponse(
|
|
|
|
+ results=results,
|
|
|
|
+ errors=errors
|
|
|
|
+ )
|
|
|
|
+
|
|
|
|
+
|
|
class ProcessTextForm(BaseModel):
|
|
class ProcessTextForm(BaseModel):
|
|
name: str
|
|
name: str
|
|
content: str
|
|
content: str
|
|
@@ -1509,3 +1602,4 @@ if ENV == "dev":
|
|
@app.get("/ef/{text}")
|
|
@app.get("/ef/{text}")
|
|
async def get_embeddings_text(text: str):
|
|
async def get_embeddings_text(text: str):
|
|
return {"result": app.state.EMBEDDING_FUNCTION(text)}
|
|
return {"result": app.state.EMBEDDING_FUNCTION(text)}
|
|
|
|
+
|