Selaa lähdekoodia

feat: litellm frontend integration

Timothy J. Baek 1 vuosi sitten
vanhempi
commit
9b6dca3d7f

+ 42 - 0
src/lib/apis/litellm/index.ts

@@ -0,0 +1,42 @@
+import { LITELLM_API_BASE_URL } from '$lib/constants';
+
+export const getLiteLLMModels = async (token: string = '') => {
+	let error = null;
+
+	const res = await fetch(`${LITELLM_API_BASE_URL}/v1/models`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			error = `OpenAI: ${err?.error?.message ?? 'Network Problem'}`;
+			return [];
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	const models = Array.isArray(res) ? res : res?.data ?? null;
+
+	return models
+		? models
+				.map((model) => ({
+					id: model.id,
+					name: model.name ?? model.id,
+					external: true,
+					source: 'litellm'
+				}))
+				.sort((a, b) => {
+					return a.name.localeCompare(b.name);
+				})
+		: models;
+};

+ 5 - 3
src/lib/apis/ollama/index.ts

@@ -128,9 +128,11 @@ export const getOllamaModels = async (token: string = '') => {
 		throw error;
 	}
 
-	return (res?.models ?? []).sort((a, b) => {
-		return a.name.localeCompare(b.name);
-	});
+	return (res?.models ?? [])
+		.map((model) => ({ id: model.model, name: model.name ?? model.model, ...model }))
+		.sort((a, b) => {
+			return a.name.localeCompare(b.name);
+		});
 };
 
 // TODO: migrate to backend

+ 8 - 4
src/lib/apis/openai/index.ts

@@ -163,7 +163,7 @@ export const getOpenAIModels = async (token: string = '') => {
 
 	return models
 		? models
-				.map((model) => ({ name: model.id, external: true }))
+				.map((model) => ({ id: model.id, name: model.name ?? model.id, external: true }))
 				.sort((a, b) => {
 					return a.name.localeCompare(b.name);
 				})
@@ -200,17 +200,21 @@ export const getOpenAIModelsDirect = async (
 	const models = Array.isArray(res) ? res : res?.data ?? null;
 
 	return models
-		.map((model) => ({ name: model.id, external: true }))
+		.map((model) => ({ id: model.id, name: model.name ?? model.id, external: true }))
 		.filter((model) => (base_url.includes('openai') ? model.name.includes('gpt') : true))
 		.sort((a, b) => {
 			return a.name.localeCompare(b.name);
 		});
 };
 
-export const generateOpenAIChatCompletion = async (token: string = '', body: object) => {
+export const generateOpenAIChatCompletion = async (
+	token: string = '',
+	body: object,
+	url: string = OPENAI_API_BASE_URL
+) => {
 	let error = null;
 
-	const res = await fetch(`${OPENAI_API_BASE_URL}/chat/completions`, {
+	const res = await fetch(`${url}/chat/completions`, {
 		method: 'POST',
 		headers: {
 			Authorization: `Bearer ${token}`,

+ 2 - 2
src/lib/components/chat/ModelSelector.svelte

@@ -25,7 +25,7 @@
 
 	$: if (selectedModels.length > 0 && $models.length > 0) {
 		selectedModels = selectedModels.map((model) =>
-			$models.map((m) => m.name).includes(model) ? model : ''
+			$models.map((m) => m.id).includes(model) ? model : ''
 		);
 	}
 </script>
@@ -45,7 +45,7 @@
 					{#if model.name === 'hr'}
 						<hr />
 					{:else}
-						<option value={model.name} class="text-gray-700 text-lg"
+						<option value={model.id} class="text-gray-700 text-lg"
 							>{model.name +
 								`${model.size ? ` (${(model.size / 1024 ** 3).toFixed(1)}GB)` : ''}`}</option
 						>

+ 9 - 0
src/lib/components/chat/SettingsModal.svelte

@@ -4,6 +4,7 @@
 
 	import { getOllamaModels } from '$lib/apis/ollama';
 	import { getOpenAIModels } from '$lib/apis/openai';
+	import { getLiteLLMModels } from '$lib/apis/litellm';
 
 	import Modal from '../common/Modal.svelte';
 	import Account from './Settings/Account.svelte';
@@ -41,7 +42,15 @@
 				console.log(error);
 				return null;
 			});
+
 			models.push(...(openAIModels ? [{ name: 'hr' }, ...openAIModels] : []));
+
+			const liteLLMModels = await getLiteLLMModels(localStorage.token).catch((error) => {
+				console.log(error);
+				return null;
+			});
+
+			models.push(...(liteLLMModels ? [{ name: 'hr' }, ...liteLLMModels] : []));
 		}
 
 		return models;

+ 2 - 0
src/lib/constants.ts

@@ -5,6 +5,8 @@ export const WEBUI_NAME = 'Open WebUI';
 export const WEBUI_BASE_URL = dev ? `http://${location.hostname}:8080` : ``;
 
 export const WEBUI_API_BASE_URL = `${WEBUI_BASE_URL}/api/v1`;
+
+export const LITELLM_API_BASE_URL = `${WEBUI_BASE_URL}/litellm/api`;
 export const OLLAMA_API_BASE_URL = `${WEBUI_BASE_URL}/ollama/api`;
 export const OPENAI_API_BASE_URL = `${WEBUI_BASE_URL}/openai/api`;
 export const AUDIO_API_BASE_URL = `${WEBUI_BASE_URL}/audio/api/v1`;

+ 8 - 0
src/routes/(app)/+layout.svelte

@@ -12,6 +12,7 @@
 	import { getPrompts } from '$lib/apis/prompts';
 
 	import { getOpenAIModels } from '$lib/apis/openai';
+	import { getLiteLLMModels } from '$lib/apis/litellm';
 
 	import {
 		user,
@@ -59,6 +60,13 @@
 
 		models.push(...(openAIModels ? [{ name: 'hr' }, ...openAIModels] : []));
 
+		const liteLLMModels = await getLiteLLMModels(localStorage.token).catch((error) => {
+			console.log(error);
+			return null;
+		});
+
+		models.push(...(liteLLMModels ? [{ name: 'hr' }, ...liteLLMModels] : []));
+
 		return models;
 	};
 

+ 58 - 53
src/routes/(app)/+page.svelte

@@ -36,6 +36,7 @@
 	import ModelSelector from '$lib/components/chat/ModelSelector.svelte';
 	import Navbar from '$lib/components/layout/Navbar.svelte';
 	import { RAGTemplate } from '$lib/utils/rag';
+	import { LITELLM_API_BASE_URL, OPENAI_API_BASE_URL } from '$lib/constants';
 
 	let stopResponseFlag = false;
 	let autoScroll = true;
@@ -277,9 +278,8 @@
 		}
 
 		await Promise.all(
-			selectedModels.map(async (model) => {
-				console.log(model);
-				const modelTag = $models.filter((m) => m.name === model).at(0);
+			selectedModels.map(async (modelId) => {
+				const model = $models.filter((m) => m.id === modelId).at(0);
 
 				// Create response message
 				let responseMessageId = uuidv4();
@@ -289,7 +289,7 @@
 					childrenIds: [],
 					role: 'assistant',
 					content: '',
-					model: model,
+					model: model.id,
 					timestamp: Math.floor(Date.now() / 1000) // Unix epoch
 				};
 
@@ -305,12 +305,12 @@
 					];
 				}
 
-				if (modelTag?.external) {
+				if (model?.external) {
 					await sendPromptOpenAI(model, prompt, responseMessageId, _chatId);
-				} else if (modelTag) {
+				} else if (model) {
 					await sendPromptOllama(model, prompt, responseMessageId, _chatId);
 				} else {
-					toast.error(`Model ${model} not found`);
+					toast.error(`Model ${model.id} not found`);
 				}
 			})
 		);
@@ -319,6 +319,7 @@
 	};
 
 	const sendPromptOllama = async (model, userPrompt, responseMessageId, _chatId) => {
+		model = model.id;
 		const responseMessage = history.messages[responseMessageId];
 
 		// Wait until history/message have been updated
@@ -530,54 +531,58 @@
 		const responseMessage = history.messages[responseMessageId];
 		scrollToBottom();
 
-		const res = await generateOpenAIChatCompletion(localStorage.token, {
-			model: model,
-			stream: true,
-			messages: [
-				$settings.system
-					? {
-							role: 'system',
-							content: $settings.system
-					  }
-					: undefined,
-				...messages
-			]
-				.filter((message) => message)
-				.map((message, idx, arr) => ({
-					role: message.role,
-					...(message.files?.filter((file) => file.type === 'image').length > 0 ?? false
+		const res = await generateOpenAIChatCompletion(
+			localStorage.token,
+			{
+				model: model.id,
+				stream: true,
+				messages: [
+					$settings.system
 						? {
-								content: [
-									{
-										type: 'text',
-										text:
-											arr.length - 1 !== idx
-												? message.content
-												: message?.raContent ?? message.content
-									},
-									...message.files
-										.filter((file) => file.type === 'image')
-										.map((file) => ({
-											type: 'image_url',
-											image_url: {
-												url: file.url
-											}
-										}))
-								]
+								role: 'system',
+								content: $settings.system
 						  }
-						: {
-								content:
-									arr.length - 1 !== idx ? message.content : message?.raContent ?? message.content
-						  })
-				})),
-			seed: $settings?.options?.seed ?? undefined,
-			stop: $settings?.options?.stop ?? undefined,
-			temperature: $settings?.options?.temperature ?? undefined,
-			top_p: $settings?.options?.top_p ?? undefined,
-			num_ctx: $settings?.options?.num_ctx ?? undefined,
-			frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
-			max_tokens: $settings?.options?.num_predict ?? undefined
-		});
+						: undefined,
+					...messages
+				]
+					.filter((message) => message)
+					.map((message, idx, arr) => ({
+						role: message.role,
+						...(message.files?.filter((file) => file.type === 'image').length > 0 ?? false
+							? {
+									content: [
+										{
+											type: 'text',
+											text:
+												arr.length - 1 !== idx
+													? message.content
+													: message?.raContent ?? message.content
+										},
+										...message.files
+											.filter((file) => file.type === 'image')
+											.map((file) => ({
+												type: 'image_url',
+												image_url: {
+													url: file.url
+												}
+											}))
+									]
+							  }
+							: {
+									content:
+										arr.length - 1 !== idx ? message.content : message?.raContent ?? message.content
+							  })
+					})),
+				seed: $settings?.options?.seed ?? undefined,
+				stop: $settings?.options?.stop ?? undefined,
+				temperature: $settings?.options?.temperature ?? undefined,
+				top_p: $settings?.options?.top_p ?? undefined,
+				num_ctx: $settings?.options?.num_ctx ?? undefined,
+				frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
+				max_tokens: $settings?.options?.num_predict ?? undefined
+			},
+			model.source === 'litellm' ? `${LITELLM_API_BASE_URL}/v1` : `${OPENAI_API_BASE_URL}`
+		);
 
 		if (res && res.ok) {
 			const reader = res.body