소스 검색

refac: ollama gguf upload

Timothy J. Baek 1 년 전
부모
커밋
af4caec4f5
4개의 변경된 파일269개의 추가작업 그리고 177개의 파일을 삭제
  1. 183 3
      backend/apps/ollama/main.py
  2. 0 149
      backend/apps/web/routers/utils.py
  3. 65 0
      src/lib/apis/ollama/index.ts
  4. 21 25
      src/lib/components/chat/Settings/Models.svelte

+ 183 - 3
backend/apps/ollama/main.py

@@ -1,23 +1,39 @@
-from fastapi import FastAPI, Request, Response, HTTPException, Depends, status
+from fastapi import (
+    FastAPI,
+    Request,
+    Response,
+    HTTPException,
+    Depends,
+    status,
+    UploadFile,
+    File,
+    BackgroundTasks,
+)
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.responses import StreamingResponse
 from fastapi.concurrency import run_in_threadpool
 
 from pydantic import BaseModel, ConfigDict
 
+import os
 import random
 import requests
 import json
 import uuid
 import aiohttp
 import asyncio
+import aiofiles
+from urllib.parse import urlparse
+from typing import Optional, List, Union
+
 
 from apps.web.models.users import Users
 from constants import ERROR_MESSAGES
 from utils.utils import decode_token, get_current_user, get_admin_user
-from config import OLLAMA_BASE_URLS, MODEL_FILTER_ENABLED, MODEL_FILTER_LIST
+from utils.misc import calculate_sha256
 
-from typing import Optional, List, Union
+
+from config import OLLAMA_BASE_URLS, MODEL_FILTER_ENABLED, MODEL_FILTER_LIST, UPLOAD_DIR
 
 
 app = FastAPI()
@@ -897,6 +913,170 @@ async def generate_openai_chat_completion(
         )
 
 
+class UrlForm(BaseModel):
+    url: str
+
+
+class UploadBlobForm(BaseModel):
+    filename: str
+
+
+def parse_huggingface_url(hf_url):
+    try:
+        # Parse the URL
+        parsed_url = urlparse(hf_url)
+
+        # Get the path and split it into components
+        path_components = parsed_url.path.split("/")
+
+        # Extract the desired output
+        user_repo = "/".join(path_components[1:3])
+        model_file = path_components[-1]
+
+        return model_file
+    except ValueError:
+        return None
+
+
+async def download_file_stream(
+    ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024
+):
+    done = False
+
+    if os.path.exists(file_path):
+        current_size = os.path.getsize(file_path)
+    else:
+        current_size = 0
+
+    headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}
+
+    timeout = aiohttp.ClientTimeout(total=600)  # Set the timeout
+
+    async with aiohttp.ClientSession(timeout=timeout) as session:
+        async with session.get(file_url, headers=headers) as response:
+            total_size = int(response.headers.get("content-length", 0)) + current_size
+
+            with open(file_path, "ab+") as file:
+                async for data in response.content.iter_chunked(chunk_size):
+                    current_size += len(data)
+                    file.write(data)
+
+                    done = current_size == total_size
+                    progress = round((current_size / total_size) * 100, 2)
+                    yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
+
+                if done:
+                    file.seek(0)
+                    hashed = calculate_sha256(file)
+                    file.seek(0)
+
+                    url = f"{ollama_url}/api/blobs/sha256:{hashed}"
+                    response = requests.post(url, data=file)
+
+                    if response.ok:
+                        res = {
+                            "done": done,
+                            "blob": f"sha256:{hashed}",
+                            "name": file_name,
+                        }
+                        os.remove(file_path)
+
+                        yield f"data: {json.dumps(res)}\n\n"
+                    else:
+                        raise "Ollama: Could not create blob, Please try again."
+
+
+# url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
+@app.post("/models/download")
+@app.post("/models/download/{url_idx}")
+async def download_model(
+    form_data: UrlForm,
+    url_idx: Optional[int] = None,
+):
+
+    if url_idx == None:
+        url_idx = 0
+    url = app.state.OLLAMA_BASE_URLS[url_idx]
+
+    file_name = parse_huggingface_url(form_data.url)
+
+    if file_name:
+        file_path = f"{UPLOAD_DIR}/{file_name}"
+
+        return StreamingResponse(
+            download_file_stream(url, form_data.url, file_path, file_name)
+        )
+    else:
+        return None
+
+
+@app.post("/models/upload")
+@app.post("/models/upload/{url_idx}")
+def upload_model(file: UploadFile = File(...), url_idx: Optional[int] = None):
+    if url_idx == None:
+        url_idx = 0
+    ollama_url = app.state.OLLAMA_BASE_URLS[url_idx]
+
+    file_path = f"{UPLOAD_DIR}/{file.filename}"
+
+    # Save file in chunks
+    with open(file_path, "wb+") as f:
+        for chunk in file.file:
+            f.write(chunk)
+
+    def file_process_stream():
+        nonlocal ollama_url
+        total_size = os.path.getsize(file_path)
+        chunk_size = 1024 * 1024
+        try:
+            with open(file_path, "rb") as f:
+                total = 0
+                done = False
+
+                while not done:
+                    chunk = f.read(chunk_size)
+                    if not chunk:
+                        done = True
+                        continue
+
+                    total += len(chunk)
+                    progress = round((total / total_size) * 100, 2)
+
+                    res = {
+                        "progress": progress,
+                        "total": total_size,
+                        "completed": total,
+                    }
+                    yield f"data: {json.dumps(res)}\n\n"
+
+                if done:
+                    f.seek(0)
+                    hashed = calculate_sha256(f)
+                    f.seek(0)
+
+                    url = f"{ollama_url}/blobs/sha256:{hashed}"
+                    response = requests.post(url, data=f)
+
+                    if response.ok:
+                        res = {
+                            "done": done,
+                            "blob": f"sha256:{hashed}",
+                            "name": file.filename,
+                        }
+                        os.remove(file_path)
+                        yield f"data: {json.dumps(res)}\n\n"
+                    else:
+                        raise Exception(
+                            "Ollama: Could not create blob, Please try again."
+                        )
+
+        except Exception as e:
+            res = {"error": str(e)}
+            yield f"data: {json.dumps(res)}\n\n"
+
+    return StreamingResponse(file_process_stream(), media_type="text/event-stream")
+
+
 @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
 async def deprecated_proxy(path: str, request: Request, user=Depends(get_current_user)):
     url = app.state.OLLAMA_BASE_URLS[0]

+ 0 - 149
backend/apps/web/routers/utils.py

@@ -21,155 +21,6 @@ from constants import ERROR_MESSAGES
 router = APIRouter()
 
 
-class UploadBlobForm(BaseModel):
-    filename: str
-
-
-from urllib.parse import urlparse
-
-
-def parse_huggingface_url(hf_url):
-    try:
-        # Parse the URL
-        parsed_url = urlparse(hf_url)
-
-        # Get the path and split it into components
-        path_components = parsed_url.path.split("/")
-
-        # Extract the desired output
-        user_repo = "/".join(path_components[1:3])
-        model_file = path_components[-1]
-
-        return model_file
-    except ValueError:
-        return None
-
-
-async def download_file_stream(url, file_path, file_name, chunk_size=1024 * 1024):
-    done = False
-
-    if os.path.exists(file_path):
-        current_size = os.path.getsize(file_path)
-    else:
-        current_size = 0
-
-    headers = {"Range": f"bytes={current_size}-"} if current_size > 0 else {}
-
-    timeout = aiohttp.ClientTimeout(total=600)  # Set the timeout
-
-    async with aiohttp.ClientSession(timeout=timeout) as session:
-        async with session.get(url, headers=headers) as response:
-            total_size = int(response.headers.get("content-length", 0)) + current_size
-
-            with open(file_path, "ab+") as file:
-                async for data in response.content.iter_chunked(chunk_size):
-                    current_size += len(data)
-                    file.write(data)
-
-                    done = current_size == total_size
-                    progress = round((current_size / total_size) * 100, 2)
-                    yield f'data: {{"progress": {progress}, "completed": {current_size}, "total": {total_size}}}\n\n'
-
-                if done:
-                    file.seek(0)
-                    hashed = calculate_sha256(file)
-                    file.seek(0)
-
-                    url = f"{OLLAMA_BASE_URLS[0]}/api/blobs/sha256:{hashed}"
-                    response = requests.post(url, data=file)
-
-                    if response.ok:
-                        res = {
-                            "done": done,
-                            "blob": f"sha256:{hashed}",
-                            "name": file_name,
-                        }
-                        os.remove(file_path)
-
-                        yield f"data: {json.dumps(res)}\n\n"
-                    else:
-                        raise "Ollama: Could not create blob, Please try again."
-
-
-@router.get("/download")
-async def download(
-    url: str,
-):
-    # url = "https://huggingface.co/TheBloke/stablelm-zephyr-3b-GGUF/resolve/main/stablelm-zephyr-3b.Q2_K.gguf"
-    file_name = parse_huggingface_url(url)
-
-    if file_name:
-        file_path = f"{UPLOAD_DIR}/{file_name}"
-
-        return StreamingResponse(
-            download_file_stream(url, file_path, file_name),
-            media_type="text/event-stream",
-        )
-    else:
-        return None
-
-
-@router.post("/upload")
-def upload(file: UploadFile = File(...)):
-    file_path = f"{UPLOAD_DIR}/{file.filename}"
-
-    # Save file in chunks
-    with open(file_path, "wb+") as f:
-        for chunk in file.file:
-            f.write(chunk)
-
-    def file_process_stream():
-        total_size = os.path.getsize(file_path)
-        chunk_size = 1024 * 1024
-        try:
-            with open(file_path, "rb") as f:
-                total = 0
-                done = False
-
-                while not done:
-                    chunk = f.read(chunk_size)
-                    if not chunk:
-                        done = True
-                        continue
-
-                    total += len(chunk)
-                    progress = round((total / total_size) * 100, 2)
-
-                    res = {
-                        "progress": progress,
-                        "total": total_size,
-                        "completed": total,
-                    }
-                    yield f"data: {json.dumps(res)}\n\n"
-
-                if done:
-                    f.seek(0)
-                    hashed = calculate_sha256(f)
-                    f.seek(0)
-
-                    url = f"{OLLAMA_BASE_URLS[0]}/blobs/sha256:{hashed}"
-                    response = requests.post(url, data=f)
-
-                    if response.ok:
-                        res = {
-                            "done": done,
-                            "blob": f"sha256:{hashed}",
-                            "name": file.filename,
-                        }
-                        os.remove(file_path)
-                        yield f"data: {json.dumps(res)}\n\n"
-                    else:
-                        raise Exception(
-                            "Ollama: Could not create blob, Please try again."
-                        )
-
-        except Exception as e:
-            res = {"error": str(e)}
-            yield f"data: {json.dumps(res)}\n\n"
-
-    return StreamingResponse(file_process_stream(), media_type="text/event-stream")
-
-
 @router.get("/gravatar")
 async def get_gravatar(
     email: str,

+ 65 - 0
src/lib/apis/ollama/index.ts

@@ -390,6 +390,71 @@ export const pullModel = async (token: string, tagName: string, urlIdx: string |
 	return res;
 };
 
+export const downloadModel = async (
+	token: string,
+	download_url: string,
+	urlIdx: string | null = null
+) => {
+	let error = null;
+
+	const res = await fetch(
+		`${OLLAMA_API_BASE_URL}/models/download${urlIdx !== null ? `/${urlIdx}` : ''}`,
+		{
+			method: 'POST',
+			headers: {
+				Authorization: `Bearer ${token}`
+			},
+			body: JSON.stringify({
+				url: download_url
+			})
+		}
+	).catch((err) => {
+		console.log(err);
+		error = err;
+
+		if ('detail' in err) {
+			error = err.detail;
+		}
+
+		return null;
+	});
+	if (error) {
+		throw error;
+	}
+	return res;
+};
+
+export const uploadModel = async (token: string, file: File, urlIdx: string | null = null) => {
+	let error = null;
+
+	const formData = new FormData();
+	formData.append('file', file);
+
+	const res = await fetch(
+		`${OLLAMA_API_BASE_URL}/models/upload${urlIdx !== null ? `/${urlIdx}` : ''}`,
+		{
+			method: 'POST',
+			headers: {
+				Authorization: `Bearer ${token}`
+			},
+			body: formData
+		}
+	).catch((err) => {
+		console.log(err);
+		error = err;
+
+		if ('detail' in err) {
+			error = err.detail;
+		}
+
+		return null;
+	});
+	if (error) {
+		throw error;
+	}
+	return res;
+};
+
 // export const pullModel = async (token: string, tagName: string) => {
 // 	return await fetch(`${OLLAMA_API_BASE_URL}/pull`, {
 // 		method: 'POST',

+ 21 - 25
src/lib/components/chat/Settings/Models.svelte

@@ -5,9 +5,11 @@
 	import {
 		createModel,
 		deleteModel,
+		downloadModel,
 		getOllamaUrls,
 		getOllamaVersion,
-		pullModel
+		pullModel,
+		uploadModel
 	} from '$lib/apis/ollama';
 	import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants';
 	import { WEBUI_NAME, models, user } from '$lib/stores';
@@ -60,7 +62,7 @@
 	let pullProgress = null;
 
 	let modelUploadMode = 'file';
-	let modelInputFile = '';
+	let modelInputFile: File[] | null = null;
 	let modelFileUrl = '';
 	let modelFileContent = `TEMPLATE """{{ .System }}\nUSER: {{ .Prompt }}\nASSISTANT: """\nPARAMETER num_ctx 4096\nPARAMETER stop "</s>"\nPARAMETER stop "USER:"\nPARAMETER stop "ASSISTANT:"`;
 	let modelFileDigest = '';
@@ -191,30 +193,23 @@
 		let name = '';
 
 		if (modelUploadMode === 'file') {
-			const file = modelInputFile[0];
-			const formData = new FormData();
-			formData.append('file', file);
-
-			fileResponse = await fetch(`${WEBUI_API_BASE_URL}/utils/upload`, {
-				method: 'POST',
-				headers: {
-					...($user && { Authorization: `Bearer ${localStorage.token}` })
-				},
-				body: formData
-			}).catch((error) => {
-				console.log(error);
-				return null;
-			});
+			const file = modelInputFile ? modelInputFile[0] : null;
+
+			if (file) {
+				fileResponse = uploadModel(localStorage.token, file, selectedOllamaUrlIdx).catch(
+					(error) => {
+						toast.error(error);
+						return null;
+					}
+				);
+			}
 		} else {
-			fileResponse = await fetch(`${WEBUI_API_BASE_URL}/utils/download?url=${modelFileUrl}`, {
-				method: 'GET',
-				headers: {
-					...($user && { Authorization: `Bearer ${localStorage.token}` })
+			fileResponse = downloadModel(localStorage.token, modelFileUrl, selectedOllamaUrlIdx).catch(
+				(error) => {
+					toast.error(error);
+					return null;
 				}
-			}).catch((error) => {
-				console.log(error);
-				return null;
-			});
+			);
 		}
 
 		if (fileResponse && fileResponse.ok) {
@@ -318,7 +313,8 @@
 		}
 
 		modelFileUrl = '';
-		modelInputFile = '';
+		modelUploadInputElement.value = '';
+		modelInputFile = null;
 		modelTransferring = false;
 		uploadProgress = null;