Browse Source

fix: ollama streaming cancellation using aiohttp

Jun Siang Cheah 11 months ago
parent
commit
4dd51badfe

+ 53 - 346
backend/apps/ollama/main.py

@@ -29,6 +29,8 @@ import time
 from urllib.parse import urlparse
 from urllib.parse import urlparse
 from typing import Optional, List, Union
 from typing import Optional, List, Union
 
 
+from starlette.background import BackgroundTask
+
 from apps.webui.models.models import Models
 from apps.webui.models.models import Models
 from apps.webui.models.users import Users
 from apps.webui.models.users import Users
 from constants import ERROR_MESSAGES
 from constants import ERROR_MESSAGES
@@ -75,9 +77,6 @@ app.state.config.OLLAMA_BASE_URLS = OLLAMA_BASE_URLS
 app.state.MODELS = {}
 app.state.MODELS = {}
 
 
 
 
-REQUEST_POOL = []
-
-
 # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
 # TODO: Implement a more intelligent load balancing mechanism for distributing requests among multiple backend instances.
 # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
 # Current implementation uses a simple round-robin approach (random.choice). Consider incorporating algorithms like weighted round-robin,
 # least connections, or least response time for better resource utilization and performance optimization.
 # least connections, or least response time for better resource utilization and performance optimization.
@@ -132,16 +131,6 @@ async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin
     return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
     return {"OLLAMA_BASE_URLS": app.state.config.OLLAMA_BASE_URLS}
 
 
 
 
-@app.get("/cancel/{request_id}")
-async def cancel_ollama_request(request_id: str, user=Depends(get_current_user)):
-    if user:
-        if request_id in REQUEST_POOL:
-            REQUEST_POOL.remove(request_id)
-        return True
-    else:
-        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
-
-
 async def fetch_url(url):
 async def fetch_url(url):
     timeout = aiohttp.ClientTimeout(total=5)
     timeout = aiohttp.ClientTimeout(total=5)
     try:
     try:
@@ -154,6 +143,45 @@ async def fetch_url(url):
         return None
         return None
 
 
 
 
+async def cleanup_response(
+    response: Optional[aiohttp.ClientResponse],
+    session: Optional[aiohttp.ClientSession],
+):
+    if response:
+        response.close()
+    if session:
+        await session.close()
+
+
+async def post_streaming_url(url, payload):
+    r = None
+    try:
+        session = aiohttp.ClientSession()
+        r = await session.post(url, data=payload)
+        r.raise_for_status()
+
+        return StreamingResponse(
+            r.content,
+            status_code=r.status,
+            headers=dict(r.headers),
+            background=BackgroundTask(cleanup_response, response=r, session=session),
+        )
+    except Exception as e:
+        error_detail = "Open WebUI: Server Connection Error"
+        if r is not None:
+            try:
+                res = await r.json()
+                if "error" in res:
+                    error_detail = f"Ollama: {res['error']}"
+            except:
+                error_detail = f"Ollama: {e}"
+
+        raise HTTPException(
+            status_code=r.status if r else 500,
+            detail=error_detail,
+        )
+
+
 def merge_models_lists(model_lists):
 def merge_models_lists(model_lists):
     merged_models = {}
     merged_models = {}
 
 
@@ -313,65 +341,7 @@ async def pull_model(
     # Admin should be able to pull models from any source
     # Admin should be able to pull models from any source
     payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
     payload = {**form_data.model_dump(exclude_none=True), "insecure": True}
 
 
-    def get_request():
-        nonlocal url
-        nonlocal r
-
-        request_id = str(uuid.uuid4())
-        try:
-            REQUEST_POOL.append(request_id)
-
-            def stream_content():
-                try:
-                    yield json.dumps({"id": request_id, "done": False}) + "\n"
-
-                    for chunk in r.iter_content(chunk_size=8192):
-                        if request_id in REQUEST_POOL:
-                            yield chunk
-                        else:
-                            log.warning("User: canceled request")
-                            break
-                finally:
-                    if hasattr(r, "close"):
-                        r.close()
-                        if request_id in REQUEST_POOL:
-                            REQUEST_POOL.remove(request_id)
-
-            r = requests.request(
-                method="POST",
-                url=f"{url}/api/pull",
-                data=json.dumps(payload),
-                stream=True,
-            )
-
-            r.raise_for_status()
-
-            return StreamingResponse(
-                stream_content(),
-                status_code=r.status_code,
-                headers=dict(r.headers),
-            )
-        except Exception as e:
-            raise e
-
-    try:
-        return await run_in_threadpool(get_request)
-
-    except Exception as e:
-        log.exception(e)
-        error_detail = "Open WebUI: Server Connection Error"
-        if r is not None:
-            try:
-                res = r.json()
-                if "error" in res:
-                    error_detail = f"Ollama: {res['error']}"
-            except:
-                error_detail = f"Ollama: {e}"
-
-        raise HTTPException(
-            status_code=r.status_code if r else 500,
-            detail=error_detail,
-        )
+    return await post_streaming_url(f"{url}/api/pull", json.dumps(payload))
 
 
 
 
 class PushModelForm(BaseModel):
 class PushModelForm(BaseModel):
@@ -399,50 +369,9 @@ async def push_model(
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     log.debug(f"url: {url}")
     log.debug(f"url: {url}")
 
 
-    r = None
-
-    def get_request():
-        nonlocal url
-        nonlocal r
-        try:
-
-            def stream_content():
-                for chunk in r.iter_content(chunk_size=8192):
-                    yield chunk
-
-            r = requests.request(
-                method="POST",
-                url=f"{url}/api/push",
-                data=form_data.model_dump_json(exclude_none=True).encode(),
-            )
-
-            r.raise_for_status()
-
-            return StreamingResponse(
-                stream_content(),
-                status_code=r.status_code,
-                headers=dict(r.headers),
-            )
-        except Exception as e:
-            raise e
-
-    try:
-        return await run_in_threadpool(get_request)
-    except Exception as e:
-        log.exception(e)
-        error_detail = "Open WebUI: Server Connection Error"
-        if r is not None:
-            try:
-                res = r.json()
-                if "error" in res:
-                    error_detail = f"Ollama: {res['error']}"
-            except:
-                error_detail = f"Ollama: {e}"
-
-        raise HTTPException(
-            status_code=r.status_code if r else 500,
-            detail=error_detail,
-        )
+    return await post_streaming_url(
+        f"{url}/api/push", form_data.model_dump_json(exclude_none=True).encode()
+    )
 
 
 
 
 class CreateModelForm(BaseModel):
 class CreateModelForm(BaseModel):
@@ -461,53 +390,9 @@ async def create_model(
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     log.info(f"url: {url}")
     log.info(f"url: {url}")
 
 
-    r = None
-
-    def get_request():
-        nonlocal url
-        nonlocal r
-        try:
-
-            def stream_content():
-                for chunk in r.iter_content(chunk_size=8192):
-                    yield chunk
-
-            r = requests.request(
-                method="POST",
-                url=f"{url}/api/create",
-                data=form_data.model_dump_json(exclude_none=True).encode(),
-                stream=True,
-            )
-
-            r.raise_for_status()
-
-            log.debug(f"r: {r}")
-
-            return StreamingResponse(
-                stream_content(),
-                status_code=r.status_code,
-                headers=dict(r.headers),
-            )
-        except Exception as e:
-            raise e
-
-    try:
-        return await run_in_threadpool(get_request)
-    except Exception as e:
-        log.exception(e)
-        error_detail = "Open WebUI: Server Connection Error"
-        if r is not None:
-            try:
-                res = r.json()
-                if "error" in res:
-                    error_detail = f"Ollama: {res['error']}"
-            except:
-                error_detail = f"Ollama: {e}"
-
-        raise HTTPException(
-            status_code=r.status_code if r else 500,
-            detail=error_detail,
-        )
+    return await post_streaming_url(
+        f"{url}/api/create", form_data.model_dump_json(exclude_none=True).encode()
+    )
 
 
 
 
 class CopyModelForm(BaseModel):
 class CopyModelForm(BaseModel):
@@ -797,66 +682,9 @@ async def generate_completion(
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     log.info(f"url: {url}")
     log.info(f"url: {url}")
 
 
-    r = None
-
-    def get_request():
-        nonlocal form_data
-        nonlocal r
-
-        request_id = str(uuid.uuid4())
-        try:
-            REQUEST_POOL.append(request_id)
-
-            def stream_content():
-                try:
-                    if form_data.stream:
-                        yield json.dumps({"id": request_id, "done": False}) + "\n"
-
-                    for chunk in r.iter_content(chunk_size=8192):
-                        if request_id in REQUEST_POOL:
-                            yield chunk
-                        else:
-                            log.warning("User: canceled request")
-                            break
-                finally:
-                    if hasattr(r, "close"):
-                        r.close()
-                        if request_id in REQUEST_POOL:
-                            REQUEST_POOL.remove(request_id)
-
-            r = requests.request(
-                method="POST",
-                url=f"{url}/api/generate",
-                data=form_data.model_dump_json(exclude_none=True).encode(),
-                stream=True,
-            )
-
-            r.raise_for_status()
-
-            return StreamingResponse(
-                stream_content(),
-                status_code=r.status_code,
-                headers=dict(r.headers),
-            )
-        except Exception as e:
-            raise e
-
-    try:
-        return await run_in_threadpool(get_request)
-    except Exception as e:
-        error_detail = "Open WebUI: Server Connection Error"
-        if r is not None:
-            try:
-                res = r.json()
-                if "error" in res:
-                    error_detail = f"Ollama: {res['error']}"
-            except:
-                error_detail = f"Ollama: {e}"
-
-        raise HTTPException(
-            status_code=r.status_code if r else 500,
-            detail=error_detail,
-        )
+    return await post_streaming_url(
+        f"{url}/api/generate", form_data.model_dump_json(exclude_none=True).encode()
+    )
 
 
 
 
 class ChatMessage(BaseModel):
 class ChatMessage(BaseModel):
@@ -981,67 +809,7 @@ async def generate_chat_completion(
 
 
     print(payload)
     print(payload)
 
 
-    r = None
-
-    def get_request():
-        nonlocal payload
-        nonlocal r
-
-        request_id = str(uuid.uuid4())
-        try:
-            REQUEST_POOL.append(request_id)
-
-            def stream_content():
-                try:
-                    if payload.get("stream", True):
-                        yield json.dumps({"id": request_id, "done": False}) + "\n"
-
-                    for chunk in r.iter_content(chunk_size=8192):
-                        if request_id in REQUEST_POOL:
-                            yield chunk
-                        else:
-                            log.warning("User: canceled request")
-                            break
-                finally:
-                    if hasattr(r, "close"):
-                        r.close()
-                        if request_id in REQUEST_POOL:
-                            REQUEST_POOL.remove(request_id)
-
-            r = requests.request(
-                method="POST",
-                url=f"{url}/api/chat",
-                data=json.dumps(payload),
-                stream=True,
-            )
-
-            r.raise_for_status()
-
-            return StreamingResponse(
-                stream_content(),
-                status_code=r.status_code,
-                headers=dict(r.headers),
-            )
-        except Exception as e:
-            log.exception(e)
-            raise e
-
-    try:
-        return await run_in_threadpool(get_request)
-    except Exception as e:
-        error_detail = "Open WebUI: Server Connection Error"
-        if r is not None:
-            try:
-                res = r.json()
-                if "error" in res:
-                    error_detail = f"Ollama: {res['error']}"
-            except:
-                error_detail = f"Ollama: {e}"
-
-        raise HTTPException(
-            status_code=r.status_code if r else 500,
-            detail=error_detail,
-        )
+    return await post_streaming_url(f"{url}/api/chat", json.dumps(payload))
 
 
 
 
 # TODO: we should update this part once Ollama supports other types
 # TODO: we should update this part once Ollama supports other types
@@ -1132,68 +900,7 @@ async def generate_openai_chat_completion(
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     log.info(f"url: {url}")
     log.info(f"url: {url}")
 
 
-    r = None
-
-    def get_request():
-        nonlocal payload
-        nonlocal r
-
-        request_id = str(uuid.uuid4())
-        try:
-            REQUEST_POOL.append(request_id)
-
-            def stream_content():
-                try:
-                    if payload.get("stream"):
-                        yield json.dumps(
-                            {"request_id": request_id, "done": False}
-                        ) + "\n"
-
-                    for chunk in r.iter_content(chunk_size=8192):
-                        if request_id in REQUEST_POOL:
-                            yield chunk
-                        else:
-                            log.warning("User: canceled request")
-                            break
-                finally:
-                    if hasattr(r, "close"):
-                        r.close()
-                        if request_id in REQUEST_POOL:
-                            REQUEST_POOL.remove(request_id)
-
-            r = requests.request(
-                method="POST",
-                url=f"{url}/v1/chat/completions",
-                data=json.dumps(payload),
-                stream=True,
-            )
-
-            r.raise_for_status()
-
-            return StreamingResponse(
-                stream_content(),
-                status_code=r.status_code,
-                headers=dict(r.headers),
-            )
-        except Exception as e:
-            raise e
-
-    try:
-        return await run_in_threadpool(get_request)
-    except Exception as e:
-        error_detail = "Open WebUI: Server Connection Error"
-        if r is not None:
-            try:
-                res = r.json()
-                if "error" in res:
-                    error_detail = f"Ollama: {res['error']}"
-            except:
-                error_detail = f"Ollama: {e}"
-
-        raise HTTPException(
-            status_code=r.status_code if r else 500,
-            detail=error_detail,
-        )
+    return await post_streaming_url(f"{url}/v1/chat/completions", json.dumps(payload))
 
 
 
 
 @app.get("/v1/models")
 @app.get("/v1/models")

+ 3 - 22
src/lib/apis/ollama/index.ts

@@ -369,27 +369,6 @@ export const generateChatCompletion = async (token: string = '', body: object) =
 	return [res, controller];
 	return [res, controller];
 };
 };
 
 
-export const cancelOllamaRequest = async (token: string = '', requestId: string) => {
-	let error = null;
-
-	const res = await fetch(`${OLLAMA_API_BASE_URL}/cancel/${requestId}`, {
-		method: 'GET',
-		headers: {
-			'Content-Type': 'text/event-stream',
-			Authorization: `Bearer ${token}`
-		}
-	}).catch((err) => {
-		error = err;
-		return null;
-	});
-
-	if (error) {
-		throw error;
-	}
-
-	return res;
-};
-
 export const createModel = async (token: string, tagName: string, content: string) => {
 export const createModel = async (token: string, tagName: string, content: string) => {
 	let error = null;
 	let error = null;
 
 
@@ -461,8 +440,10 @@ export const deleteModel = async (token: string, tagName: string, urlIdx: string
 
 
 export const pullModel = async (token: string, tagName: string, urlIdx: string | null = null) => {
 export const pullModel = async (token: string, tagName: string, urlIdx: string | null = null) => {
 	let error = null;
 	let error = null;
+	const controller = new AbortController();
 
 
 	const res = await fetch(`${OLLAMA_API_BASE_URL}/api/pull${urlIdx !== null ? `/${urlIdx}` : ''}`, {
 	const res = await fetch(`${OLLAMA_API_BASE_URL}/api/pull${urlIdx !== null ? `/${urlIdx}` : ''}`, {
+		signal: controller.signal,
 		method: 'POST',
 		method: 'POST',
 		headers: {
 		headers: {
 			Accept: 'application/json',
 			Accept: 'application/json',
@@ -485,7 +466,7 @@ export const pullModel = async (token: string, tagName: string, urlIdx: string |
 	if (error) {
 	if (error) {
 		throw error;
 		throw error;
 	}
 	}
-	return res;
+	return [res, controller];
 };
 };
 
 
 export const downloadModel = async (
 export const downloadModel = async (

+ 50 - 63
src/lib/components/chat/Chat.svelte

@@ -26,7 +26,7 @@
 		splitStream
 		splitStream
 	} from '$lib/utils';
 	} from '$lib/utils';
 
 
-	import { cancelOllamaRequest, generateChatCompletion } from '$lib/apis/ollama';
+	import { generateChatCompletion } from '$lib/apis/ollama';
 	import {
 	import {
 		addTagById,
 		addTagById,
 		createNewChat,
 		createNewChat,
@@ -65,7 +65,6 @@
 	let autoScroll = true;
 	let autoScroll = true;
 	let processing = '';
 	let processing = '';
 	let messagesContainerElement: HTMLDivElement;
 	let messagesContainerElement: HTMLDivElement;
-	let currentRequestId = null;
 
 
 	let showModelSelector = true;
 	let showModelSelector = true;
 
 
@@ -130,10 +129,6 @@
 	//////////////////////////
 	//////////////////////////
 
 
 	const initNewChat = async () => {
 	const initNewChat = async () => {
-		if (currentRequestId !== null) {
-			await cancelOllamaRequest(localStorage.token, currentRequestId);
-			currentRequestId = null;
-		}
 		window.history.replaceState(history.state, '', `/`);
 		window.history.replaceState(history.state, '', `/`);
 		await chatId.set('');
 		await chatId.set('');
 
 
@@ -616,7 +611,6 @@
 
 
 					if (stopResponseFlag) {
 					if (stopResponseFlag) {
 						controller.abort('User: Stop Response');
 						controller.abort('User: Stop Response');
-						await cancelOllamaRequest(localStorage.token, currentRequestId);
 					} else {
 					} else {
 						const messages = createMessagesList(responseMessageId);
 						const messages = createMessagesList(responseMessageId);
 						const res = await chatCompleted(localStorage.token, {
 						const res = await chatCompleted(localStorage.token, {
@@ -647,8 +641,6 @@
 						}
 						}
 					}
 					}
 
 
-					currentRequestId = null;
-
 					break;
 					break;
 				}
 				}
 
 
@@ -669,63 +661,58 @@
 								throw data;
 								throw data;
 							}
 							}
 
 
-							if ('id' in data) {
-								console.log(data);
-								currentRequestId = data.id;
-							} else {
-								if (data.done == false) {
-									if (responseMessage.content == '' && data.message.content == '\n') {
-										continue;
-									} else {
-										responseMessage.content += data.message.content;
-										messages = messages;
-									}
+							if (data.done == false) {
+								if (responseMessage.content == '' && data.message.content == '\n') {
+									continue;
 								} else {
 								} else {
-									responseMessage.done = true;
-
-									if (responseMessage.content == '') {
-										responseMessage.error = {
-											code: 400,
-											content: `Oops! No text generated from Ollama, Please try again.`
-										};
-									}
-
-									responseMessage.context = data.context ?? null;
-									responseMessage.info = {
-										total_duration: data.total_duration,
-										load_duration: data.load_duration,
-										sample_count: data.sample_count,
-										sample_duration: data.sample_duration,
-										prompt_eval_count: data.prompt_eval_count,
-										prompt_eval_duration: data.prompt_eval_duration,
-										eval_count: data.eval_count,
-										eval_duration: data.eval_duration
-									};
+									responseMessage.content += data.message.content;
 									messages = messages;
 									messages = messages;
+								}
+							} else {
+								responseMessage.done = true;
 
 
-									if ($settings.notificationEnabled && !document.hasFocus()) {
-										const notification = new Notification(
-											selectedModelfile
-												? `${
-														selectedModelfile.title.charAt(0).toUpperCase() +
-														selectedModelfile.title.slice(1)
-												  }`
-												: `${model}`,
-											{
-												body: responseMessage.content,
-												icon: selectedModelfile?.imageUrl ?? `${WEBUI_BASE_URL}/static/favicon.png`
-											}
-										);
-									}
-
-									if ($settings.responseAutoCopy) {
-										copyToClipboard(responseMessage.content);
-									}
-
-									if ($settings.responseAutoPlayback) {
-										await tick();
-										document.getElementById(`speak-button-${responseMessage.id}`)?.click();
-									}
+								if (responseMessage.content == '') {
+									responseMessage.error = {
+										code: 400,
+										content: `Oops! No text generated from Ollama, Please try again.`
+									};
+								}
+
+								responseMessage.context = data.context ?? null;
+								responseMessage.info = {
+									total_duration: data.total_duration,
+									load_duration: data.load_duration,
+									sample_count: data.sample_count,
+									sample_duration: data.sample_duration,
+									prompt_eval_count: data.prompt_eval_count,
+									prompt_eval_duration: data.prompt_eval_duration,
+									eval_count: data.eval_count,
+									eval_duration: data.eval_duration
+								};
+								messages = messages;
+
+								if ($settings.notificationEnabled && !document.hasFocus()) {
+									const notification = new Notification(
+										selectedModelfile
+											? `${
+													selectedModelfile.title.charAt(0).toUpperCase() +
+													selectedModelfile.title.slice(1)
+											  }`
+											: `${model}`,
+										{
+											body: responseMessage.content,
+											icon: selectedModelfile?.imageUrl ?? `${WEBUI_BASE_URL}/static/favicon.png`
+										}
+									);
+								}
+
+								if ($settings.responseAutoCopy) {
+									copyToClipboard(responseMessage.content);
+								}
+
+								if ($settings.responseAutoPlayback) {
+									await tick();
+									document.getElementById(`speak-button-${responseMessage.id}`)?.click();
 								}
 								}
 							}
 							}
 						}
 						}

+ 21 - 21
src/lib/components/chat/ModelSelector/Selector.svelte

@@ -8,7 +8,7 @@
 	import Check from '$lib/components/icons/Check.svelte';
 	import Check from '$lib/components/icons/Check.svelte';
 	import Search from '$lib/components/icons/Search.svelte';
 	import Search from '$lib/components/icons/Search.svelte';
 
 
-	import { cancelOllamaRequest, deleteModel, getOllamaVersion, pullModel } from '$lib/apis/ollama';
+	import { deleteModel, getOllamaVersion, pullModel } from '$lib/apis/ollama';
 
 
 	import { user, MODEL_DOWNLOAD_POOL, models, mobile } from '$lib/stores';
 	import { user, MODEL_DOWNLOAD_POOL, models, mobile } from '$lib/stores';
 	import { toast } from 'svelte-sonner';
 	import { toast } from 'svelte-sonner';
@@ -72,10 +72,12 @@
 			return;
 			return;
 		}
 		}
 
 
-		const res = await pullModel(localStorage.token, sanitizedModelTag, '0').catch((error) => {
-			toast.error(error);
-			return null;
-		});
+		const [res, controller] = await pullModel(localStorage.token, sanitizedModelTag, '0').catch(
+			(error) => {
+				toast.error(error);
+				return null;
+			}
+		);
 
 
 		if (res) {
 		if (res) {
 			const reader = res.body
 			const reader = res.body
@@ -83,6 +85,16 @@
 				.pipeThrough(splitStream('\n'))
 				.pipeThrough(splitStream('\n'))
 				.getReader();
 				.getReader();
 
 
+			MODEL_DOWNLOAD_POOL.set({
+				...$MODEL_DOWNLOAD_POOL,
+				[sanitizedModelTag]: {
+					...$MODEL_DOWNLOAD_POOL[sanitizedModelTag],
+					abortController: controller,
+					reader,
+					done: false
+				}
+			});
+
 			while (true) {
 			while (true) {
 				try {
 				try {
 					const { value, done } = await reader.read();
 					const { value, done } = await reader.read();
@@ -101,19 +113,6 @@
 								throw data.detail;
 								throw data.detail;
 							}
 							}
 
 
-							if (data.id) {
-								MODEL_DOWNLOAD_POOL.set({
-									...$MODEL_DOWNLOAD_POOL,
-									[sanitizedModelTag]: {
-										...$MODEL_DOWNLOAD_POOL[sanitizedModelTag],
-										requestId: data.id,
-										reader,
-										done: false
-									}
-								});
-								console.log(data);
-							}
-
 							if (data.status) {
 							if (data.status) {
 								if (data.digest) {
 								if (data.digest) {
 									let downloadProgress = 0;
 									let downloadProgress = 0;
@@ -181,11 +180,12 @@
 	});
 	});
 
 
 	const cancelModelPullHandler = async (model: string) => {
 	const cancelModelPullHandler = async (model: string) => {
-		const { reader, requestId } = $MODEL_DOWNLOAD_POOL[model];
+		const { reader, abortController } = $MODEL_DOWNLOAD_POOL[model];
+		if (abortController) {
+			abortController.abort();
+		}
 		if (reader) {
 		if (reader) {
 			await reader.cancel();
 			await reader.cancel();
-
-			await cancelOllamaRequest(localStorage.token, requestId);
 			delete $MODEL_DOWNLOAD_POOL[model];
 			delete $MODEL_DOWNLOAD_POOL[model];
 			MODEL_DOWNLOAD_POOL.set({
 			MODEL_DOWNLOAD_POOL.set({
 				...$MODEL_DOWNLOAD_POOL
 				...$MODEL_DOWNLOAD_POOL

+ 28 - 27
src/lib/components/chat/Settings/Models.svelte

@@ -8,7 +8,6 @@
 		getOllamaUrls,
 		getOllamaUrls,
 		getOllamaVersion,
 		getOllamaVersion,
 		pullModel,
 		pullModel,
-		cancelOllamaRequest,
 		uploadModel
 		uploadModel
 	} from '$lib/apis/ollama';
 	} from '$lib/apis/ollama';
 
 
@@ -67,12 +66,14 @@
 			console.log(model);
 			console.log(model);
 
 
 			updateModelId = model.id;
 			updateModelId = model.id;
-			const res = await pullModel(localStorage.token, model.id, selectedOllamaUrlIdx).catch(
-				(error) => {
-					toast.error(error);
-					return null;
-				}
-			);
+			const [res, controller] = await pullModel(
+				localStorage.token,
+				model.id,
+				selectedOllamaUrlIdx
+			).catch((error) => {
+				toast.error(error);
+				return null;
+			});
 
 
 			if (res) {
 			if (res) {
 				const reader = res.body
 				const reader = res.body
@@ -141,10 +142,12 @@
 			return;
 			return;
 		}
 		}
 
 
-		const res = await pullModel(localStorage.token, sanitizedModelTag, '0').catch((error) => {
-			toast.error(error);
-			return null;
-		});
+		const [res, controller] = await pullModel(localStorage.token, sanitizedModelTag, '0').catch(
+			(error) => {
+				toast.error(error);
+				return null;
+			}
+		);
 
 
 		if (res) {
 		if (res) {
 			const reader = res.body
 			const reader = res.body
@@ -152,6 +155,16 @@
 				.pipeThrough(splitStream('\n'))
 				.pipeThrough(splitStream('\n'))
 				.getReader();
 				.getReader();
 
 
+			MODEL_DOWNLOAD_POOL.set({
+				...$MODEL_DOWNLOAD_POOL,
+				[sanitizedModelTag]: {
+					...$MODEL_DOWNLOAD_POOL[sanitizedModelTag],
+					abortController: controller,
+					reader,
+					done: false
+				}
+			});
+
 			while (true) {
 			while (true) {
 				try {
 				try {
 					const { value, done } = await reader.read();
 					const { value, done } = await reader.read();
@@ -170,19 +183,6 @@
 								throw data.detail;
 								throw data.detail;
 							}
 							}
 
 
-							if (data.id) {
-								MODEL_DOWNLOAD_POOL.set({
-									...$MODEL_DOWNLOAD_POOL,
-									[sanitizedModelTag]: {
-										...$MODEL_DOWNLOAD_POOL[sanitizedModelTag],
-										requestId: data.id,
-										reader,
-										done: false
-									}
-								});
-								console.log(data);
-							}
-
 							if (data.status) {
 							if (data.status) {
 								if (data.digest) {
 								if (data.digest) {
 									let downloadProgress = 0;
 									let downloadProgress = 0;
@@ -416,11 +416,12 @@
 	};
 	};
 
 
 	const cancelModelPullHandler = async (model: string) => {
 	const cancelModelPullHandler = async (model: string) => {
-		const { reader, requestId } = $MODEL_DOWNLOAD_POOL[model];
+		const { reader, abortController } = $MODEL_DOWNLOAD_POOL[model];
+		if (abortController) {
+			abortController.abort();
+		}
 		if (reader) {
 		if (reader) {
 			await reader.cancel();
 			await reader.cancel();
-
-			await cancelOllamaRequest(localStorage.token, requestId);
 			delete $MODEL_DOWNLOAD_POOL[model];
 			delete $MODEL_DOWNLOAD_POOL[model];
 			MODEL_DOWNLOAD_POOL.set({
 			MODEL_DOWNLOAD_POOL.set({
 				...$MODEL_DOWNLOAD_POOL
 				...$MODEL_DOWNLOAD_POOL

+ 1 - 13
src/lib/components/workspace/Playground.svelte

@@ -8,7 +8,7 @@
 	import { OLLAMA_API_BASE_URL, OPENAI_API_BASE_URL, WEBUI_API_BASE_URL } from '$lib/constants';
 	import { OLLAMA_API_BASE_URL, OPENAI_API_BASE_URL, WEBUI_API_BASE_URL } from '$lib/constants';
 	import { WEBUI_NAME, config, user, models, settings } from '$lib/stores';
 	import { WEBUI_NAME, config, user, models, settings } from '$lib/stores';
 
 
-	import { cancelOllamaRequest, generateChatCompletion } from '$lib/apis/ollama';
+	import { generateChatCompletion } from '$lib/apis/ollama';
 	import { generateOpenAIChatCompletion } from '$lib/apis/openai';
 	import { generateOpenAIChatCompletion } from '$lib/apis/openai';
 
 
 	import { splitStream } from '$lib/utils';
 	import { splitStream } from '$lib/utils';
@@ -24,7 +24,6 @@
 	let selectedModelId = '';
 	let selectedModelId = '';
 
 
 	let loading = false;
 	let loading = false;
-	let currentRequestId = null;
 	let stopResponseFlag = false;
 	let stopResponseFlag = false;
 
 
 	let messagesContainerElement: HTMLDivElement;
 	let messagesContainerElement: HTMLDivElement;
@@ -46,14 +45,6 @@
 		}
 		}
 	};
 	};
 
 
-	// const cancelHandler = async () => {
-	// 	if (currentRequestId) {
-	// 		const res = await cancelOllamaRequest(localStorage.token, currentRequestId);
-	// 		currentRequestId = null;
-	// 		loading = false;
-	// 	}
-	// };
-
 	const stopResponse = () => {
 	const stopResponse = () => {
 		stopResponseFlag = true;
 		stopResponseFlag = true;
 		console.log('stopResponse');
 		console.log('stopResponse');
@@ -171,8 +162,6 @@
 					if (stopResponseFlag) {
 					if (stopResponseFlag) {
 						controller.abort('User: Stop Response');
 						controller.abort('User: Stop Response');
 					}
 					}
-
-					currentRequestId = null;
 					break;
 					break;
 				}
 				}
 
 
@@ -229,7 +218,6 @@
 
 
 			loading = false;
 			loading = false;
 			stopResponseFlag = false;
 			stopResponseFlag = false;
-			currentRequestId = null;
 		}
 		}
 	};
 	};