Explorar el Código

Merge pull request #1272 from anuraagdjain/feat/cancel-model-download

feat: cancel model download
Timothy Jaeryang Baek hace 1 año
padre
commit
a1faa307b1

+ 1 - 1
.gitignore

@@ -166,7 +166,7 @@ cython_debug/
 #  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
 #  and can be added to the global gitignore or merged into this file.  For a more nuclear
 #  option (not recommended) you can uncomment the following to ignore the entire idea folder.
-#.idea/
+.idea/
 
 # Logs
 logs

+ 18 - 2
backend/apps/ollama/main.py

@@ -250,11 +250,26 @@ async def pull_model(
     def get_request():
         nonlocal url
         nonlocal r
+
+        request_id = str(uuid.uuid4())
         try:
+            REQUEST_POOL.append(request_id)
 
             def stream_content():
-                for chunk in r.iter_content(chunk_size=8192):
-                    yield chunk
+                try:
+                    yield json.dumps({"id": request_id, "done": False}) + "\n"
+
+                    for chunk in r.iter_content(chunk_size=8192):
+                        if request_id in REQUEST_POOL:
+                            yield chunk
+                        else:
+                            print("User: canceled request")
+                            break
+                finally:
+                    if hasattr(r, "close"):
+                        r.close()
+                        if request_id in REQUEST_POOL:
+                            REQUEST_POOL.remove(request_id)
 
             r = requests.request(
                 method="POST",
@@ -275,6 +290,7 @@ async def pull_model(
 
     try:
         return await run_in_threadpool(get_request)
+
     except Exception as e:
         print(e)
         error_detail = "Open WebUI: Server Connection Error"

+ 1 - 1
src/lib/apis/ollama/index.ts

@@ -271,7 +271,7 @@ export const generateChatCompletion = async (token: string = '', body: object) =
 	return [res, controller];
 };
 
-export const cancelChatCompletion = async (token: string = '', requestId: string) => {
+export const cancelOllamaRequest = async (token: string = '', requestId: string) => {
 	let error = null;
 
 	const res = await fetch(`${OLLAMA_API_BASE_URL}/cancel/${requestId}`, {

+ 90 - 14
src/lib/components/chat/Settings/Models.svelte

@@ -9,6 +9,7 @@
 		getOllamaUrls,
 		getOllamaVersion,
 		pullModel,
+		cancelOllamaRequest,
 		uploadModel
 	} from '$lib/apis/ollama';
 	import { WEBUI_API_BASE_URL, WEBUI_BASE_URL } from '$lib/constants';
@@ -163,7 +164,7 @@
 				// Remove the downloaded model
 				delete modelDownloadStatus[modelName];
 
-				console.log(data);
+				modelDownloadStatus = { ...modelDownloadStatus };
 
 				if (!data.success) {
 					toast.error(data.error);
@@ -372,12 +373,24 @@
 					for (const line of lines) {
 						if (line !== '') {
 							let data = JSON.parse(line);
+							console.log(data);
 							if (data.error) {
 								throw data.error;
 							}
 							if (data.detail) {
 								throw data.detail;
 							}
+
+							if (data.id) {
+								modelDownloadStatus[opts.modelName] = {
+									...modelDownloadStatus[opts.modelName],
+									requestId: data.id,
+									reader,
+									done: false
+								};
+								console.log(data);
+							}
+
 							if (data.status) {
 								if (data.digest) {
 									let downloadProgress = 0;
@@ -387,11 +400,17 @@
 										downloadProgress = 100;
 									}
 									modelDownloadStatus[opts.modelName] = {
+										...modelDownloadStatus[opts.modelName],
 										pullProgress: downloadProgress,
 										digest: data.digest
 									};
 								} else {
 									toast.success(data.status);
+
+									modelDownloadStatus[opts.modelName] = {
+										...modelDownloadStatus[opts.modelName],
+										done: data.status === 'success'
+									};
 								}
 							}
 						}
@@ -404,7 +423,14 @@
 					opts.callback({ success: false, error, modelName: opts.modelName });
 				}
 			}
-			opts.callback({ success: true, modelName: opts.modelName });
+
+			console.log(modelDownloadStatus[opts.modelName]);
+
+			if (modelDownloadStatus[opts.modelName].done) {
+				opts.callback({ success: true, modelName: opts.modelName });
+			} else {
+				opts.callback({ success: false, error: 'Download canceled', modelName: opts.modelName });
+			}
 		}
 	};
 
@@ -474,6 +500,18 @@
 		ollamaVersion = await getOllamaVersion(localStorage.token).catch((error) => false);
 		liteLLMModelInfo = await getLiteLLMModelInfo(localStorage.token);
 	});
+
+	const cancelModelPullHandler = async (model: string) => {
+		const { reader, requestId } = modelDownloadStatus[model];
+		if (reader) {
+			await reader.cancel();
+
+			await cancelOllamaRequest(localStorage.token, requestId);
+			delete modelDownloadStatus[model];
+			await deleteModel(localStorage.token, model);
+			toast.success(`${model} download has been canceled`);
+		}
+	};
 </script>
 
 <div class="flex flex-col h-full justify-between text-sm">
@@ -604,20 +642,58 @@
 
 						{#if Object.keys(modelDownloadStatus).length > 0}
 							{#each Object.keys(modelDownloadStatus) as model}
-								<div class="flex flex-col">
-									<div class="font-medium mb-1">{model}</div>
-									<div class="">
-										<div
-											class="dark:bg-gray-600 bg-gray-500 text-xs font-medium text-gray-100 text-center p-0.5 leading-none rounded-full"
-											style="width: {Math.max(15, modelDownloadStatus[model].pullProgress ?? 0)}%"
-										>
-											{modelDownloadStatus[model].pullProgress ?? 0}%
-										</div>
-										<div class="mt-1 text-xs dark:text-gray-500" style="font-size: 0.5rem;">
-											{modelDownloadStatus[model].digest}
+								{#if 'pullProgress' in modelDownloadStatus[model]}
+									<div class="flex flex-col">
+										<div class="font-medium mb-1">{model}</div>
+										<div class="">
+											<div class="flex flex-row justify-between space-x-4 pr-2">
+												<div class=" flex-1">
+													<div
+														class="dark:bg-gray-600 bg-gray-500 text-xs font-medium text-gray-100 text-center p-0.5 leading-none rounded-full"
+														style="width: {Math.max(
+															15,
+															modelDownloadStatus[model].pullProgress ?? 0
+														)}%"
+													>
+														{modelDownloadStatus[model].pullProgress ?? 0}%
+													</div>
+												</div>
+
+												<Tooltip content="Cancel">
+													<button
+														class="text-gray-800 dark:text-gray-100"
+														on:click={() => {
+															cancelModelPullHandler(model);
+														}}
+													>
+														<svg
+															class="w-4 h-4 text-gray-800 dark:text-white"
+															aria-hidden="true"
+															xmlns="http://www.w3.org/2000/svg"
+															width="24"
+															height="24"
+															fill="currentColor"
+															viewBox="0 0 24 24"
+														>
+															<path
+																stroke="currentColor"
+																stroke-linecap="round"
+																stroke-linejoin="round"
+																stroke-width="2"
+																d="M6 18 17.94 6M18 18 6.06 6"
+															/>
+														</svg>
+													</button>
+												</Tooltip>
+											</div>
+											{#if 'digest' in modelDownloadStatus[model]}
+												<div class="mt-1 text-xs dark:text-gray-500" style="font-size: 0.5rem;">
+													{modelDownloadStatus[model].digest}
+												</div>
+											{/if}
 										</div>
 									</div>
-								</div>
+								{/if}
 							{/each}
 						{/if}
 					</div>

+ 3 - 3
src/routes/(app)/+page.svelte

@@ -19,7 +19,7 @@
 	} from '$lib/stores';
 	import { copyToClipboard, splitStream } from '$lib/utils';
 
-	import { generateChatCompletion, cancelChatCompletion, generateTitle } from '$lib/apis/ollama';
+	import { generateChatCompletion, cancelOllamaRequest, generateTitle } from '$lib/apis/ollama';
 	import {
 		addTagById,
 		createNewChat,
@@ -104,7 +104,7 @@
 
 	const initNewChat = async () => {
 		if (currentRequestId !== null) {
-			await cancelChatCompletion(localStorage.token, currentRequestId);
+			await cancelOllamaRequest(localStorage.token, currentRequestId);
 			currentRequestId = null;
 		}
 		window.history.replaceState(history.state, '', `/`);
@@ -372,7 +372,7 @@
 
 					if (stopResponseFlag) {
 						controller.abort('User: Stop Response');
-						await cancelChatCompletion(localStorage.token, currentRequestId);
+						await cancelOllamaRequest(localStorage.token, currentRequestId);
 					}
 
 					currentRequestId = null;

+ 3 - 3
src/routes/(app)/c/[id]/+page.svelte

@@ -19,7 +19,7 @@
 	} from '$lib/stores';
 	import { copyToClipboard, splitStream, convertMessagesToHistory } from '$lib/utils';
 
-	import { generateChatCompletion, generateTitle, cancelChatCompletion } from '$lib/apis/ollama';
+	import { generateChatCompletion, generateTitle, cancelOllamaRequest } from '$lib/apis/ollama';
 	import {
 		addTagById,
 		createNewChat,
@@ -382,7 +382,7 @@
 
 					if (stopResponseFlag) {
 						controller.abort('User: Stop Response');
-						await cancelChatCompletion(localStorage.token, currentRequestId);
+						await cancelOllamaRequest(localStorage.token, currentRequestId);
 					}
 
 					currentRequestId = null;
@@ -843,7 +843,7 @@
 			shareEnabled={messages.length > 0}
 			initNewChat={async () => {
 				if (currentRequestId !== null) {
-					await cancelChatCompletion(localStorage.token, currentRequestId);
+					await cancelOllamaRequest(localStorage.token, currentRequestId);
 					currentRequestId = null;
 				}
 

+ 4 - 4
src/routes/(app)/playground/+page.svelte

@@ -13,7 +13,7 @@
 	} from '$lib/constants';
 	import { WEBUI_NAME, config, user, models, settings } from '$lib/stores';
 
-	import { cancelChatCompletion, generateChatCompletion } from '$lib/apis/ollama';
+	import { cancelOllamaRequest, generateChatCompletion } from '$lib/apis/ollama';
 	import { generateOpenAIChatCompletion } from '$lib/apis/openai';
 
 	import { splitStream } from '$lib/utils';
@@ -52,7 +52,7 @@
 
 	// const cancelHandler = async () => {
 	// 	if (currentRequestId) {
-	// 		const res = await cancelChatCompletion(localStorage.token, currentRequestId);
+	// 		const res = await cancelOllamaRequest(localStorage.token, currentRequestId);
 	// 		currentRequestId = null;
 	// 		loading = false;
 	// 	}
@@ -95,7 +95,7 @@
 				const { value, done } = await reader.read();
 				if (done || stopResponseFlag) {
 					if (stopResponseFlag) {
-						await cancelChatCompletion(localStorage.token, currentRequestId);
+						await cancelOllamaRequest(localStorage.token, currentRequestId);
 					}
 
 					currentRequestId = null;
@@ -181,7 +181,7 @@
 				const { value, done } = await reader.read();
 				if (done || stopResponseFlag) {
 					if (stopResponseFlag) {
-						await cancelChatCompletion(localStorage.token, currentRequestId);
+						await cancelOllamaRequest(localStorage.token, currentRequestId);
 					}
 
 					currentRequestId = null;