Browse Source

feat: openai compatible api support

Timothy J. Baek 1 year ago
parent
commit
17c66fde0f

+ 34 - 13
backend/apps/openai/main.py

@@ -1,6 +1,6 @@
 from fastapi import FastAPI, Request, Response, HTTPException, Depends
 from fastapi.middleware.cors import CORSMiddleware
-from fastapi.responses import StreamingResponse
+from fastapi.responses import StreamingResponse, JSONResponse
 
 import requests
 import json
@@ -69,18 +69,18 @@ async def update_openai_key(form_data: KeyUpdateForm, user=Depends(get_current_u
 @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
 async def proxy(path: str, request: Request, user=Depends(get_current_user)):
     target_url = f"{app.state.OPENAI_API_BASE_URL}/{path}"
-
-    body = await request.body()
-    headers = dict(request.headers)
+    print(target_url, app.state.OPENAI_API_KEY)
 
     if user.role not in ["user", "admin"]:
         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
+    if app.state.OPENAI_API_KEY == "":
+        raise HTTPException(status_code=401, detail=ERROR_MESSAGES.API_KEY_NOT_FOUND)
 
-    headers.pop("Host", None)
-    headers.pop("Authorization", None)
-    headers.pop("Origin", None)
-    headers.pop("Referer", None)
+    body = await request.body()
+    # headers = dict(request.headers)
+    # print(headers)
 
+    headers = {}
     headers["Authorization"] = f"Bearer {app.state.OPENAI_API_KEY}"
 
     try:
@@ -94,11 +94,32 @@ async def proxy(path: str, request: Request, user=Depends(get_current_user)):
 
         r.raise_for_status()
 
-        return StreamingResponse(
-            r.iter_content(chunk_size=8192),
-            status_code=r.status_code,
-            headers=dict(r.headers),
-        )
+        # Check if response is SSE
+        if "text/event-stream" in r.headers.get("Content-Type", ""):
+            return StreamingResponse(
+                r.iter_content(chunk_size=8192),
+                status_code=r.status_code,
+                headers=dict(r.headers),
+            )
+        else:
+            # For non-SSE, read the response and return it
+            # response_data = (
+            #     r.json()
+            #     if r.headers.get("Content-Type", "")
+            #     == "application/json"
+            #     else r.text
+            # )
+
+            response_data = r.json()
+
+            print(type(response_data))
+
+            if "openai" in app.state.OPENAI_API_BASE_URL and path == "models":
+                response_data["data"] = list(
+                    filter(lambda model: "gpt" in model["id"], response_data["data"])
+                )
+
+            return response_data
     except Exception as e:
         print(e)
         error_detail = "Ollama WebUI: Server Connection Error"

+ 4 - 1
backend/config.py

@@ -33,7 +33,10 @@ if ENV == "prod":
 ####################################
 
 OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
-OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "https://api.openai.com/v1")
+OPENAI_API_BASE_URL = os.environ.get("OPENAI_API_BASE_URL", "")
+
+if OPENAI_API_BASE_URL == "":
+    OPENAI_API_BASE_URL = "https://api.openai.com/v1"
 
 ####################################
 # WEBUI_VERSION

+ 1 - 0
backend/constants.py

@@ -33,4 +33,5 @@ class ERROR_MESSAGES(str, Enum):
     )
     NOT_FOUND = "We could not find what you're looking for :/"
     USER_NOT_FOUND = "We could not find what you're looking for :/"
+    API_KEY_NOT_FOUND = "Oops! It looks like there's a hiccup. The API key is missing. Please make sure to provide a valid API key to access this feature."
     MALICIOUS = "Unusual activities detected, please try again in a few minutes."

+ 3 - 1
backend/main.py

@@ -6,6 +6,8 @@ from fastapi.middleware.cors import CORSMiddleware
 from starlette.exceptions import HTTPException as StarletteHTTPException
 
 from apps.ollama.main import app as ollama_app
+from apps.openai.main import app as openai_app
+
 from apps.web.main import app as webui_app
 
 import time
@@ -46,7 +48,7 @@ async def check_url(request: Request, call_next):
 
 
 app.mount("/api/v1", webui_app)
-# app.mount("/ollama/api", WSGIMiddleware(ollama_app))
 app.mount("/ollama/api", ollama_app)
+app.mount("/openai/api", openai_app)
 
 app.mount("/", SPAStaticFiles(directory="../build", html=True), name="spa-static-files")

+ 3 - 9
example.env

@@ -1,12 +1,6 @@
-# If you're serving both the frontend and backend (Recommended)
-# Set the public API base URL for seamless communication
-PUBLIC_API_BASE_URL='/ollama/api'
-
-# If you're serving only the frontend (Not recommended and not fully supported)
-# Comment above and Uncomment below
-# You can use the default value or specify a custom path, e.g., '/api'
-# PUBLIC_API_BASE_URL='http://{location.hostname}:11434/api'
-
 # Ollama URL for the backend to connect
 # The path '/ollama/api' will be redirected to the specified backend URL
 OLLAMA_API_BASE_URL='http://localhost:11434/api'
+
+OPENAI_API_BASE_URL=''
+OPENAI_API_KEY=''

+ 173 - 1
src/lib/apis/openai/index.ts

@@ -1,4 +1,176 @@
-export const getOpenAIModels = async (
+import { OPENAI_API_BASE_URL } from '$lib/constants';
+
+export const getOpenAIUrl = async (token: string = '') => {
+	let error = null;
+
+	const res = await fetch(`${OPENAI_API_BASE_URL}/url`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			} else {
+				error = 'Server connection failed';
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res.OPENAI_API_BASE_URL;
+};
+
+export const updateOpenAIUrl = async (token: string = '', url: string) => {
+	let error = null;
+
+	const res = await fetch(`${OPENAI_API_BASE_URL}/url/update`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		},
+		body: JSON.stringify({
+			url: url
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			} else {
+				error = 'Server connection failed';
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res.OPENAI_API_BASE_URL;
+};
+
+export const getOpenAIKey = async (token: string = '') => {
+	let error = null;
+
+	const res = await fetch(`${OPENAI_API_BASE_URL}/key`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			} else {
+				error = 'Server connection failed';
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res.OPENAI_API_KEY;
+};
+
+export const updateOpenAIKey = async (token: string = '', key: string) => {
+	let error = null;
+
+	const res = await fetch(`${OPENAI_API_BASE_URL}/key/update`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		},
+		body: JSON.stringify({
+			key: key
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			if ('detail' in err) {
+				error = err.detail;
+			} else {
+				error = 'Server connection failed';
+			}
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res.OPENAI_API_KEY;
+};
+
+export const getOpenAIModels = async (token: string = '') => {
+	let error = null;
+
+	const res = await fetch(`${OPENAI_API_BASE_URL}/models`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			...(token && { authorization: `Bearer ${token}` })
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			console.log(err);
+			error = `OpenAI: ${err?.error?.message ?? 'Network Problem'}`;
+			return [];
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	const models = Array.isArray(res) ? res : res?.data ?? null;
+
+	return models
+		? models
+				.map((model) => ({ name: model.id, external: true }))
+				.sort((a, b) => {
+					return a.name.localeCompare(b.name);
+				})
+		: models;
+};
+
+export const getOpenAIModelsDirect = async (
 	base_url: string = 'https://api.openai.com/v1',
 	api_key: string = ''
 ) => {

+ 29 - 44
src/lib/components/chat/SettingsModal.svelte

@@ -24,6 +24,13 @@
 	import { updateUserPassword } from '$lib/apis/auths';
 	import { goto } from '$app/navigation';
 	import Page from '../../../routes/(app)/+page.svelte';
+	import {
+		getOpenAIKey,
+		getOpenAIModels,
+		getOpenAIUrl,
+		updateOpenAIKey,
+		updateOpenAIUrl
+	} from '$lib/apis/openai';
 
 	export let show = false;
 
@@ -153,6 +160,13 @@
 		}
 	};
 
+	const updateOpenAIHandler = async () => {
+		OPENAI_API_BASE_URL = await updateOpenAIUrl(localStorage.token, OPENAI_API_BASE_URL);
+		OPENAI_API_KEY = await updateOpenAIKey(localStorage.token, OPENAI_API_KEY);
+
+		await models.set(await getModels());
+	};
+
 	const toggleTheme = async () => {
 		if (theme === 'dark') {
 			theme = 'light';
@@ -484,7 +498,7 @@
 	};
 
 	const getModels = async (type = 'all') => {
-		let models = [];
+		const models = [];
 		models.push(
 			...(await getOllamaModels(localStorage.token).catch((error) => {
 				toast.error(error);
@@ -493,43 +507,13 @@
 		);
 
 		// If OpenAI API Key exists
-		if (type === 'all' && $settings.OPENAI_API_KEY) {
-			const OPENAI_API_BASE_URL = $settings.OPENAI_API_BASE_URL ?? 'https://api.openai.com/v1';
+		if (type === 'all' && OPENAI_API_KEY) {
+			const openAIModels = await getOpenAIModels(localStorage.token).catch((error) => {
+				console.log(error);
+				return null;
+			});
 
-			// Validate OPENAI_API_KEY
-			const openaiModelRes = await fetch(`${OPENAI_API_BASE_URL}/models`, {
-				method: 'GET',
-				headers: {
-					'Content-Type': 'application/json',
-					Authorization: `Bearer ${$settings.OPENAI_API_KEY}`
-				}
-			})
-				.then(async (res) => {
-					if (!res.ok) throw await res.json();
-					return res.json();
-				})
-				.catch((error) => {
-					console.log(error);
-					toast.error(`OpenAI: ${error?.error?.message ?? 'Network Problem'}`);
-					return null;
-				});
-
-			const openAIModels = Array.isArray(openaiModelRes)
-				? openaiModelRes
-				: openaiModelRes?.data ?? null;
-
-			models.push(
-				...(openAIModels
-					? [
-							{ name: 'hr' },
-							...openAIModels
-								.map((model) => ({ name: model.id, external: true }))
-								.filter((model) =>
-									OPENAI_API_BASE_URL.includes('openai') ? model.name.includes('gpt') : true
-								)
-					  ]
-					: [])
-			);
+			models.push(...(openAIModels ? [{ name: 'hr' }, ...openAIModels] : []));
 		}
 
 		return models;
@@ -564,6 +548,8 @@
 		console.log('settings', $user.role === 'admin');
 		if ($user.role === 'admin') {
 			API_BASE_URL = await getOllamaAPIUrl(localStorage.token);
+			OPENAI_API_BASE_URL = await getOpenAIUrl(localStorage.token);
+			OPENAI_API_KEY = await getOpenAIKey(localStorage.token);
 		}
 
 		let settings = JSON.parse(localStorage.getItem('settings') ?? '{}');
@@ -584,9 +570,6 @@
 		options = { ...options, ...settings.options };
 		options.stop = (settings?.options?.stop ?? []).join(',');
 
-		OPENAI_API_KEY = settings.OPENAI_API_KEY ?? '';
-		OPENAI_API_BASE_URL = settings.OPENAI_API_BASE_URL ?? 'https://api.openai.com/v1';
-
 		titleAutoGenerate = settings.titleAutoGenerate ?? true;
 		speechAutoSend = settings.speechAutoSend ?? false;
 		responseAutoCopy = settings.responseAutoCopy ?? false;
@@ -1415,10 +1398,12 @@
 					<form
 						class="flex flex-col h-full justify-between space-y-3 text-sm"
 						on:submit|preventDefault={() => {
-							saveSettings({
-								OPENAI_API_KEY: OPENAI_API_KEY !== '' ? OPENAI_API_KEY : undefined,
-								OPENAI_API_BASE_URL: OPENAI_API_BASE_URL !== '' ? OPENAI_API_BASE_URL : undefined
-							});
+							updateOpenAIHandler();
+
+							// saveSettings({
+							// 	OPENAI_API_KEY: OPENAI_API_KEY !== '' ? OPENAI_API_KEY : undefined,
+							// 	OPENAI_API_BASE_URL: OPENAI_API_BASE_URL !== '' ? OPENAI_API_BASE_URL : undefined
+							// });
 							show = false;
 						}}
 					>

+ 3 - 4
src/lib/constants.ts

@@ -1,11 +1,10 @@
 import { dev } from '$app/environment';
 
-export const OLLAMA_API_BASE_URL = dev
-	? `http://${location.hostname}:8080/ollama/api`
-	: '/ollama/api';
-
 export const WEBUI_BASE_URL = dev ? `http://${location.hostname}:8080` : ``;
+
 export const WEBUI_API_BASE_URL = `${WEBUI_BASE_URL}/api/v1`;
+export const OLLAMA_API_BASE_URL = `${WEBUI_BASE_URL}/ollama/api`;
+export const OPENAI_API_BASE_URL = `${WEBUI_BASE_URL}/openai/api`;
 
 export const WEB_UI_VERSION = 'v1.0.0-alpha-static';
 

+ 10 - 12
src/routes/(app)/+layout.svelte

@@ -37,19 +37,17 @@
 				return [];
 			}))
 		);
-		// If OpenAI API Key exists
-		if ($settings.OPENAI_API_KEY) {
-			const openAIModels = await getOpenAIModels(
-				$settings.OPENAI_API_BASE_URL ?? 'https://api.openai.com/v1',
-				$settings.OPENAI_API_KEY
-			).catch((error) => {
-				console.log(error);
-				toast.error(error);
-				return null;
-			});
 
-			models.push(...(openAIModels ? [{ name: 'hr' }, ...openAIModels] : []));
-		}
+		// $settings.OPENAI_API_BASE_URL ?? 'https://api.openai.com/v1',
+		// 		$settings.OPENAI_API_KEY
+
+		const openAIModels = await getOpenAIModels(localStorage.token).catch((error) => {
+			console.log(error);
+			return null;
+		});
+
+		models.push(...(openAIModels ? [{ name: 'hr' }, ...openAIModels] : []));
+
 		return models;
 	};