Browse Source

feat: modelfile backend & ollama version 0.0.0 whitelisted

Timothy J. Baek 1 year ago
parent
commit
032d7c7440

+ 10 - 2
backend/apps/web/models/modelfiles.py

@@ -42,6 +42,14 @@ class ModelfileForm(BaseModel):
     modelfile: dict
 
 
+class ModelfileTagNameForm(BaseModel):
+    tag_name: str
+
+
+class ModelfileUpdateForm(ModelfileForm, ModelfileTagNameForm):
+    pass
+
+
 class ModelfileResponse(BaseModel):
     tag_name: str
     user_id: str
@@ -57,11 +65,11 @@ class ModelfilesTable:
     def insert_new_modelfile(
         self, user_id: str, form_data: ModelfileForm
     ) -> Optional[ModelfileModel]:
-        if "title" in form_data.modelfile:
+        if "tagName" in form_data.modelfile:
             modelfile = ModelfileModel(
                 **{
                     "user_id": user_id,
-                    "tag_name": form_data.modelfile["title"],
+                    "tag_name": form_data.modelfile["tagName"],
                     "modelfile": json.dumps(form_data.modelfile),
                     "timestamp": int(time.time()),
                 }

+ 16 - 10
backend/apps/web/routers/modelfiles.py

@@ -11,6 +11,8 @@ from apps.web.models.users import Users
 from apps.web.models.modelfiles import (
     Modelfiles,
     ModelfileForm,
+    ModelfileTagNameForm,
+    ModelfileUpdateForm,
     ModelfileResponse,
 )
 
@@ -77,13 +79,15 @@ async def create_new_modelfile(form_data: ModelfileForm, cred=Depends(bearer_sch
 ############################
 
 
-@router.get("/{tag_name}", response_model=Optional[ModelfileResponse])
-async def get_modelfile_by_tag_name(tag_name: str, cred=Depends(bearer_scheme)):
+@router.post("/", response_model=Optional[ModelfileResponse])
+async def get_modelfile_by_tag_name(
+    form_data: ModelfileTagNameForm, cred=Depends(bearer_scheme)
+):
     token = cred.credentials
     user = Users.get_user_by_token(token)
 
     if user:
-        modelfile = Modelfiles.get_modelfile_by_tag_name(tag_name)
+        modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
 
         if modelfile:
             return ModelfileResponse(
@@ -109,16 +113,16 @@ async def get_modelfile_by_tag_name(tag_name: str, cred=Depends(bearer_scheme)):
 ############################
 
 
-@router.post("/{tag_name}", response_model=Optional[ModelfileResponse])
+@router.post("/update", response_model=Optional[ModelfileResponse])
 async def update_modelfile_by_tag_name(
-    tag_name: str, form_data: ModelfileForm, cred=Depends(bearer_scheme)
+    form_data: ModelfileUpdateForm, cred=Depends(bearer_scheme)
 ):
     token = cred.credentials
     user = Users.get_user_by_token(token)
 
     if user:
         if user.role == "admin":
-            modelfile = Modelfiles.get_modelfile_by_tag_name(tag_name)
+            modelfile = Modelfiles.get_modelfile_by_tag_name(form_data.tag_name)
             if modelfile:
                 updated_modelfile = {
                     **json.loads(modelfile.modelfile),
@@ -126,7 +130,7 @@ async def update_modelfile_by_tag_name(
                 }
 
                 modelfile = Modelfiles.update_modelfile_by_tag_name(
-                    tag_name, updated_modelfile
+                    form_data.tag_name, updated_modelfile
                 )
 
                 return ModelfileResponse(
@@ -157,14 +161,16 @@ async def update_modelfile_by_tag_name(
 ############################
 
 
-@router.delete("/{tag_name}", response_model=bool)
-async def delete_modelfile_by_tag_name(tag_name: str, cred=Depends(bearer_scheme)):
+@router.delete("/delete", response_model=bool)
+async def delete_modelfile_by_tag_name(
+    form_data: ModelfileTagNameForm, cred=Depends(bearer_scheme)
+):
     token = cred.credentials
     user = Users.get_user_by_token(token)
 
     if user:
         if user.role == "admin":
-            result = Modelfiles.delete_modelfile_by_tag_name(tag_name)
+            result = Modelfiles.delete_modelfile_by_tag_name(form_data.tag_name)
             return result
         else:
             raise HTTPException(

+ 173 - 0
src/lib/apis/modelfiles/index.ts

@@ -0,0 +1,173 @@
+import { WEBUI_API_BASE_URL } from '$lib/constants';
+
+export const createNewModelfile = async (token: string, modelfile: object) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/create`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			modelfile: modelfile
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.catch((err) => {
+			error = err;
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const getModelfiles = async (token: string = '') => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/`, {
+		method: 'GET',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		}
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err;
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res.map((modelfile) => modelfile.modelfile);
+};
+
+export const getModelfileByTagName = async (token: string, tagName: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			tag_name: tagName
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res.modelfile;
+};
+
+export const updateModelfileByTagName = async (
+	token: string,
+	tagName: string,
+	modelfile: object
+) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/update`, {
+		method: 'POST',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			tag_name: tagName,
+			modelfile: modelfile
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const deleteModelfileByTagName = async (token: string, tagName: string) => {
+	let error = null;
+
+	const res = await fetch(`${WEBUI_API_BASE_URL}/modelfiles/delete`, {
+		method: 'DELETE',
+		headers: {
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
+			authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			tag_name: tagName
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			return json;
+		})
+		.catch((err) => {
+			error = err;
+
+			console.log(err);
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};

+ 68 - 0
src/lib/apis/ollama/index.ts

@@ -134,3 +134,71 @@ export const generateChatCompletion = async (
 
 	return res;
 };
+
+export const createModel = async (
+	base_url: string = OLLAMA_API_BASE_URL,
+	token: string,
+	tagName: string,
+	content: string
+) => {
+	let error = null;
+
+	const res = await fetch(`${base_url}/create`, {
+		method: 'POST',
+		headers: {
+			'Content-Type': 'text/event-stream',
+			Authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			name: tagName,
+			modelfile: content
+		})
+	}).catch((err) => {
+		error = err;
+		return null;
+	});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};
+
+export const deleteModel = async (
+	base_url: string = OLLAMA_API_BASE_URL,
+	token: string,
+	tagName: string
+) => {
+	let error = null;
+
+	const res = await fetch(`${base_url}/delete`, {
+		method: 'DELETE',
+		headers: {
+			'Content-Type': 'text/event-stream',
+			Authorization: `Bearer ${token}`
+		},
+		body: JSON.stringify({
+			name: tagName
+		})
+	})
+		.then(async (res) => {
+			if (!res.ok) throw await res.json();
+			return res.json();
+		})
+		.then((json) => {
+			console.log(json);
+			return true;
+		})
+		.catch((err) => {
+			console.log(err);
+			error = err.error;
+			return null;
+		});
+
+	if (error) {
+		throw error;
+	}
+
+	return res;
+};

+ 7 - 7
src/lib/utils/index.ts

@@ -102,11 +102,11 @@ export const copyToClipboard = (text) => {
 };
 
 export const checkVersion = (required, current) => {
-	return (
-		current.localeCompare(required, undefined, {
-			numeric: true,
-			sensitivity: 'case',
-			caseFirst: 'upper'
-		}) < 0
-	);
+	return current === '0.0.0'
+		? true
+		: current.localeCompare(required, undefined, {
+				numeric: true,
+				sensitivity: 'case',
+				caseFirst: 'upper'
+		  }) < 0;
 };

+ 9 - 3
src/routes/(app)/+layout.svelte

@@ -8,6 +8,8 @@
 	const { saveAs } = fileSaver;
 
 	import { getOllamaModels, getOllamaVersion } from '$lib/apis/ollama';
+	import { getModelfiles } from '$lib/apis/modelfiles';
+
 	import { getOpenAIModels } from '$lib/apis/openai';
 
 	import { user, showSettings, settings, models, modelfiles } from '$lib/stores';
@@ -95,11 +97,14 @@
 
 			console.log();
 			await settings.set(JSON.parse(localStorage.getItem('settings') ?? '{}'));
-			await models.set(await getModels());
+			// await models.set(await getModels());
+			// JSON.parse(localStorage.getItem('modelfiles') ?? '[]')
+			await modelfiles.set(await getModelfiles(localStorage.token));
+			console.log($modelfiles);
 
-			await modelfiles.set(JSON.parse(localStorage.getItem('modelfiles') ?? '[]'));
 			modelfiles.subscribe(async () => {
 				// should fetch models
+				await models.set(await getModels());
 			});
 
 			await setOllamaVersion();
@@ -176,7 +181,8 @@
 								<button
 									class="relative z-20 flex px-5 py-2 rounded-full bg-white border border-gray-100 dark:border-none hover:bg-gray-100 transition font-medium text-sm"
 									on:click={async () => {
-										await setOllamaVersion();
+										location.href = '/';
+										// await setOllamaVersion();
 									}}
 								>
 									Check Again

+ 16 - 30
src/routes/(app)/modelfiles/+page.svelte

@@ -4,43 +4,29 @@
 	import toast from 'svelte-french-toast';
 
 	import { OLLAMA_API_BASE_URL } from '$lib/constants';
+	import { deleteModel } from '$lib/apis/ollama';
+	import { deleteModelfileByTagName, getModelfiles } from '$lib/apis/modelfiles';
 
 	const deleteModelHandler = async (tagName) => {
 		let success = null;
-		const res = await fetch(`${$settings?.API_BASE_URL ?? OLLAMA_API_BASE_URL}/delete`, {
-			method: 'DELETE',
-			headers: {
-				'Content-Type': 'text/event-stream',
-				...($settings.authHeader && { Authorization: $settings.authHeader }),
-				...($user && { Authorization: `Bearer ${localStorage.token}` })
-			},
-			body: JSON.stringify({
-				name: tagName
-			})
-		})
-			.then(async (res) => {
-				if (!res.ok) throw await res.json();
-				return res.json();
-			})
-			.then((json) => {
-				console.log(json);
-				toast.success(`Deleted ${tagName}`);
-				success = true;
-				return json;
-			})
-			.catch((err) => {
-				console.log(err);
-				toast.error(err.error);
-				return null;
-			});
+
+		success = await deleteModel(
+			$settings?.API_BASE_URL ?? OLLAMA_API_BASE_URL,
+			localStorage.token,
+			tagName
+		);
+
+		if (success) {
+			toast.success(`Deleted ${tagName}`);
+		}
 
 		return success;
 	};
 
-	const deleteModelfilebyTagName = async (tagName) => {
+	const deleteModelfile = async (tagName) => {
 		await deleteModelHandler(tagName);
-		await modelfiles.set($modelfiles.filter((modelfile) => modelfile.tagName != tagName));
-		localStorage.setItem('modelfiles', JSON.stringify($modelfiles));
+		await deleteModelfileByTagName(localStorage.token, tagName);
+		await modelfiles.set(await getModelfiles(localStorage.token));
 	};
 
 	const shareModelfile = async (modelfile) => {
@@ -167,7 +153,7 @@
 							class="self-center w-fit text-sm px-2 py-2 border dark:border-gray-600 rounded-xl"
 							type="button"
 							on:click={() => {
-								deleteModelfilebyTagName(modelfile.tagName);
+								deleteModelfile(modelfile.tagName);
 							}}
 						>
 							<svg

+ 16 - 16
src/routes/(app)/modelfiles/create/+page.svelte

@@ -8,6 +8,8 @@
 	import Advanced from '$lib/components/chat/Settings/Advanced.svelte';
 	import { splitStream } from '$lib/utils';
 	import { onMount, tick } from 'svelte';
+	import { createModel } from '$lib/apis/ollama';
+	import { createNewModelfile, getModelfiles } from '$lib/apis/modelfiles';
 
 	let loading = false;
 
@@ -93,11 +95,14 @@ SYSTEM """${system}"""`.replace(/^\s*\n/gm, '');
 	};
 
 	const saveModelfile = async (modelfile) => {
-		await modelfiles.set([
-			...$modelfiles.filter((m) => m.tagName !== modelfile.tagName),
-			modelfile
-		]);
-		localStorage.setItem('modelfiles', JSON.stringify($modelfiles));
+		// await modelfiles.set([
+		// 	...$modelfiles.filter((m) => m.tagName !== modelfile.tagName),
+		// 	modelfile
+		// ]);
+		// localStorage.setItem('modelfiles', JSON.stringify($modelfiles));
+
+		await createNewModelfile(localStorage.token, modelfile);
+		await modelfiles.set(await getModelfiles(localStorage.token));
 	};
 
 	const submitHandler = async () => {
@@ -128,17 +133,12 @@ SYSTEM """${system}"""`.replace(/^\s*\n/gm, '');
 			Object.keys(categories).filter((category) => categories[category]).length > 0 &&
 			!$models.includes(tagName)
 		) {
-			const res = await fetch(`${$settings?.API_BASE_URL ?? OLLAMA_API_BASE_URL}/create`, {
-				method: 'POST',
-				headers: {
-					'Content-Type': 'text/event-stream',
-					...($user && { Authorization: `Bearer ${localStorage.token}` })
-				},
-				body: JSON.stringify({
-					name: tagName,
-					modelfile: content
-				})
-			});
+			const res = await createModel(
+				$settings?.API_BASE_URL ?? OLLAMA_API_BASE_URL,
+				localStorage.token,
+				tagName,
+				content
+			);
 
 			if (res) {
 				const reader = res.body

+ 31 - 28
src/routes/(app)/modelfiles/edit/+page.svelte

@@ -2,14 +2,20 @@
 	import { v4 as uuidv4 } from 'uuid';
 	import { toast } from 'svelte-french-toast';
 	import { goto } from '$app/navigation';
-	import { OLLAMA_API_BASE_URL } from '$lib/constants';
-	import { settings, db, user, config, modelfiles } from '$lib/stores';
 
-	import Advanced from '$lib/components/chat/Settings/Advanced.svelte';
-	import { splitStream } from '$lib/utils';
 	import { onMount } from 'svelte';
 	import { page } from '$app/stores';
 
+	import { settings, db, user, config, modelfiles } from '$lib/stores';
+
+	import { OLLAMA_API_BASE_URL } from '$lib/constants';
+	import { splitStream } from '$lib/utils';
+
+	import { createModel } from '$lib/apis/ollama';
+	import { getModelfiles, updateModelfileByTagName } from '$lib/apis/modelfiles';
+
+	import Advanced from '$lib/components/chat/Settings/Advanced.svelte';
+
 	let loading = false;
 
 	let filesInputElement;
@@ -78,17 +84,20 @@
 		}
 	});
 
-	const saveModelfile = async (modelfile) => {
-		await modelfiles.set(
-			$modelfiles.map((e) => {
-				if (e.tagName === modelfile.tagName) {
-					return modelfile;
-				} else {
-					return e;
-				}
-			})
-		);
-		localStorage.setItem('modelfiles', JSON.stringify($modelfiles));
+	const updateModelfile = async (modelfile) => {
+		// await modelfiles.set(
+		// 	$modelfiles.map((e) => {
+		// 		if (e.tagName === modelfile.tagName) {
+		// 			return modelfile;
+		// 		} else {
+		// 			return e;
+		// 		}
+		// 	})
+		// );
+		// localStorage.setItem('modelfiles', JSON.stringify($modelfiles));
+
+		await updateModelfileByTagName(localStorage.token, modelfile.tagName, modelfile);
+		await modelfiles.set(await getModelfiles(localStorage.token));
 	};
 
 	const updateHandler = async () => {
@@ -106,18 +115,12 @@
 			content !== '' &&
 			Object.keys(categories).filter((category) => categories[category]).length > 0
 		) {
-			const res = await fetch(`${$settings?.API_BASE_URL ?? OLLAMA_API_BASE_URL}/create`, {
-				method: 'POST',
-				headers: {
-					'Content-Type': 'text/event-stream',
-					...($settings.authHeader && { Authorization: $settings.authHeader }),
-					...($user && { Authorization: `Bearer ${localStorage.token}` })
-				},
-				body: JSON.stringify({
-					name: tagName,
-					modelfile: content
-				})
-			});
+			const res = await createModel(
+				$settings?.API_BASE_URL ?? OLLAMA_API_BASE_URL,
+				localStorage.token,
+				tagName,
+				content
+			);
 
 			if (res) {
 				const reader = res.body
@@ -178,7 +181,7 @@
 			}
 
 			if (success) {
-				await saveModelfile({
+				await updateModelfile({
 					tagName: tagName,
 					imageUrl: imageUrl,
 					title: title,