Bläddra i källkod

feat: cancel request

Resolves #1006
Timothy J. Baek 1 år sedan
förälder
incheckning
8d34324d12
2 ändrade filer med 126 tillägg och 63 borttagningar
  1. 100 2
      backend/apps/ollama/main.py
  2. 26 61
      src/routes/(app)/playground/+page.svelte

+ 100 - 2
backend/apps/ollama/main.py

@@ -3,7 +3,7 @@ from fastapi.middleware.cors import CORSMiddleware
 from fastapi.responses import StreamingResponse
 from fastapi.concurrency import run_in_threadpool
 
-from pydantic import BaseModel
+from pydantic import BaseModel, ConfigDict
 
 import random
 import requests
@@ -684,7 +684,7 @@ class GenerateChatCompletionForm(BaseModel):
 
 @app.post("/api/chat")
 @app.post("/api/chat/{url_idx}")
-async def generate_completion(
+async def generate_chat_completion(
     form_data: GenerateChatCompletionForm,
     url_idx: Optional[int] = None,
     user=Depends(get_current_user),
@@ -765,6 +765,104 @@ async def generate_completion(
         )
 
 
+# TODO: we should update this part once Ollama supports other types
+class OpenAIChatMessage(BaseModel):
+    role: str
+    content: str
+
+    model_config = ConfigDict(extra="allow")
+
+
+class OpenAIChatCompletionForm(BaseModel):
+    model: str
+    messages: List[OpenAIChatMessage]
+
+    model_config = ConfigDict(extra="allow")
+
+
+@app.post("/v1/chat/completions")
+@app.post("/v1/chat/completions/{url_idx}")
+async def generate_openai_chat_completion(
+    form_data: OpenAIChatCompletionForm,
+    url_idx: Optional[int] = None,
+    user=Depends(get_current_user),
+):
+
+    if url_idx == None:
+        if form_data.model in app.state.MODELS:
+            url_idx = random.choice(app.state.MODELS[form_data.model]["urls"])
+        else:
+            raise HTTPException(
+                status_code=400,
+                detail=ERROR_MESSAGES.MODEL_NOT_FOUND(form_data.model),
+            )
+
+    url = app.state.OLLAMA_BASE_URLS[url_idx]
+
+    r = None
+
+    def get_request():
+        nonlocal form_data
+        nonlocal r
+
+        request_id = str(uuid.uuid4())
+        try:
+            REQUEST_POOL.append(request_id)
+
+            def stream_content():
+                try:
+                    if form_data.stream:
+                        yield json.dumps(
+                            {"request_id": request_id, "done": False}
+                        ) + "\n"
+
+                    for chunk in r.iter_content(chunk_size=8192):
+                        if request_id in REQUEST_POOL:
+                            yield chunk
+                        else:
+                            print("User: canceled request")
+                            break
+                finally:
+                    if hasattr(r, "close"):
+                        r.close()
+                        if request_id in REQUEST_POOL:
+                            REQUEST_POOL.remove(request_id)
+
+            r = requests.request(
+                method="POST",
+                url=f"{url}/v1/chat/completions",
+                data=form_data.model_dump_json(exclude_none=True),
+                stream=True,
+            )
+
+            r.raise_for_status()
+
+            return StreamingResponse(
+                stream_content(),
+                status_code=r.status_code,
+                headers=dict(r.headers),
+            )
+        except Exception as e:
+            raise e
+
+    try:
+        return await run_in_threadpool(get_request)
+    except Exception as e:
+        error_detail = "Open WebUI: Server Connection Error"
+        if r is not None:
+            try:
+                res = r.json()
+                if "error" in res:
+                    error_detail = f"Ollama: {res['error']}"
+            except:
+                error_detail = f"Ollama: {e}"
+
+        raise HTTPException(
+            status_code=r.status_code if r else 500,
+            detail=error_detail,
+        )
+
+
 @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
 async def deprecated_proxy(path: str, request: Request, user=Depends(get_current_user)):
     url = app.state.OLLAMA_BASE_URLS[0]

+ 26 - 61
src/routes/(app)/playground/+page.svelte

@@ -26,7 +26,7 @@
 	let selectedModelId = '';
 
 	let loading = false;
-	let currentRequestId;
+	let currentRequestId = null;
 	let stopResponseFlag = false;
 
 	let messagesContainerElement: HTMLDivElement;
@@ -92,6 +92,10 @@
 			while (true) {
 				const { value, done } = await reader.read();
 				if (done || stopResponseFlag) {
+					if (stopResponseFlag) {
+						await cancelChatCompletion(localStorage.token, currentRequestId);
+					}
+
 					currentRequestId = null;
 					break;
 				}
@@ -108,7 +112,11 @@
 								let data = JSON.parse(line.replace(/^data: /, ''));
 								console.log(data);
 
-								text += data.choices[0].delta.content ?? '';
+								if ('request_id' in data) {
+									currentRequestId = data.request_id;
+								} else {
+									text += data.choices[0].delta.content ?? '';
+								}
 							}
 						}
 					}
@@ -146,16 +154,6 @@
 				: `${OLLAMA_API_BASE_URL}/v1`
 		);
 
-		// const [res, controller] = await generateChatCompletion(localStorage.token, {
-		// 	model: selectedModelId,
-		// 	messages: [
-		// 		{
-		// 			role: 'assistant',
-		// 			content: text
-		// 		}
-		// 	]
-		// });
-
 		let responseMessage;
 		if (messages.at(-1)?.role === 'assistant') {
 			responseMessage = messages.at(-1);
@@ -180,6 +178,11 @@
 			while (true) {
 				const { value, done } = await reader.read();
 				if (done || stopResponseFlag) {
+					if (stopResponseFlag) {
+						await cancelChatCompletion(localStorage.token, currentRequestId);
+					}
+
+					currentRequestId = null;
 					break;
 				}
 
@@ -196,17 +199,21 @@
 								let data = JSON.parse(line.replace(/^data: /, ''));
 								console.log(data);
 
-								if (responseMessage.content == '' && data.choices[0].delta.content == '\n') {
-									continue;
+								if ('request_id' in data) {
+									currentRequestId = data.request_id;
 								} else {
-									textareaElement.style.height = textareaElement.scrollHeight + 'px';
+									if (responseMessage.content == '' && data.choices[0].delta.content == '\n') {
+										continue;
+									} else {
+										textareaElement.style.height = textareaElement.scrollHeight + 'px';
 
-									responseMessage.content += data.choices[0].delta.content ?? '';
-									messages = messages;
+										responseMessage.content += data.choices[0].delta.content ?? '';
+										messages = messages;
 
-									textareaElement.style.height = textareaElement.scrollHeight + 'px';
+										textareaElement.style.height = textareaElement.scrollHeight + 'px';
 
-									await tick();
+										await tick();
+									}
 								}
 							}
 						}
@@ -217,48 +224,6 @@
 
 				scrollToBottom();
 			}
-
-			// while (true) {
-			// 	const { value, done } = await reader.read();
-			// 	if (done || stopResponseFlag) {
-			// 		if (stopResponseFlag) {
-			// 			await cancelChatCompletion(localStorage.token, currentRequestId);
-			// 		}
-
-			// 		currentRequestId = null;
-			// 		break;
-			// 	}
-
-			// 	try {
-			// 		let lines = value.split('\n');
-
-			// 		for (const line of lines) {
-			// 			if (line !== '') {
-			// 				console.log(line);
-			// 				let data = JSON.parse(line);
-
-			// 				if ('detail' in data) {
-			// 					throw data;
-			// 				}
-
-			// 				if ('id' in data) {
-			// 					console.log(data);
-			// 					currentRequestId = data.id;
-			// 				} else {
-			// 					if (data.done == false) {
-			// 						text += data.message.content;
-			// 					} else {
-			// 						console.log('done');
-			// 					}
-			// 				}
-			// 			}
-			// 		}
-			// 	} catch (error) {
-			// 		console.log(error);
-			// 	}
-
-			// 	scrollToBottom();
-			// }
 		}
 	};