瀏覽代碼

refac: cancel download

Timothy J. Baek 1 年之前
父節點
當前提交
244f34c24e

+ 1 - 1
src/lib/apis/ollama/index.ts

@@ -271,7 +271,7 @@ export const generateChatCompletion = async (token: string = '', body: object) =
 	return [res, controller];
 };
 
-export const cancelChatCompletion = async (token: string = '', requestId: string) => {
+export const cancelOllamaRequest = async (token: string = '', requestId: string) => {
 	let error = null;
 
 	const res = await fetch(`${OLLAMA_API_BASE_URL}/cancel/${requestId}`, {

+ 82 - 42
src/lib/components/chat/Settings/Models.svelte

@@ -7,7 +7,8 @@
 		deleteModel,
 		getOllamaUrls,
 		getOllamaVersion,
-		pullModel
+		pullModel,
+		cancelOllamaRequest
 	} from '$lib/apis/ollama';
 	import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants';
 	import { WEBUI_NAME, models, user } from '$lib/stores';
@@ -364,12 +365,24 @@
 					for (const line of lines) {
 						if (line !== '') {
 							let data = JSON.parse(line);
+							console.log(data);
 							if (data.error) {
 								throw data.error;
 							}
 							if (data.detail) {
 								throw data.detail;
 							}
+
+							if (data.id) {
+								modelDownloadStatus[opts.modelName] = {
+									...modelDownloadStatus[opts.modelName],
+									requestId: data.id,
+									reader,
+									done: false
+								};
+								console.log(data);
+							}
+
 							if (data.status) {
 								if (data.digest) {
 									let downloadProgress = 0;
@@ -379,12 +392,17 @@
 										downloadProgress = 100;
 									}
 									modelDownloadStatus[opts.modelName] = {
-										reader,
+										...modelDownloadStatus[opts.modelName],
 										pullProgress: downloadProgress,
 										digest: data.digest
 									};
 								} else {
 									toast.success(data.status);
+
+									modelDownloadStatus[opts.modelName] = {
+										...modelDownloadStatus[opts.modelName],
+										done: data.status === 'success'
+									};
 								}
 							}
 						}
@@ -397,7 +415,14 @@
 					opts.callback({ success: false, error, modelName: opts.modelName });
 				}
 			}
-			opts.callback({ success: true, modelName: opts.modelName });
+
+			console.log(modelDownloadStatus[opts.modelName]);
+
+			if (modelDownloadStatus[opts.modelName].done) {
+				opts.callback({ success: true, modelName: opts.modelName });
+			} else {
+				opts.callback({ success: false, error: 'Download canceled', modelName: opts.modelName });
+			}
 		}
 	};
 
@@ -467,10 +492,13 @@
 		ollamaVersion = await getOllamaVersion(localStorage.token).catch((error) => false);
 		liteLLMModelInfo = await getLiteLLMModelInfo(localStorage.token);
 	});
-	const deleteModelPull = async (model: string) => {
-		const { reader } = modelDownloadStatus[model];
+
+	const cancelModelPullHandler = async (model: string) => {
+		const { reader, requestId } = modelDownloadStatus[model];
 		if (reader) {
 			await reader.cancel();
+
+			await cancelOllamaRequest(localStorage.token, requestId);
 			delete modelDownloadStatus[model];
 			await deleteModel(localStorage.token, model);
 			toast.success(`${model} download has been canceled`);
@@ -606,46 +634,58 @@
 
 						{#if Object.keys(modelDownloadStatus).length > 0}
 							{#each Object.keys(modelDownloadStatus) as model}
-								<div class="flex flex-col">
-									<div class="font-medium mb-1">{model}</div>
-									<div class="">
-										<div class="flex flex-row space-x-4 pr-2">
-											<div
-												class="dark:bg-gray-600 bg-gray-500 text-xs font-medium text-gray-100 text-center p-0.5 leading-none rounded-full"
-												style="width: {Math.max(15, modelDownloadStatus[model].pullProgress ?? 0)}%"
-											>
-												{modelDownloadStatus[model].pullProgress ?? 0}%
+								{#if 'pullProgress' in modelDownloadStatus[model]}
+									<div class="flex flex-col">
+										<div class="font-medium mb-1">{model}</div>
+										<div class="">
+											<div class="flex flex-row justify-between space-x-4 pr-2">
+												<div class=" flex-1">
+													<div
+														class="dark:bg-gray-600 bg-gray-500 text-xs font-medium text-gray-100 text-center p-0.5 leading-none rounded-full"
+														style="width: {Math.max(
+															15,
+															modelDownloadStatus[model].pullProgress ?? 0
+														)}%"
+													>
+														{modelDownloadStatus[model].pullProgress ?? 0}%
+													</div>
+												</div>
+
+												<Tooltip content="Cancel">
+													<button
+														class="text-gray-800 dark:text-gray-100"
+														on:click={() => {
+															cancelModelPullHandler(model);
+														}}
+													>
+														<svg
+															class="w-4 h-4 text-gray-800 dark:text-white"
+															aria-hidden="true"
+															xmlns="http://www.w3.org/2000/svg"
+															width="24"
+															height="24"
+															fill="currentColor"
+															viewBox="0 0 24 24"
+														>
+															<path
+																stroke="currentColor"
+																stroke-linecap="round"
+																stroke-linejoin="round"
+																stroke-width="2"
+																d="M6 18 17.94 6M18 18 6.06 6"
+															/>
+														</svg>
+													</button>
+												</Tooltip>
 											</div>
-											<button
-												class="text-gray-800 dark:text-gray-100"
-												on:click={() => {
-													deleteModelPull(model);
-												}}
-											>
-												<svg
-													class="w-4 h-4 text-gray-800 dark:text-white"
-													aria-hidden="true"
-													xmlns="http://www.w3.org/2000/svg"
-													width="24"
-													height="24"
-													fill="currentColor"
-													viewBox="0 0 24 24"
-												>
-													<path
-														stroke="currentColor"
-														stroke-linecap="round"
-														stroke-linejoin="round"
-														stroke-width="2"
-														d="M6 18 17.94 6M18 18 6.06 6"
-													/>
-												</svg>
-											</button>
-										</div>
-										<div class="mt-1 text-xs dark:text-gray-500" style="font-size: 0.5rem;">
-											{modelDownloadStatus[model].digest}
+											{#if 'digest' in modelDownloadStatus[model]}
+												<div class="mt-1 text-xs dark:text-gray-500" style="font-size: 0.5rem;">
+													{modelDownloadStatus[model].digest}
+												</div>
+											{/if}
 										</div>
 									</div>
-								</div>
+								{/if}
 							{/each}
 						{/if}
 					</div>

+ 3 - 3
src/routes/(app)/+page.svelte

@@ -19,7 +19,7 @@
 	} from '$lib/stores';
 	import { copyToClipboard, splitStream } from '$lib/utils';
 
-	import { generateChatCompletion, cancelChatCompletion, generateTitle } from '$lib/apis/ollama';
+	import { generateChatCompletion, cancelOllamaRequest, generateTitle } from '$lib/apis/ollama';
 	import {
 		addTagById,
 		createNewChat,
@@ -104,7 +104,7 @@
 
 	const initNewChat = async () => {
 		if (currentRequestId !== null) {
-			await cancelChatCompletion(localStorage.token, currentRequestId);
+			await cancelOllamaRequest(localStorage.token, currentRequestId);
 			currentRequestId = null;
 		}
 		window.history.replaceState(history.state, '', `/`);
@@ -372,7 +372,7 @@
 
 					if (stopResponseFlag) {
 						controller.abort('User: Stop Response');
-						await cancelChatCompletion(localStorage.token, currentRequestId);
+						await cancelOllamaRequest(localStorage.token, currentRequestId);
 					}
 
 					currentRequestId = null;

+ 3 - 3
src/routes/(app)/c/[id]/+page.svelte

@@ -19,7 +19,7 @@
 	} from '$lib/stores';
 	import { copyToClipboard, splitStream, convertMessagesToHistory } from '$lib/utils';
 
-	import { generateChatCompletion, generateTitle, cancelChatCompletion } from '$lib/apis/ollama';
+	import { generateChatCompletion, generateTitle, cancelOllamaRequest } from '$lib/apis/ollama';
 	import {
 		addTagById,
 		createNewChat,
@@ -382,7 +382,7 @@
 
 					if (stopResponseFlag) {
 						controller.abort('User: Stop Response');
-						await cancelChatCompletion(localStorage.token, currentRequestId);
+						await cancelOllamaRequest(localStorage.token, currentRequestId);
 					}
 
 					currentRequestId = null;
@@ -843,7 +843,7 @@
 			shareEnabled={messages.length > 0}
 			initNewChat={async () => {
 				if (currentRequestId !== null) {
-					await cancelChatCompletion(localStorage.token, currentRequestId);
+					await cancelOllamaRequest(localStorage.token, currentRequestId);
 					currentRequestId = null;
 				}
 

+ 4 - 4
src/routes/(app)/playground/+page.svelte

@@ -13,7 +13,7 @@
 	} from '$lib/constants';
 	import { WEBUI_NAME, config, user, models, settings } from '$lib/stores';
 
-	import { cancelChatCompletion, generateChatCompletion } from '$lib/apis/ollama';
+	import { cancelOllamaRequest, generateChatCompletion } from '$lib/apis/ollama';
 	import { generateOpenAIChatCompletion } from '$lib/apis/openai';
 
 	import { splitStream } from '$lib/utils';
@@ -52,7 +52,7 @@
 
 	// const cancelHandler = async () => {
 	// 	if (currentRequestId) {
-	// 		const res = await cancelChatCompletion(localStorage.token, currentRequestId);
+	// 		const res = await cancelOllamaRequest(localStorage.token, currentRequestId);
 	// 		currentRequestId = null;
 	// 		loading = false;
 	// 	}
@@ -95,7 +95,7 @@
 				const { value, done } = await reader.read();
 				if (done || stopResponseFlag) {
 					if (stopResponseFlag) {
-						await cancelChatCompletion(localStorage.token, currentRequestId);
+						await cancelOllamaRequest(localStorage.token, currentRequestId);
 					}
 
 					currentRequestId = null;
@@ -181,7 +181,7 @@
 				const { value, done } = await reader.read();
 				if (done || stopResponseFlag) {
 					if (stopResponseFlag) {
-						await cancelChatCompletion(localStorage.token, currentRequestId);
+						await cancelOllamaRequest(localStorage.token, currentRequestId);
 					}
 
 					currentRequestId = null;