浏览代码

Merge branch 'main' into bun

Timothy Jaeryang Baek 1 年之前
父节点
当前提交
1a93191259

+ 4 - 0
Dockerfile

@@ -16,6 +16,10 @@ ARG OLLAMA_API_BASE_URL='/ollama/api'
 
 ENV ENV=prod
 ENV OLLAMA_API_BASE_URL $OLLAMA_API_BASE_URL
+
+ENV OPENAI_API_BASE_URL ""
+ENV OPENAI_API_KEY ""
+
 ENV WEBUI_JWT_SECRET_KEY "SECRET_KEY"
 
 WORKDIR /app

+ 135 - 0
backend/apps/openai/main.py

@@ -0,0 +1,135 @@
+from fastapi import FastAPI, Request, Response, HTTPException, Depends
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import StreamingResponse, JSONResponse
+
+import requests
+import json
+from pydantic import BaseModel
+
+from apps.web.models.users import Users
+from constants import ERROR_MESSAGES
+from utils.utils import decode_token, get_current_user
+from config import OPENAI_API_BASE_URL, OPENAI_API_KEY
+
+app = FastAPI()
+app.add_middleware(
+    CORSMiddleware,
+    allow_origins=["*"],
+    allow_credentials=True,
+    allow_methods=["*"],
+    allow_headers=["*"],
+)
+
+app.state.OPENAI_API_BASE_URL = OPENAI_API_BASE_URL
+app.state.OPENAI_API_KEY = OPENAI_API_KEY
+
+
+class UrlUpdateForm(BaseModel):
+    url: str
+
+
+class KeyUpdateForm(BaseModel):
+    key: str
+
+
+@app.get("/url")
+async def get_openai_url(user=Depends(get_current_user)):
+    if user and user.role == "admin":
+        return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
+    else:
+        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+
+
+@app.post("/url/update")
+async def update_openai_url(form_data: UrlUpdateForm, user=Depends(get_current_user)):
+    if user and user.role == "admin":
+        app.state.OPENAI_API_BASE_URL = form_data.url
+        return {"OPENAI_API_BASE_URL": app.state.OPENAI_API_BASE_URL}
+    else:
+        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+
+
+@app.get("/key")
+async def get_openai_key(user=Depends(get_current_user)):
+    if user and user.role == "admin":
+        return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
+    else:
+        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+
+
+@app.post("/key/update")
+async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_current_user)):
+    if user and user.role == "admin":
+        app.state.OPENAI_API_KEY = form_data.key
+        return {"OPENAI_API_KEY": app.state.OPENAI_API_KEY}
+    else:
+        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+
+
+@app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
+async def proxy(path: str, request: Request, user=Depends(get_current_user)):
+    target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}"
+    print(target_url, app.state.OPENAI_API_KEY)
+
+    if user.role not in ["user", "admin"]:
+        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+    if app.state.OPENAI_API_KEY == "":
+        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
+
+    body = await request.body()
+    # headers = dict(request.headers)
+    # print(headers)
+
+    headers = {}
+    headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
+    headers["Content-Type"] = "application/json"
+
+    try:
+        r = requests.request(
+            method=request.method,
+            url=target_url,
+            data=body,
+            headers=headers,
+            stream=True,
+        )
+
+        r.raise_for_status()
+
+        # Check if response is SSE
+        if "text/event-stream" in r.headers.get("Content-Type", ""):
+            return StreamingResponse(
+                r.iter_content(chunk_size=8192),
+                status_code=r.status_code,
+                headers=dict(r.headers),
+            )
+        else:
+            # For non-SSE, read the response and return it
+            # response_data = (
+            #     r.json()
+            #     if r.headers.get("Content-Type", "")
+            #     == "application/json"
+            #     else r.text
+            # )
+
+            response_data = r.json()
+
+            print(type(response_data))
+
+            if "openai" in app.state.OPENAI_API_BASE_URL and path == "models":
+                response_data["data"] = list(
+                    filter(lambda model: "gpt" in model["id"], response_data["data"])
+                )
+
+            return response_data
+    except Exception as e:
+        print(e)
+        error_detail = "Ollama WebUI: Server Connection Error"
+        if r is not None:
+            try:
+                res = r.json()
+                if "error" in res:
+                    error_detail = f"External: {res['error']}"
+            except:
+                error_detail = f"External: {e}"
+
+        raise HTTPException(status_code=r.status_code, detail=error_detail)

+ 12 - 1
backend/config.py

@@ -26,11 +26,22 @@ if ENV == "prod":
     if OLLAMA_API_BASE_URL == "/ollama/api":
         OLLAMA_API_BASE_URL = "http://host.docker.internal:11434/api"
 
+
+####################################
+# OPENAI_API
+####################################
+
+OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
+OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "")
+
+if OPENAI_API_BASE_URL == "":
+    OPENAI_API_BASE_URL = "https://api.openai.com/v1"
+
 ####################################
 # WEBUI_VERSION
 ####################################
 
-WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.42")
+WEBUI_VERSION = os.environ.get("WEBUI_VERSION", "v1.0.0-alpha.50")
 
 ####################################
 # WEBUI_AUTH (Required for security)

+ 1 - 0
backend/constants.py

@@ -33,4 +33,5 @@ class ERROR_MESSAGES(str, Enum):
         "The requested action has been restricted as a security measure.")
     NOT_FOUND = "We could not find what you're looking for :/"
     USER_NOT_FOUND = "We could not find what you're looking for :/"
+    API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."
     MALICIOUS = "Unusual activities detected, please try again in a few minutes."

+ 3 - 1
backend/main.py

@@ -6,6 +6,8 @@ from fastapi.middleware.cors import CORSMiddleware
 from starlette.exceptions import HTTPException as StarletteHTTPException
 
 from apps.ollama.main import app as ollama_app
+from apps.openai.main import app as openai_app
+
 from apps.web.main import app as webui_app
 
 import time
@@ -47,8 +49,8 @@ async def check_url(request: Request, call_next):
 
 
 app.mount("/api/v1", webui_app)
-# app.mount("/ollama/api", WSGIMiddleware(ollama_app))
 app.mount("/ollama/api", ollama_app)
+app.mount("/openai/api", openai_app)
 
 app.mount("/",
           SPAStaticFiles(directory="../build", html=True),

+ 3 - 9
example.env

@@ -1,12 +1,6 @@
-# If you're serving both the frontend and backend (Recommended)
-# Set the public API base URL for seamless communication
-PUBLIC_API_BASE_URL='/ollama/api'
-
-# If you're serving only the frontend (Not recommended and not fully supported)
-# Comment above and Uncomment below
-# You can use the default value or specify a custom path, e.g., '/api'
-# PUBLIC_API_BASE_URL='http://{location.hostname}:11434/api'
-
 # Ollama URL for the backend to connect
 # The path '/ollama/api' will be redirected to the specified backend URL
 OLLAMA_API_BASE_URL='http://localhost:11434/api'
+
+OPENAI_API_BASE_URL=''
+OPENAI_API_KEY=''

+ 196 - 1
src/lib/apis/openai/index.ts

@@ -1,4 +1,176 @@
-export const getOpenAIModels = async (
+import { OPENAI_API_BASE_URL } from '$lib/constants';
+
+export const getOpenAIUrl = async (token: string = '') => {
+	let error = null;
+
+	const res = await fetch(`${OPENAI_API_BASE_URL}/url`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			} else {
+				error = 'Server connection failed';
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res.OPENAI_API_BASE_URL;
+};
+
+export const updateOpenAIUrl = async (token: string = '', url: string) => {
+	let error = null;
+
+	const res = await fetch(`${OPENAI_API_BASE_URL}/url/update`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		},
+		body: JSON.stringify({
+			url: url
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			} else {
+				error = 'Server connection failed';
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res.OPENAI_API_BASE_URL;
+};
+
+export const getOpenAIKey = async (token: string = '') => {
+	let error = null;
+
+	const res = await fetch(`${OPENAI_API_BASE_URL}/key`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			} else {
+				error = 'Server connection failed';
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res.OPENAI_API_KEY;
+};
+
+export const updateOpenAIKey = async (token: string = '', key: string) => {
+	let error = null;
+
+	const res = await fetch(`${OPENAI_API_BASE_URL}/key/update`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		},
+		body: JSON.stringify({
+			key: key
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			} else {
+				error = 'Server connection failed';
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res.OPENAI_API_KEY;
+};
+
+export const getOpenAIModels = async (token: string = '') => {
+	let error = null;
+
+	const res = await fetch(`${OPENAI_API_BASE_URL}/models`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			error = `OpenAI: ${err?.error?.message ?? 'Network Problem'}`;
+			return [];
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	const models = Array.isArray(res) ? res : res?.data ?? null;
+
+	return models
+		? models
+				.map((model) => ({ name: model.id, external: true }))
+				.sort((a, b) => {
+					return a.name.localeCompare(b.name);
+				})
+		: models;
+};
+
+export const getOpenAIModelsDirect = async (
 	base_url: string = 'https://api.openai.com/v1',
 	api_key: string = ''
 ) => {
@@ -34,3 +206,26 @@ export const getOpenAIModels = async (
 			return a.name.localeCompare(b.name);
 		});
 };
+
+export const generateOpenAIChatCompletion = async (token: string = '', body: object) => {
+	let error = null;
+
+	const res = await fetch(`${OPENAI_API_BASE_URL}/chat/completions`, {
+		method: 'POST',
+		headers: {
+			Authorization: `Bearer ${token}`,
+			'Content-Type': 'application/json'
+		},
+		body: JSON.stringify(body)
+	}).catch((err) => {
+		console.log(err);
+		error = err;
+		return null;
+	});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};

+ 1 - 1
src/lib/components/chat/Messages/Placeholder.svelte

@@ -27,7 +27,7 @@
 					>
 						{#if model in modelfiles}
 							<img
-								src={modelfiles[model]?.imageUrl}
+								src={modelfiles[model]?.imageUrl ?? '/ollama-dark.png'}
 								alt="modelfile"
 								class=" w-20 mb-2 rounded-full {models.length > 1
 									? ' border-[5px] border-white dark:border-gray-800'

+ 53 - 68
src/lib/components/chat/SettingsModal.svelte

@@ -24,6 +24,13 @@
 	import { updateUserPassword } from '$lib/apis/auths';
 	import { goto } from '$app/navigation';
 	import Page from '../../../routes/(app)/+page.svelte';
+	import {
+		getOpenAIKey,
+		getOpenAIModels,
+		getOpenAIUrl,
+		updateOpenAIKey,
+		updateOpenAIUrl
+	} from '$lib/apis/openai';
 
 	export let show = false;
 
@@ -153,6 +160,13 @@
 		}
 	};
 
+	const updateOpenAIHandler = async () => {
+		OPENAI_API_BASE_URL = await updateOpenAIUrl(localStorage.token, OPENAI_API_BASE_URL);
+		OPENAI_API_KEY = await updateOpenAIKey(localStorage.token, OPENAI_API_KEY);
+
+		await models.set(await getModels());
+	};
+
 	const toggleTheme = async () => {
 		if (theme === 'dark') {
 			theme = 'light';
@@ -484,7 +498,7 @@
 	};
 
 	const getModels = async (type = 'all') => {
-		let models = [];
+		const models = [];
 		models.push(
 			...(await getOllamaModels(localStorage.token).catch((error) => {
 				toast.error(error);
@@ -493,43 +507,13 @@
 		);
 
 		// If OpenAI API Key exists
-		if (type === 'all' && $settings.OPENAI_API_KEY) {
-			const OPENAI_API_BASE_URL = $settings.OPENAI_API_BASE_URL ?? 'https://api.openai.com/v1';
+		if (type === 'all' && OPENAI_API_KEY) {
+			const openAIModels = await getOpenAIModels(localStorage.token).catch((error) => {
+				console.log(error);
+				return null;
+			});
 
-			// Validate OPENAI_API_KEY
-			const openaiModelRes = await fetch(`${OPENAI_API_BASE_URL}/models`, {
-				method: 'GET',
-				headers: {
-					'Content-Type': 'application/json',
-					Authorization: `Bearer ${$settings.OPENAI_API_KEY}`
-				}
-			})
-				.then(async (res) => {
-					if (!res.ok) throw await res.json();
-					return res.json();
-				})
-				.catch((error) => {
-					console.log(error);
-					toast.error(`OpenAI: ${error?.error?.message ?? 'Network Problem'}`);
-					return null;
-				});
-
-			const openAIModels = Array.isArray(openaiModelRes)
-				? openaiModelRes
-				: openaiModelRes?.data ?? null;
-
-			models.push(
-				...(openAIModels
-					? [
-							{ name: 'hr' },
-							...openAIModels
-								.map((model) => ({ name: model.id, external: true }))
-								.filter((model) =>
-									OPENAI_API_BASE_URL.includes('openai') ? model.name.includes('gpt') : true
-								)
-					  ]
-					: [])
-			);
+			models.push(...(openAIModels ? [{ name: 'hr' }, ...openAIModels] : []));
 		}
 
 		return models;
@@ -564,6 +548,8 @@
 		console.log('settings', $user.role === 'admin');
 		if ($user.role === 'admin') {
 			API_BASE_URL = await getOllamaAPIUrl(localStorage.token);
+			OPENAI_API_BASE_URL = await getOpenAIUrl(localStorage.token);
+			OPENAI_API_KEY = await getOpenAIKey(localStorage.token);
 		}
 
 		let settings = JSON.parse(localStorage.getItem('settings') ?? '{}');
@@ -584,9 +570,6 @@
 		options = { ...options, ...settings.options };
 		options.stop = (settings?.options?.stop ?? []).join(',');
 
-		OPENAI_API_KEY = settings.OPENAI_API_KEY ?? '';
-		OPENAI_API_BASE_URL = settings.OPENAI_API_BASE_URL ?? 'https://api.openai.com/v1';
-
 		titleAutoGenerate = settings.titleAutoGenerate ?? true;
 		speechAutoSend = settings.speechAutoSend ?? false;
 		responseAutoCopy = settings.responseAutoCopy ?? false;
@@ -709,31 +692,31 @@
 						</div>
 						<div class=" self-center">Models</div>
 					</button>
-				{/if}
 
-				<button
-					class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
-					'external'
-						? 'bg-gray-200 dark:bg-gray-700'
-						: ' hover:bg-gray-300 dark:hover:bg-gray-800'}"
-					on:click={() => {
-						selectedTab = 'external';
-					}}
-				>
-					<div class=" self-center mr-2">
-						<svg
-							xmlns="http://www.w3.org/2000/svg"
-							viewBox="0 0 16 16"
-							fill="currentColor"
-							class="w-4 h-4"
-						>
-							<path
-								d="M1 9.5A3.5 3.5 0 0 0 4.5 13H12a3 3 0 0 0 .917-5.857 2.503 2.503 0 0 0-3.198-3.019 3.5 3.5 0 0 0-6.628 2.171A3.5 3.5 0 0 0 1 9.5Z"
-							/>
-						</svg>
-					</div>
-					<div class=" self-center">External</div>
-				</button>
+					<button
+						class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
+						'external'
+							? 'bg-gray-200 dark:bg-gray-700'
+							: ' hover:bg-gray-300 dark:hover:bg-gray-800'}"
+						on:click={() => {
+							selectedTab = 'external';
+						}}
+					>
+						<div class=" self-center mr-2">
+							<svg
+								xmlns="http://www.w3.org/2000/svg"
+								viewBox="0 0 16 16"
+								fill="currentColor"
+								class="w-4 h-4"
+							>
+								<path
+									d="M1 9.5A3.5 3.5 0 0 0 4.5 13H12a3 3 0 0 0 .917-5.857 2.503 2.503 0 0 0-3.198-3.019 3.5 3.5 0 0 0-6.628 2.171A3.5 3.5 0 0 0 1 9.5Z"
+								/>
+							</svg>
+						</div>
+						<div class=" self-center">External</div>
+					</button>
+				{/if}
 
 				<button
 					class="px-2.5 py-2.5 min-w-fit rounded-lg flex-1 md:flex-none flex text-right transition {selectedTab ===
@@ -1415,10 +1398,12 @@
 					<form
 						class="flex flex-col h-full justify-between space-y-3 text-sm"
 						on:submit|preventDefault={() => {
-							saveSettings({
-								OPENAI_API_KEY: OPENAI_API_KEY !== '' ? OPENAI_API_KEY : undefined,
-								OPENAI_API_BASE_URL: OPENAI_API_BASE_URL !== '' ? OPENAI_API_BASE_URL : undefined
-							});
+							updateOpenAIHandler();
+
+							// saveSettings({
+							// 	OPENAI_API_KEY: OPENAI_API_KEY !== '' ? OPENAI_API_KEY : undefined,
+							// 	OPENAI_API_BASE_URL: OPENAI_API_BASE_URL !== '' ? OPENAI_API_BASE_URL : undefined
+							// });
 							show = false;
 						}}
 					>

+ 3 - 4
src/lib/constants.ts

@@ -1,11 +1,10 @@
 import { dev } from '$app/environment';
 
-export const OLLAMA_API_BASE_URL = dev
-	? `http://${location.hostname}:8080/ollama/api`
-	: '/ollama/api';
-
 export const WEBUI_BASE_URL = dev ? `http://${location.hostname}:8080` : ``;
+
 export const WEBUI_API_BASE_URL = `${WEBUI_BASE_URL}/api/v1`;
+export const OLLAMA_API_BASE_URL = `${WEBUI_BASE_URL}/ollama/api`;
+export const OPENAI_API_BASE_URL = `${WEBUI_BASE_URL}/openai/api`;
 
 export const WEB_UI_VERSION = 'v1.0.0-alpha-static';
 

+ 10 - 12
src/routes/(app)/+layout.svelte

@@ -37,19 +37,17 @@
 				return [];
 			}))
 		);
-		// If OpenAI API Key exists
-		if ($settings.OPENAI_API_KEY) {
-			const openAIModels = await getOpenAIModels(
-				$settings.OPENAI_API_BASE_URL ?? 'https://api.openai.com/v1',
-				$settings.OPENAI_API_KEY
-			).catch((error) => {
-				console.log(error);
-				toast.error(error);
-				return null;
-			});
 
-			models.push(...(openAIModels ? [{ name: 'hr' }, ...openAIModels] : []));
-		}
+		// $settings.OPENAI_API_BASE_URL ?? 'https://api.openai.com/v1',
+		// 		$settings.OPENAI_API_KEY
+
+		const openAIModels = await getOpenAIModels(localStorage.token).catch((error) => {
+			console.log(error);
+			return null;
+		});
+
+		models.push(...(openAIModels ? [{ name: 'hr' }, ...openAIModels] : []));
+
 		return models;
 	};
 

+ 145 - 161
src/routes/(app)/+page.svelte

@@ -16,6 +16,7 @@
 	import ModelSelector from '$lib/components/chat/ModelSelector.svelte';
 	import Navbar from '$lib/components/layout/Navbar.svelte';
 	import { createNewChat, getChatList, updateChatById } from '$lib/apis/chats';
+	import { generateOpenAIChatCompletion } from '$lib/apis/openai';
 
 	let stopResponseFlag = false;
 	let autoScroll = true;
@@ -321,188 +322,171 @@
 	};
 
 	const sendPromptOpenAI = async (model, userPrompt, parentId, _chatId) => {
-		if ($settings.OPENAI_API_KEY) {
-			if (models) {
-				let responseMessageId = uuidv4();
-
-				let responseMessage = {
-					parentId: parentId,
-					id: responseMessageId,
-					childrenIds: [],
-					role: 'assistant',
-					content: '',
-					model: model
-				};
-
-				history.messages[responseMessageId] = responseMessage;
-				history.currentId = responseMessageId;
-				if (parentId !== null) {
-					history.messages[parentId].childrenIds = [
-						...history.messages[parentId].childrenIds,
-						responseMessageId
-					];
-				}
+		let responseMessageId = uuidv4();
 
-				window.scrollTo({ top: document.body.scrollHeight });
-
-				const res = await fetch(
-					`${$settings.OPENAI_API_BASE_URL ?? 'https://api.openai.com/v1'}/chat/completions`,
-					{
-						method: 'POST',
-						headers: {
-							Authorization: `Bearer ${$settings.OPENAI_API_KEY}`,
-							'Content-Type': 'application/json'
-						},
-						body: JSON.stringify({
-							model: model,
-							stream: true,
-							messages: [
-								$settings.system
-									? {
-											role: 'system',
-											content: $settings.system
-									  }
-									: undefined,
-								...messages
-							]
-								.filter((message) => message)
-								.map((message) => ({
-									role: message.role,
-									...(message.files
-										? {
-												content: [
-													{
-														type: 'text',
-														text: message.content
-													},
-													...message.files
-														.filter((file) => file.type === 'image')
-														.map((file) => ({
-															type: 'image_url',
-															image_url: {
-																url: file.url
-															}
-														}))
-												]
-										  }
-										: { content: message.content })
-								})),
-							seed: $settings?.options?.seed ?? undefined,
-							stop: $settings?.options?.stop ?? undefined,
-							temperature: $settings?.options?.temperature ?? undefined,
-							top_p: $settings?.options?.top_p ?? undefined,
-							num_ctx: $settings?.options?.num_ctx ?? undefined,
-							frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
-							max_tokens: $settings?.options?.num_predict ?? undefined
-						})
-					}
-				).catch((err) => {
-					console.log(err);
-					return null;
-				});
+		let responseMessage = {
+			parentId: parentId,
+			id: responseMessageId,
+			childrenIds: [],
+			role: 'assistant',
+			content: '',
+			model: model
+		};
 
-				if (res && res.ok) {
-					const reader = res.body
-						.pipeThrough(new TextDecoderStream())
-						.pipeThrough(splitStream('\n'))
-						.getReader();
-
-					while (true) {
-						const { value, done } = await reader.read();
-						if (done || stopResponseFlag || _chatId !== $chatId) {
-							responseMessage.done = true;
-							messages = messages;
-							break;
-						}
+		history.messages[responseMessageId] = responseMessage;
+		history.currentId = responseMessageId;
+		if (parentId !== null) {
+			history.messages[parentId].childrenIds = [
+				...history.messages[parentId].childrenIds,
+				responseMessageId
+			];
+		}
 
-						try {
-							let lines = value.split('\n');
-
-							for (const line of lines) {
-								if (line !== '') {
-									console.log(line);
-									if (line === 'data: [DONE]') {
-										responseMessage.done = true;
-										messages = messages;
-									} else {
-										let data = JSON.parse(line.replace(/^data: /, ''));
-										console.log(data);
-
-										if (responseMessage.content == '' && data.choices[0].delta.content == '\n') {
-											continue;
-										} else {
-											responseMessage.content += data.choices[0].delta.content ?? '';
-											messages = messages;
-										}
-									}
-								}
-							}
-						} catch (error) {
-							console.log(error);
-						}
+		window.scrollTo({ top: document.body.scrollHeight });
 
-						if ($settings.notificationEnabled && !document.hasFocus()) {
-							const notification = new Notification(`OpenAI ${model}`, {
-								body: responseMessage.content,
-								icon: '/favicon.png'
-							});
-						}
+		const res = await generateOpenAIChatCompletion(localStorage.token, {
+			model: model,
+			stream: true,
+			messages: [
+				$settings.system
+					? {
+							role: 'system',
+							content: $settings.system
+					  }
+					: undefined,
+				...messages
+			]
+				.filter((message) => message)
+				.map((message) => ({
+					role: message.role,
+					...(message.files
+						? {
+								content: [
+									{
+										type: 'text',
+										text: message.content
+									},
+									...message.files
+										.filter((file) => file.type === 'image')
+										.map((file) => ({
+											type: 'image_url',
+											image_url: {
+												url: file.url
+											}
+										}))
+								]
+						  }
+						: { content: message.content })
+				})),
+			seed: $settings?.options?.seed ?? undefined,
+			stop: $settings?.options?.stop ?? undefined,
+			temperature: $settings?.options?.temperature ?? undefined,
+			top_p: $settings?.options?.top_p ?? undefined,
+			num_ctx: $settings?.options?.num_ctx ?? undefined,
+			frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
+			max_tokens: $settings?.options?.num_predict ?? undefined
+		});
 
-						if ($settings.responseAutoCopy) {
-							copyToClipboard(responseMessage.content);
-						}
+		if (res && res.ok) {
+			const reader = res.body
+				.pipeThrough(new TextDecoderStream())
+				.pipeThrough(splitStream('\n'))
+				.getReader();
 
-						if (autoScroll) {
-							window.scrollTo({ top: document.body.scrollHeight });
-						}
-					}
+			while (true) {
+				const { value, done } = await reader.read();
+				if (done || stopResponseFlag || _chatId !== $chatId) {
+					responseMessage.done = true;
+					messages = messages;
+					break;
+				}
 
-					if ($chatId == _chatId) {
-						chat = await updateChatById(localStorage.token, _chatId, {
-							messages: messages,
-							history: history
-						});
-						await chats.set(await getChatList(localStorage.token));
-					}
-				} else {
-					if (res !== null) {
-						const error = await res.json();
-						console.log(error);
-						if ('detail' in error) {
-							toast.error(error.detail);
-							responseMessage.content = error.detail;
-						} else {
-							if ('message' in error.error) {
-								toast.error(error.error.message);
-								responseMessage.content = error.error.message;
+				try {
+					let lines = value.split('\n');
+
+					for (const line of lines) {
+						if (line !== '') {
+							console.log(line);
+							if (line === 'data: [DONE]') {
+								responseMessage.done = true;
+								messages = messages;
 							} else {
-								toast.error(error.error);
-								responseMessage.content = error.error;
+								let data = JSON.parse(line.replace(/^data: /, ''));
+								console.log(data);
+
+								if (responseMessage.content == '' && data.choices[0].delta.content == '\n') {
+									continue;
+								} else {
+									responseMessage.content += data.choices[0].delta.content ?? '';
+									messages = messages;
+								}
 							}
 						}
-					} else {
-						toast.error(`Uh-oh! There was an issue connecting to ${model}.`);
-						responseMessage.content = `Uh-oh! There was an issue connecting to ${model}.`;
 					}
+				} catch (error) {
+					console.log(error);
+				}
 
-					responseMessage.error = true;
-					responseMessage.content = `Uh-oh! There was an issue connecting to ${model}.`;
-					responseMessage.done = true;
-					messages = messages;
+				if ($settings.notificationEnabled && !document.hasFocus()) {
+					const notification = new Notification(`OpenAI ${model}`, {
+						body: responseMessage.content,
+						icon: '/favicon.png'
+					});
 				}
 
-				stopResponseFlag = false;
-				await tick();
+				if ($settings.responseAutoCopy) {
+					copyToClipboard(responseMessage.content);
+				}
 
 				if (autoScroll) {
 					window.scrollTo({ top: document.body.scrollHeight });
 				}
+			}
 
-				if (messages.length == 2) {
-					window.history.replaceState(history.state, '', `/c/${_chatId}`);
-					await setChatTitle(_chatId, userPrompt);
+			if ($chatId == _chatId) {
+				chat = await updateChatById(localStorage.token, _chatId, {
+					messages: messages,
+					history: history
+				});
+				await chats.set(await getChatList(localStorage.token));
+			}
+		} else {
+			if (res !== null) {
+				const error = await res.json();
+				console.log(error);
+				if ('detail' in error) {
+					toast.error(error.detail);
+					responseMessage.content = error.detail;
+				} else {
+					if ('message' in error.error) {
+						toast.error(error.error.message);
+						responseMessage.content = error.error.message;
+					} else {
+						toast.error(error.error);
+						responseMessage.content = error.error;
+					}
 				}
+			} else {
+				toast.error(`Uh-oh! There was an issue connecting to ${model}.`);
+				responseMessage.content = `Uh-oh! There was an issue connecting to ${model}.`;
 			}
+
+			responseMessage.error = true;
+			responseMessage.content = `Uh-oh! There was an issue connecting to ${model}.`;
+			responseMessage.done = true;
+			messages = messages;
+		}
+
+		stopResponseFlag = false;
+		await tick();
+
+		if (autoScroll) {
+			window.scrollTo({ top: document.body.scrollHeight });
+		}
+
+		if (messages.length == 2) {
+			window.history.replaceState(history.state, '', `/c/${_chatId}`);
+			await setChatTitle(_chatId, userPrompt);
 		}
 	};
 

+ 146 - 161
src/routes/(app)/c/[id]/+page.svelte

@@ -9,6 +9,8 @@
 	import { models, modelfiles, user, settings, chats, chatId } from '$lib/stores';
 
 	import { generateChatCompletion, generateTitle } from '$lib/apis/ollama';
+	import { generateOpenAIChatCompletion } from '$lib/apis/openai';
+
 	import { copyToClipboard, splitStream } from '$lib/utils';
 
 	import MessageInput from '$lib/components/chat/MessageInput.svelte';
@@ -338,188 +340,171 @@
 	};
 
 	const sendPromptOpenAI = async (model, userPrompt, parentId, _chatId) => {
-		if ($settings.OPENAI_API_KEY) {
-			if (models) {
-				let responseMessageId = uuidv4();
-
-				let responseMessage = {
-					parentId: parentId,
-					id: responseMessageId,
-					childrenIds: [],
-					role: 'assistant',
-					content: '',
-					model: model
-				};
-
-				history.messages[responseMessageId] = responseMessage;
-				history.currentId = responseMessageId;
-				if (parentId !== null) {
-					history.messages[parentId].childrenIds = [
-						...history.messages[parentId].childrenIds,
-						responseMessageId
-					];
-				}
+		let responseMessageId = uuidv4();
 
-				window.scrollTo({ top: document.body.scrollHeight });
-
-				const res = await fetch(
-					`${$settings.OPENAI_API_BASE_URL ?? 'https://api.openai.com/v1'}/chat/completions`,
-					{
-						method: 'POST',
-						headers: {
-							Authorization: `Bearer ${$settings.OPENAI_API_KEY}`,
-							'Content-Type': 'application/json'
-						},
-						body: JSON.stringify({
-							model: model,
-							stream: true,
-							messages: [
-								$settings.system
-									? {
-											role: 'system',
-											content: $settings.system
-									  }
-									: undefined,
-								...messages
-							]
-								.filter((message) => message)
-								.map((message) => ({
-									role: message.role,
-									...(message.files
-										? {
-												content: [
-													{
-														type: 'text',
-														text: message.content
-													},
-													...message.files
-														.filter((file) => file.type === 'image')
-														.map((file) => ({
-															type: 'image_url',
-															image_url: {
-																url: file.url
-															}
-														}))
-												]
-										  }
-										: { content: message.content })
-								})),
-							seed: $settings?.options?.seed ?? undefined,
-							stop: $settings?.options?.stop ?? undefined,
-							temperature: $settings?.options?.temperature ?? undefined,
-							top_p: $settings?.options?.top_p ?? undefined,
-							num_ctx: $settings?.options?.num_ctx ?? undefined,
-							frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
-							max_tokens: $settings?.options?.num_predict ?? undefined
-						})
-					}
-				).catch((err) => {
-					console.log(err);
-					return null;
-				});
+		let responseMessage = {
+			parentId: parentId,
+			id: responseMessageId,
+			childrenIds: [],
+			role: 'assistant',
+			content: '',
+			model: model
+		};
 
-				if (res && res.ok) {
-					const reader = res.body
-						.pipeThrough(new TextDecoderStream())
-						.pipeThrough(splitStream('\n'))
-						.getReader();
-
-					while (true) {
-						const { value, done } = await reader.read();
-						if (done || stopResponseFlag || _chatId !== $chatId) {
-							responseMessage.done = true;
-							messages = messages;
-							break;
-						}
+		history.messages[responseMessageId] = responseMessage;
+		history.currentId = responseMessageId;
+		if (parentId !== null) {
+			history.messages[parentId].childrenIds = [
+				...history.messages[parentId].childrenIds,
+				responseMessageId
+			];
+		}
 
-						try {
-							let lines = value.split('\n');
-
-							for (const line of lines) {
-								if (line !== '') {
-									console.log(line);
-									if (line === 'data: [DONE]') {
-										responseMessage.done = true;
-										messages = messages;
-									} else {
-										let data = JSON.parse(line.replace(/^data: /, ''));
-										console.log(data);
-
-										if (responseMessage.content == '' && data.choices[0].delta.content == '\n') {
-											continue;
-										} else {
-											responseMessage.content += data.choices[0].delta.content ?? '';
-											messages = messages;
-										}
-									}
-								}
-							}
-						} catch (error) {
-							console.log(error);
-						}
+		window.scrollTo({ top: document.body.scrollHeight });
 
-						if ($settings.notificationEnabled && !document.hasFocus()) {
-							const notification = new Notification(`OpenAI ${model}`, {
-								body: responseMessage.content,
-								icon: '/favicon.png'
-							});
-						}
+		const res = await generateOpenAIChatCompletion(localStorage.token, {
+			model: model,
+			stream: true,
+			messages: [
+				$settings.system
+					? {
+							role: 'system',
+							content: $settings.system
+					  }
+					: undefined,
+				...messages
+			]
+				.filter((message) => message)
+				.map((message) => ({
+					role: message.role,
+					...(message.files
+						? {
+								content: [
+									{
+										type: 'text',
+										text: message.content
+									},
+									...message.files
+										.filter((file) => file.type === 'image')
+										.map((file) => ({
+											type: 'image_url',
+											image_url: {
+												url: file.url
+											}
+										}))
+								]
+						  }
+						: { content: message.content })
+				})),
+			seed: $settings?.options?.seed ?? undefined,
+			stop: $settings?.options?.stop ?? undefined,
+			temperature: $settings?.options?.temperature ?? undefined,
+			top_p: $settings?.options?.top_p ?? undefined,
+			num_ctx: $settings?.options?.num_ctx ?? undefined,
+			frequency_penalty: $settings?.options?.repeat_penalty ?? undefined,
+			max_tokens: $settings?.options?.num_predict ?? undefined
+		});
 
-						if ($settings.responseAutoCopy) {
-							copyToClipboard(responseMessage.content);
-						}
+		if (res && res.ok) {
+			const reader = res.body
+				.pipeThrough(new TextDecoderStream())
+				.pipeThrough(splitStream('\n'))
+				.getReader();
 
-						if (autoScroll) {
-							window.scrollTo({ top: document.body.scrollHeight });
-						}
-					}
+			while (true) {
+				const { value, done } = await reader.read();
+				if (done || stopResponseFlag || _chatId !== $chatId) {
+					responseMessage.done = true;
+					messages = messages;
+					break;
+				}
 
-					if ($chatId == _chatId) {
-						chat = await updateChatById(localStorage.token, _chatId, {
-							messages: messages,
-							history: history
-						});
-						await chats.set(await getChatList(localStorage.token));
-					}
-				} else {
-					if (res !== null) {
-						const error = await res.json();
-						console.log(error);
-						if ('detail' in error) {
-							toast.error(error.detail);
-							responseMessage.content = error.detail;
-						} else {
-							if ('message' in error.error) {
-								toast.error(error.error.message);
-								responseMessage.content = error.error.message;
+				try {
+					let lines = value.split('\n');
+
+					for (const line of lines) {
+						if (line !== '') {
+							console.log(line);
+							if (line === 'data: [DONE]') {
+								responseMessage.done = true;
+								messages = messages;
 							} else {
-								toast.error(error.error);
-								responseMessage.content = error.error;
+								let data = JSON.parse(line.replace(/^data: /, ''));
+								console.log(data);
+
+								if (responseMessage.content == '' && data.choices[0].delta.content == '\n') {
+									continue;
+								} else {
+									responseMessage.content += data.choices[0].delta.content ?? '';
+									messages = messages;
+								}
 							}
 						}
-					} else {
-						toast.error(`Uh-oh! There was an issue connecting to ${model}.`);
-						responseMessage.content = `Uh-oh! There was an issue connecting to ${model}.`;
 					}
+				} catch (error) {
+					console.log(error);
+				}
 
-					responseMessage.error = true;
-					responseMessage.content = `Uh-oh! There was an issue connecting to ${model}.`;
-					responseMessage.done = true;
-					messages = messages;
+				if ($settings.notificationEnabled && !document.hasFocus()) {
+					const notification = new Notification(`OpenAI ${model}`, {
+						body: responseMessage.content,
+						icon: '/favicon.png'
+					});
 				}
 
-				stopResponseFlag = false;
-				await tick();
+				if ($settings.responseAutoCopy) {
+					copyToClipboard(responseMessage.content);
+				}
 
 				if (autoScroll) {
 					window.scrollTo({ top: document.body.scrollHeight });
 				}
+			}
 
-				if (messages.length == 2) {
-					window.history.replaceState(history.state, '', `/c/${_chatId}`);
-					await setChatTitle(_chatId, userPrompt);
+			if ($chatId == _chatId) {
+				chat = await updateChatById(localStorage.token, _chatId, {
+					messages: messages,
+					history: history
+				});
+				await chats.set(await getChatList(localStorage.token));
+			}
+		} else {
+			if (res !== null) {
+				const error = await res.json();
+				console.log(error);
+				if ('detail' in error) {
+					toast.error(error.detail);
+					responseMessage.content = error.detail;
+				} else {
+					if ('message' in error.error) {
+						toast.error(error.error.message);
+						responseMessage.content = error.error.message;
+					} else {
+						toast.error(error.error);
+						responseMessage.content = error.error;
+					}
 				}
+			} else {
+				toast.error(`Uh-oh! There was an issue connecting to ${model}.`);
+				responseMessage.content = `Uh-oh! There was an issue connecting to ${model}.`;
 			}
+
+			responseMessage.error = true;
+			responseMessage.content = `Uh-oh! There was an issue connecting to ${model}.`;
+			responseMessage.done = true;
+			messages = messages;
+		}
+
+		stopResponseFlag = false;
+		await tick();
+
+		if (autoScroll) {
+			window.scrollTo({ top: document.body.scrollHeight });
+		}
+
+		if (messages.length == 2) {
+			window.history.replaceState(history.state, '', `/c/${_chatId}`);
+			await setChatTitle(_chatId, userPrompt);
 		}
 	};