Преглед изворни кода

feat: chat playground backend integration

Timothy J. Baek пре 1 година
родитељ
комит
901e7a33fa

+ 6 - 6
backend/apps/ollama/main.py

@@ -11,7 +11,7 @@ from pydantic import BaseModel
 from apps.web.models.users import Users
 from constants import ERROR_MESSAGES
 from utils.utils import decode_token, get_current_user, get_admin_user
-from config import OLLAMA_API_BASE_URL, WEBUI_AUTH
+from config import OLLAMA_BASE_URL, WEBUI_AUTH
 
 app = FastAPI()
 app.add_middleware(
@@ -22,7 +22,7 @@ app.add_middleware(
     allow_headers=["*"],
 )
 
-app.state.OLLAMA_API_BASE_URL = OLLAMA_API_BASE_URL
+app.state.OLLAMA_BASE_URL = OLLAMA_BASE_URL
 
 # TARGET_SERVER_URL = OLLAMA_API_BASE_URL
 
@@ -32,7 +32,7 @@ REQUEST_POOL = []
 
 @app.get("/url")
 async def get_ollama_api_url(user=Depends(get_admin_user)):
-    return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
+    return {"OLLAMA_BASE_URL": app.state.OLLAMA_BASE_URL}
 
 
 class UrlUpdateForm(BaseModel):
@@ -41,8 +41,8 @@ class UrlUpdateForm(BaseModel):
 
 @app.post("/url/update")
 async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
-    app.state.OLLAMA_API_BASE_URL = form_data.url
-    return {"OLLAMA_API_BASE_URL": app.state.OLLAMA_API_BASE_URL}
+    app.state.OLLAMA_BASE_URL = form_data.url
+    return {"OLLAMA_BASE_URL": app.state.OLLAMA_BASE_URL}
 
 
 @app.get("/cancel/{request_id}")
@@ -57,7 +57,7 @@ async def cancel_ollama_request(request_id: str, user=Depends(get_current_user))
 
 @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
 async def proxy(path: str, request: Request, user=Depends(get_current_user)):
-    target_url = f"{app.state.OLLAMA_API_BASE_URL}/{path}"
+    target_url = f"{app.state.OLLAMA_BASE_URL}/{path}"
 
     body = await request.body()
     headers = dict(request.headers)

+ 11 - 0
backend/config.py

@@ -211,6 +211,17 @@ if ENV == "prod":
     if OLLAMA_API_BASE_URL == "/ollama/api":
         OLLAMA_API_BASE_URL = "http://host.docker.internal:11434/api"
 
+
+OLLAMA_BASE_URL = os.environ.get("OLLAMA_BASE_URL", "")
+
+if OLLAMA_BASE_URL == "":
+    OLLAMA_BASE_URL = (
+        OLLAMA_API_BASE_URL[:-4]
+        if OLLAMA_API_BASE_URL.endswith("/api")
+        else OLLAMA_API_BASE_URL
+    )
+
+
 ####################################
 # OPENAI_API
 ####################################

+ 11 - 11
src/lib/apis/ollama/index.ts

@@ -29,7 +29,7 @@ export const getOllamaAPIUrl = async (token: string = '') => {
 		throw error;
 	}
 
-	return res.OLLAMA_API_BASE_URL;
+	return res.OLLAMA_BASE_URL;
 };
 
 export const updateOllamaAPIUrl = async (token: string = '', url: string) => {
@@ -64,13 +64,13 @@ export const updateOllamaAPIUrl = async (token: string = '', url: string) => {
 		throw error;
 	}
 
-	return res.OLLAMA_API_BASE_URL;
+	return res.OLLAMA_BASE_URL;
 };
 
 export const getOllamaVersion = async (token: string = '') => {
 	let error = null;
 
-	const res = await fetch(`${OLLAMA_API_BASE_URL}/version`, {
+	const res = await fetch(`${OLLAMA_API_BASE_URL}/api/version`, {
 		method: 'GET',
 		headers: {
 			Accept: 'application/json',
@@ -102,7 +102,7 @@ export const getOllamaVersion = async (token: string = '') => {
 export const getOllamaModels = async (token: string = '') => {
 	let error = null;
 
-	const res = await fetch(`${OLLAMA_API_BASE_URL}/tags`, {
+	const res = await fetch(`${OLLAMA_API_BASE_URL}/api/tags`, {
 		method: 'GET',
 		headers: {
 			Accept: 'application/json',
@@ -148,7 +148,7 @@ export const generateTitle = async (
 
 	console.log(template);
 
-	const res = await fetch(`${OLLAMA_API_BASE_URL}/generate`, {
+	const res = await fetch(`${OLLAMA_API_BASE_URL}/api/generate`, {
 		method: 'POST',
 		headers: {
 			'Content-Type': 'text/event-stream',
@@ -186,7 +186,7 @@ export const generatePrompt = async (token: string = '', model: string, conversa
 		conversation = '[no existing conversation]';
 	}
 
-	const res = await fetch(`${OLLAMA_API_BASE_URL}/generate`, {
+	const res = await fetch(`${OLLAMA_API_BASE_URL}/api/generate`, {
 		method: 'POST',
 		headers: {
 			'Content-Type': 'text/event-stream',
@@ -220,7 +220,7 @@ export const generatePrompt = async (token: string = '', model: string, conversa
 export const generateTextCompletion = async (token: string = '', model: string, text: string) => {
 	let error = null;
 
-	const res = await fetch(`${OLLAMA_API_BASE_URL}/generate`, {
+	const res = await fetch(`${OLLAMA_API_BASE_URL}/api/generate`, {
 		method: 'POST',
 		headers: {
 			'Content-Type': 'text/event-stream',
@@ -247,7 +247,7 @@ export const generateChatCompletion = async (token: string = '', body: object) =
 	let controller = new AbortController();
 	let error = null;
 
-	const res = await fetch(`${OLLAMA_API_BASE_URL}/chat`, {
+	const res = await fetch(`${OLLAMA_API_BASE_URL}/api/chat`, {
 		signal: controller.signal,
 		method: 'POST',
 		headers: {
@@ -291,7 +291,7 @@ export const cancelChatCompletion = async (token: string = '', requestId: string
 export const createModel = async (token: string, tagName: string, content: string) => {
 	let error = null;
 
-	const res = await fetch(`${OLLAMA_API_BASE_URL}/create`, {
+	const res = await fetch(`${OLLAMA_API_BASE_URL}/api/create`, {
 		method: 'POST',
 		headers: {
 			'Content-Type': 'text/event-stream',
@@ -316,7 +316,7 @@ export const createModel = async (token: string, tagName: string, content: strin
 export const deleteModel = async (token: string, tagName: string) => {
 	let error = null;
 
-	const res = await fetch(`${OLLAMA_API_BASE_URL}/delete`, {
+	const res = await fetch(`${OLLAMA_API_BASE_URL}/api/delete`, {
 		method: 'DELETE',
 		headers: {
 			'Content-Type': 'text/event-stream',
@@ -350,7 +350,7 @@ export const deleteModel = async (token: string, tagName: string) => {
 export const pullModel = async (token: string, tagName: string) => {
 	let error = null;
 
-	const res = await fetch(`${OLLAMA_API_BASE_URL}/pull`, {
+	const res = await fetch(`${OLLAMA_API_BASE_URL}/api/pull`, {
 		method: 'POST',
 		headers: {
 			'Content-Type': 'text/event-stream',

+ 2 - 2
src/lib/components/chat/Settings/Connections.svelte

@@ -114,12 +114,12 @@
 		<hr class=" dark:border-gray-700" />
 
 		<div>
-			<div class=" mb-2.5 text-sm font-medium">Ollama API URL</div>
+			<div class=" mb-2.5 text-sm font-medium">Ollama Base URL</div>
 			<div class="flex w-full">
 				<div class="flex-1 mr-2">
 					<input
 						class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
-						placeholder="Enter URL (e.g. http://localhost:11434/api)"
+						placeholder="Enter URL (e.g. http://localhost:11434)"
 						bind:value={API_BASE_URL}
 					/>
 				</div>

+ 229 - 62
src/routes/(app)/playground/+page.svelte

@@ -1,14 +1,21 @@
 <script>
 	import { goto } from '$app/navigation';
 
-	import { onMount } from 'svelte';
+	import { onMount, tick } from 'svelte';
 
 	import { toast } from 'svelte-sonner';
 
-	import { WEBUI_API_BASE_URL } from '$lib/constants';
+	import {
+		LITELLM_API_BASE_URL,
+		OLLAMA_API_BASE_URL,
+		OPENAI_API_BASE_URL,
+		WEBUI_API_BASE_URL
+	} from '$lib/constants';
 	import { WEBUI_NAME, config, user, models, settings } from '$lib/stores';
 
 	import { cancelChatCompletion, generateChatCompletion } from '$lib/apis/ollama';
+	import { generateOpenAIChatCompletion } from '$lib/apis/openai';
+
 	import { splitStream } from '$lib/utils';
 
 	let mode = 'chat';
@@ -16,18 +23,28 @@
 
 	let text = '';
 
-	let selectedModel = '';
+	let selectedModelId = '';
 
 	let loading = false;
 	let currentRequestId;
 	let stopResponseFlag = false;
 
 	let system = '';
-	let messages = [];
+	let messages = [
+		{
+			role: 'user',
+			content: ''
+		}
+	];
 
 	const scrollToBottom = () => {
-		const element = document.getElementById('text-completion-textarea');
-		element.scrollTop = element.scrollHeight;
+		// const element = document.getElementById('text-completion-textarea');
+
+		const element = document.getElementById('messages-container');
+
+		if (element) {
+			element.scrollTop = element?.scrollHeight;
+		}
 	};
 
 	// const cancelHandler = async () => {
@@ -43,67 +60,216 @@
 		console.log('stopResponse');
 	};
 
-	const submitHandler = async () => {
-		if (selectedModel) {
-			loading = true;
-
-			const [res, controller] = await generateChatCompletion(localStorage.token, {
-				model: selectedModel,
-				messages: [
-					{
-						role: 'assistant',
-						content: text
+	const textCompletionHandler = async () => {
+		const [res, controller] = await generateChatCompletion(localStorage.token, {
+			model: selectedModelId,
+			messages: [
+				{
+					role: 'assistant',
+					content: text
+				}
+			]
+		});
+
+		if (res && res.ok) {
+			const reader = res.body
+				.pipeThrough(new TextDecoderStream())
+				.pipeThrough(splitStream('\n'))
+				.getReader();
+
+			while (true) {
+				const { value, done } = await reader.read();
+				if (done || stopResponseFlag) {
+					if (stopResponseFlag) {
+						await cancelChatCompletion(localStorage.token, currentRequestId);
 					}
-				]
-			});
-
-			if (res && res.ok) {
-				const reader = res.body
-					.pipeThrough(new TextDecoderStream())
-					.pipeThrough(splitStream('\n'))
-					.getReader();
-
-				while (true) {
-					const { value, done } = await reader.read();
-					if (done || stopResponseFlag) {
-						if (stopResponseFlag) {
-							await cancelChatCompletion(localStorage.token, currentRequestId);
-						}
 
-						currentRequestId = null;
-						break;
-					}
+					currentRequestId = null;
+					break;
+				}
 
-					try {
-						let lines = value.split('\n');
+				try {
+					let lines = value.split('\n');
 
-						for (const line of lines) {
-							if (line !== '') {
-								console.log(line);
-								let data = JSON.parse(line);
+					for (const line of lines) {
+						if (line !== '') {
+							console.log(line);
+							let data = JSON.parse(line);
 
-								if ('detail' in data) {
-									throw data;
-								}
+							if ('detail' in data) {
+								throw data;
+							}
 
-								if ('id' in data) {
-									console.log(data);
-									currentRequestId = data.id;
+							if ('id' in data) {
+								console.log(data);
+								currentRequestId = data.id;
+							} else {
+								if (data.done == false) {
+									text += data.message.content;
 								} else {
-									if (data.done == false) {
-										text += data.message.content;
-									} else {
-										console.log('done');
-									}
+									console.log('done');
 								}
 							}
 						}
-					} catch (error) {
-						console.log(error);
 					}
+				} catch (error) {
+					console.log(error);
+				}
+
+				scrollToBottom();
+			}
+		}
+	};
+
+	const chatCompletionHandler = async () => {
+		const model = $models.find((model) => model.id === selectedModelId);
+
+		const res = await generateOpenAIChatCompletion(
+			localStorage.token,
+			{
+				model: model.id,
+				stream: true,
+				messages: [
+					system
+						? {
+								role: 'system',
+								content: system
+						  }
+						: undefined,
+					...messages
+				].filter((message) => message)
+			},
+			model.external
+				? model.source === 'litellm'
+					? `${LITELLM_API_BASE_URL}/v1`
+					: `${OPENAI_API_BASE_URL}`
+				: `${OLLAMA_API_BASE_URL}/v1`
+		);
+
+		// const [res, controller] = await generateChatCompletion(localStorage.token, {
+		// 	model: selectedModelId,
+		// 	messages: [
+		// 		{
+		// 			role: 'assistant',
+		// 			content: text
+		// 		}
+		// 	]
+		// });
+
+		let responseMessage;
+		if (messages.at(-1)?.role === 'assistant') {
+			responseMessage = messages.at(-1);
+		} else {
+			responseMessage = {
+				role: 'assistant',
+				content: ''
+			};
+			messages.push(responseMessage);
+			messages = messages;
+		}
+
+		await tick();
+		const textareaElement = document.getElementById(`assistant-${messages.length - 1}-textarea`);
+
+		if (res && res.ok) {
+			const reader = res.body
+				.pipeThrough(new TextDecoderStream())
+				.pipeThrough(splitStream('\n'))
+				.getReader();
+
+			while (true) {
+				const { value, done } = await reader.read();
+				if (done || stopResponseFlag) {
+					break;
+				}
+
+				try {
+					let lines = value.split('\n');
+
+					for (const line of lines) {
+						if (line !== '') {
+							console.log(line);
+							if (line === 'data: [DONE]') {
+								// responseMessage.done = true;
+								messages = messages;
+							} else {
+								let data = JSON.parse(line.replace(/^data: /, ''));
+								console.log(data);
+
+								if (responseMessage.content == '' && data.choices[0].delta.content == '\n') {
+									continue;
+								} else {
+									textareaElement.style.height = textareaElement.scrollHeight + 'px';
 
-					scrollToBottom();
+									responseMessage.content += data.choices[0].delta.content ?? '';
+									messages = messages;
+
+									textareaElement.style.height = textareaElement.scrollHeight + 'px';
+
+									await tick();
+								}
+							}
+						}
+					}
+				} catch (error) {
+					console.log(error);
 				}
+
+				scrollToBottom();
+			}
+
+			// while (true) {
+			// 	const { value, done } = await reader.read();
+			// 	if (done || stopResponseFlag) {
+			// 		if (stopResponseFlag) {
+			// 			await cancelChatCompletion(localStorage.token, currentRequestId);
+			// 		}
+
+			// 		currentRequestId = null;
+			// 		break;
+			// 	}
+
+			// 	try {
+			// 		let lines = value.split('\n');
+
+			// 		for (const line of lines) {
+			// 			if (line !== '') {
+			// 				console.log(line);
+			// 				let data = JSON.parse(line);
+
+			// 				if ('detail' in data) {
+			// 					throw data;
+			// 				}
+
+			// 				if ('id' in data) {
+			// 					console.log(data);
+			// 					currentRequestId = data.id;
+			// 				} else {
+			// 					if (data.done == false) {
+			// 						text += data.message.content;
+			// 					} else {
+			// 						console.log('done');
+			// 					}
+			// 				}
+			// 			}
+			// 		}
+			// 	} catch (error) {
+			// 		console.log(error);
+			// 	}
+
+			// 	scrollToBottom();
+			// }
+		}
+	};
+
+	const submitHandler = async () => {
+		if (selectedModelId) {
+			loading = true;
+
+			if (mode === 'complete') {
+				await textCompletionHandler();
+			} else if (mode === 'chat') {
+				await chatCompletionHandler();
 			}
 
 			loading = false;
@@ -118,11 +284,11 @@
 		}
 
 		if ($settings?.models) {
-			selectedModel = $settings?.models[0];
+			selectedModelId = $settings?.models[0];
 		} else if ($config?.default_models) {
-			selectedModel = $config?.default_models.split(',')[0];
+			selectedModelId = $config?.default_models.split(',')[0];
 		} else {
-			selectedModel = '';
+			selectedModelId = '';
 		}
 		loaded = true;
 	});
@@ -185,7 +351,7 @@
 						<select
 							id="models"
 							class="outline-none bg-transparent text-sm font-medium rounded-lg w-full placeholder-gray-400"
-							bind:value={selectedModel}
+							bind:value={selectedModelId}
 						>
 							<option class=" text-gray-800" value="" selected disabled>Select a model</option>
 
@@ -234,10 +400,11 @@
 						<div class="p-3 outline outline-1 outline-gray-200 dark:outline-gray-800 rounded-lg">
 							<div class=" text-sm font-medium">System</div>
 							<textarea
-								id="text-completion-textarea"
+								id="system-textarea"
 								class="w-full h-full bg-transparent resize-none outline-none text-sm"
 								bind:value={system}
 								placeholder="You're a helpful assistant."
+								rows="4"
 							/>
 						</div>
 					</div>
@@ -271,8 +438,8 @@
 
 											<div class="flex-1">
 												<textarea
-													id="text-completion-textarea"
-													class="w-full bg-transparent outline-none rounded-lg p-2 text-sm resize-none"
+													id="{message.role}-{idx}-textarea"
+													class="w-full bg-transparent outline-none rounded-lg p-2 text-sm resize-none overflow-hidden"
 													placeholder="Enter {message.role === 'user'
 														? 'a user'
 														: 'an assistant'} message here"
@@ -320,7 +487,7 @@
 									{/each}
 
 									<button
-										class="flex items-center gap-2"
+										class="flex items-center gap-2 px-2 py-1"
 										on:click={() => {
 											console.log(messages.at(-1));
 											messages.push({