瀏覽代碼

feat: parallel model downloads

Anuraag Jain 1 年之前
父節點
當前提交
ea721feea9
共有 3 個文件被更改,包括 79 次插入13 次删除
  1. 11 0
      package-lock.json
  2. 1 0
      package.json
  3. 67 13
      src/lib/components/chat/SettingsModal.svelte

+ 11 - 0
package-lock.json

@@ -9,6 +9,7 @@
 			"version": "0.0.1",
 			"dependencies": {
 				"@sveltejs/adapter-node": "^1.3.1",
+				"async": "^3.2.5",
 				"file-saver": "^2.0.5",
 				"highlight.js": "^11.9.0",
 				"idb": "^7.1.1",
@@ -1208,6 +1209,11 @@
 				"node": ">=8"
 			}
 		},
+		"node_modules/async": {
+			"version": "3.2.5",
+			"resolved": "https://registry.npmjs.org/async/-/async-3.2.5.tgz",
+			"integrity": "sha512-baNZyqaaLhyLVKm/DlvdW051MSgO6b8eVfIezl9E5PqWxFgzLm/wQntEW4zOytVburDEr0JlALEpdOFwvErLsg=="
+		},
 		"node_modules/autoprefixer": {
 			"version": "10.4.16",
 			"resolved": "https://registry.npmjs.org/autoprefixer/-/autoprefixer-10.4.16.tgz",
@@ -4645,6 +4651,11 @@
 			"integrity": "sha512-HGyxoOTYUyCM6stUe6EJgnd4EoewAI7zMdfqO+kGjnlZmBDz/cR5pf8r/cR4Wq60sL/p0IkcjUEEPwS3GFrIyw==",
 			"dev": true
 		},
+		"async": {
+			"version": "3.2.5",
+			"resolved": "https://registry.npmjs.org/async/-/async-3.2.5.tgz",
+			"integrity": "sha512-baNZyqaaLhyLVKm/DlvdW051MSgO6b8eVfIezl9E5PqWxFgzLm/wQntEW4zOytVburDEr0JlALEpdOFwvErLsg=="
+		},
 		"autoprefixer": {
 			"version": "10.4.16",
 			"resolved": "https://registry.npmjs.org/autoprefixer/-/autoprefixer-10.4.16.tgz",

+ 1 - 0
package.json

@@ -39,6 +39,7 @@
 	"type": "module",
 	"dependencies": {
 		"@sveltejs/adapter-node": "^1.3.1",
+		"async": "^3.2.5",
 		"file-saver": "^2.0.5",
 		"highlight.js": "^11.9.0",
 		"idb": "^7.1.1",

+ 67 - 13
src/lib/components/chat/SettingsModal.svelte

@@ -6,6 +6,7 @@
 	import { onMount } from 'svelte';
 	import { config, models, settings, user, chats } from '$lib/stores';
 	import { splitStream, getGravatarURL } from '$lib/utils';
+	import queue from 'async/queue';
 
 	import { getOllamaVersion } from '$lib/apis/ollama';
 	import { createNewChat, deleteAllChats, getAllChats, getChatList } from '$lib/apis/chats';
@@ -38,6 +39,8 @@
 	let theme = 'dark';
 	let notificationEnabled = false;
 	let system = '';
+	const modelDownloadQueue = queue((task:{modelName: string}, cb) => pullModelHandlerProcessor({modelName: task.modelName, callback: cb}), 3);
+	let modelDownloadStatus: Record<string, any> = {};
 
 	// Advanced
 	let requestFormat = '';
@@ -224,8 +227,9 @@
 		authEnabled = !authEnabled;
 	};
 
-	const pullModelHandler = async () => {
-		modelTransferring = true;
+	const pullModelHandlerProcessor = async (opts:{modelName:string, callback: Function}) => {
+		console.log('Pull model name', opts.modelName);
+		
 		const res = await fetch(`${API_BASE_URL}/pull`, {
 			method: 'POST',
 			headers: {
@@ -234,7 +238,7 @@
 				...($user && { Authorization: `Bearer ${localStorage.token}` })
 			},
 			body: JSON.stringify({
-				name: modelTag
+				name: opts.modelName
 			})
 		});
 
@@ -265,11 +269,9 @@
 						}
 						if (data.status) {
 							if (!data.digest) {
-								toast.success(data.status);
-
 								if (data.status === 'success') {
 									const notification = new Notification(`Ollama`, {
-										body: `Model '${modelTag}' has been successfully downloaded.`,
+										body: `Model '${opts.modelName}' has been successfully downloaded.`,
 										icon: '/favicon.png'
 									});
 								}
@@ -280,21 +282,48 @@
 								} else {
 									pullProgress = 100;
 								}
+								modelDownloadStatus[opts.modelName] = {pullProgress};
 							}
 						}
 					}
 				}
 			} catch (error) {
-				console.log(error);
-				toast.error(error);
+				console.error(error);
+				opts.callback({success:false, error, modelName: opts.modelName});
 			}
 		}
+		opts.callback({success: true, modelName: opts.modelName});
+	};
+
+	const pullModelHandler = async() => {
+		if(modelDownloadStatus[modelTag]){
+			toast.error("Model already in queue for downloading.");
+			return;
+		}
+		if(Object.keys(modelDownloadStatus).length === 3){
+			toast.error('Maximum of 3 models can be downloading simultaneously. Please try again later');
+			return;
+		}
+		modelTransferring = true;
+
+		modelDownloadQueue.push({modelName: modelTag},async (data:{modelName: string; success: boolean; error?: Error}) => {
+			const {modelName} = data;
+			// Remove the downloaded model
+			delete modelDownloadStatus[modelName];
+
+			if(!data.success){
+				toast.error(`There was some issue in downloading the model ${modelName}`);
+				return;
+			}
+		
+			toast.success(`Model ${modelName} was successfully downloaded`);
+			models.set(await getModels());
+		});
 
 		modelTag = '';
-		modelTransferring = false;
+		modelTransferring = false;	
+	}
 
-		models.set(await getModels());
-	};
 
 	const calculateSHA256 = async (file) => {
 		console.log(file);
@@ -1248,7 +1277,7 @@
 									>
 								</div>
 
-								{#if pullProgress !== null}
+								<!-- {#if pullProgress !== null}
 									<div class="mt-2">
 										<div class=" mb-2 text-xs">Pull Progress</div>
 										<div class="w-full rounded-full dark:bg-gray-800">
@@ -1263,8 +1292,33 @@
 											{digest}
 										</div>
 									</div>
-								{/if}
+								{/if} -->
 							</div>
+							{#if Object.keys(modelDownloadStatus).length > 0}
+							<table class="w-full text-sm text-left text-gray-500 dark:text-gray-400">
+								<thead
+									class="text-xs text-gray-700 uppercase bg-gray-50 dark:bg-gray-700 dark:text-gray-400"
+								>
+									<tr>
+										<th scope="col" class="px-6 py-3"> Model Name </th>
+										<th scope="col" class="px-6 py-3"> Download progress </th>
+									</tr>
+								</thead>
+								<tbody>
+									{#each Object.entries(modelDownloadStatus) as [modelName, payload]}
+									<tr class="bg-white border-b dark:bg-gray-800 dark:border-gray-700">
+										<td class="px-6 py-4">{modelName}</td>
+										<td class="px-6 py-4"><div
+											class="dark:bg-gray-600 text-xs font-medium text-blue-100 text-center p-0.5 leading-none rounded-full"
+											style="width: {Math.max(15, payload.pullProgress ?? 0)}%"
+										>
+											{ payload.pullProgress ?? 0}%
+										</div></td>
+									</tr>
+									{/each}
+									</tbody>
+									</table>
+									{/if}
 							<hr class=" dark:border-gray-700" />
 
 							<div>