Browse Source

Merge pull request #9641 from SentinalMax/bugfix/GGUF-upload-issue

Fix: GGUF model upload instability
Timothy Jaeryang Baek 2 months ago
parent
commit
6862081d6b
2 changed files with 81 additions and 51 deletions
  1. 76 47
      backend/open_webui/routers/ollama.py
  2. 5 4
      backend/open_webui/utils/misc.py

+ 76 - 47
backend/open_webui/routers/ollama.py

@@ -11,10 +11,8 @@ import re
 import time
 from typing import Optional, Union
 from urllib.parse import urlparse
-
 import aiohttp
 from aiocache import cached
-
 import requests
 
 from fastapi import (
@@ -990,6 +988,8 @@ async def generate_chat_completion(
         )
 
     payload = {**form_data.model_dump(exclude_none=True)}
+    if "metadata" in payload:
+        del payload["metadata"]
 
     model_id = payload["model"]
     model_info = Models.get_model_by_id(model_id)
@@ -1408,9 +1408,10 @@ async def download_model(
         return None
 
 
+# TODO: Progress bar does not reflect size & duration of upload.
 @router.post("/models/upload")
 @router.post("/models/upload/{url_idx}")
-def upload_model(
+async def upload_model(
     request: Request,
     file: UploadFile = File(...),
     url_idx: Optional[int] = None,
@@ -1419,62 +1420,90 @@ def upload_model(
     if url_idx is None:
         url_idx = 0
     ollama_url = request.app.state.config.OLLAMA_BASE_URLS[url_idx]
-
-    file_path = f"{UPLOAD_DIR}/{file.filename}"
-
-    # Save file in chunks
-    with open(file_path, "wb+") as f:
-        for chunk in file.file:
-            f.write(chunk)
-
-    def file_process_stream():
+    file_path = os.path.join(UPLOAD_DIR, file.filename)
+    os.makedirs(UPLOAD_DIR, exist_ok=True)
+
+    # --- P1: save file locally ---
+    chunk_size = 1024 * 1024 * 2 # 2 MB chunks
+    with open(file_path, "wb") as out_f:
+        while True:
+            chunk = file.file.read(chunk_size)
+            #log.info(f"Chunk: {str(chunk)}") # DEBUG
+            if not chunk:
+                break
+            out_f.write(chunk)
+
+    async def file_process_stream():
         nonlocal ollama_url
         total_size = os.path.getsize(file_path)
-        chunk_size = 1024 * 1024
+        log.info(f"Total Model Size: {str(total_size)}") # DEBUG
+
+        # --- P2: SSE progress + calculate sha256 hash ---
+        file_hash = calculate_sha256(file_path, chunk_size)
+        log.info(f"Model Hash: {str(file_hash)}") # DEBUG
         try:
             with open(file_path, "rb") as f:
-                total = 0
-                done = False
-
-                while not done:
-                    chunk = f.read(chunk_size)
-                    if not chunk:
-                        done = True
-                        continue
-
-                    total += len(chunk)
-                    progress = round((total / total_size) * 100, 2)
-
-                    res = {
+                bytes_read = 0
+                while chunk := f.read(chunk_size): 
+                    bytes_read += len(chunk)
+                    progress = round(bytes_read / total_size * 100, 2)
+                    data_msg = {
                         "progress": progress,
                         "total": total_size,
-                        "completed": total,
+                        "completed": bytes_read,
                     }
-                    yield f"data: {json.dumps(res)}\n\n"
+                    yield f"data: {json.dumps(data_msg)}\n\n"
 
-                if done:
-                    f.seek(0)
-                    hashed = calculate_sha256(f)
-                    f.seek(0)
+            # --- P3: Upload to ollama /api/blobs ---
+            with open(file_path, "rb") as f:
+                url = f"{ollama_url}/api/blobs/sha256:{file_hash}"
+                response = requests.post(url, data=f)
+
+            if response.ok:
+                log.info(f"Uploaded to /api/blobs") # DEBUG
+                # Remove local file
+                os.remove(file_path)
+
+                # Create model in ollama
+                model_name, ext = os.path.splitext(file.filename)
+                log.info(f"Created Model: {model_name}") # DEBUG
+
+                create_payload = {
+                    "model": model_name,
+                    # Reference the file by its original name => the uploaded blob's digest
+                    "files": {
+                        file.filename: f"sha256:{file_hash}"
+                    },
+                }
+                log.info(f"Model Payload: {create_payload}") # DEBUG
+
+                # Call ollama /api/create
+                #https://github.com/ollama/ollama/blob/main/docs/api.md#create-a-model
+                create_resp = requests.post(
+                    url=f"{ollama_url}/api/create",
+                    headers={"Content-Type": "application/json"},
+                    data=json.dumps(create_payload),
+                )
 
-                    url = f"{ollama_url}/api/blobs/sha256:{hashed}"
-                    response = requests.post(url, data=f)
+                if create_resp.ok:
+                    log.info(f"API SUCCESS!") # DEBUG
+                    done_msg = {
+                        "done": True,
+                        "blob": f"sha256:{file_hash}",
+                        "name": file.filename,
+                        "model_created": model_name,
+                    }
+                    yield f"data: {json.dumps(done_msg)}\n\n"
+                else:
+                    raise Exception(
+                        f"Failed to create model in Ollama. {create_resp.text}"
+                    )
 
-                    if response.ok:
-                        res = {
-                            "done": done,
-                            "blob": f"sha256:{hashed}",
-                            "name": file.filename,
-                        }
-                        os.remove(file_path)
-                        yield f"data: {json.dumps(res)}\n\n"
-                    else:
-                        raise Exception(
-                            "Ollama: Could not create blob, Please try again."
-                        )
+            else:
+                raise Exception("Ollama: Could not create blob, Please try again.")
 
         except Exception as e:
             res = {"error": str(e)}
             yield f"data: {json.dumps(res)}\n\n"
 
-    return StreamingResponse(file_process_stream(), media_type="text/event-stream")
+    return StreamingResponse(file_process_stream(), media_type="text/event-stream")

+ 5 - 4
backend/open_webui/utils/misc.py

@@ -244,11 +244,12 @@ def get_gravatar_url(email):
     return f"https://www.gravatar.com/avatar/{hash_hex}?d=mp"
 
 
-def calculate_sha256(file):
+def calculate_sha256(file_path, chunk_size):
+    #Compute SHA-256 hash of a file efficiently in chunks
     sha256 = hashlib.sha256()
-    # Read the file in chunks to efficiently handle large files
-    for chunk in iter(lambda: file.read(8192), b""):
-        sha256.update(chunk)
+    with open(file_path, "rb") as f:
+        while chunk := f.read(chunk_size):
+            sha256.update(chunk)
     return sha256.hexdigest()