Browse Source

feat: multiple ollama support

Timothy J. Baek 1 year ago
parent
commit
f04d60b6d9

+ 717 - 11
backend/apps/ollama/main.py

@@ -3,16 +3,23 @@ from fastapi.middleware.cors import CORSMiddleware
 from fastapi.responses import StreamingResponse
 from fastapi.concurrency import run_in_threadpool
 
+from pydantic import BaseModel
+
+import random
 import requests
 import json
 import uuid
-from pydantic import BaseModel
+import aiohttp
+import asyncio
 
 from apps.web.models.users import Users
 from constants import ERROR_MESSAGES
 from utils.utils import decode_token, get_current_user, get_admin_user
 from config import OLLAMA_BASE_URL, WEBUI_AUTH
 
+from typing import Optional, List, Union
+
+
 app = FastAPI()
 app.add_middleware(
     CORSMiddleware,
@@ -23,26 +30,39 @@ app.add_middleware(
 )
 
 app.state.OLLAMA_BASE_URL = OLLAMA_BASE_URL
-
-# TARGET_SERVER_URL = OLLAMA_API_BASE_URL
+app.state.OLLAMA_BASE_URLS = [OLLAMA_BASE_URL]
+app.state.MODELS = {}
 
 
 REQUEST_POOL = []
 
 
-@app.get("/url")
-async def get_ollama_api_url(user=Depends(get_admin_user)):
-    return {"OLLAMA_BASE_URL": app.state.OLLAMA_BASE_URL}
+@app.middleware("http")
+async def check_url(request: Request, call_next):
+    if len(app.state.MODELS) == 0:
+        await get_all_models()
+    else:
+        pass
+
+    response = await call_next(request)
+    return response
+
+
+@app.get("/urls")
+async def get_ollama_api_urls(user=Depends(get_admin_user)):
+    return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS}
 
 
 class UrlUpdateForm(BaseModel):
-    url: str
+    urls: List[str]
 
 
-@app.post("/url/update")
+@app.post("/urls/update")
 async def update_ollama_api_url(form_data: UrlUpdateForm, user=Depends(get_admin_user)):
-    app.state.OLLAMA_BASE_URL = form_data.url
-    return {"OLLAMA_BASE_URL": app.state.OLLAMA_BASE_URL}
+    app.state.OLLAMA_BASE_URLS = form_data.urls
+
+    print(app.state.OLLAMA_BASE_URLS)
+    return {"OLLAMA_BASE_URLS": app.state.OLLAMA_BASE_URLS}
 
 
 @app.get("/cancel/{request_id}")
@@ -55,9 +75,695 @@ async def cancel_ollama_request(request_id: str, user=Depends(get_current_user))
         raise HTTPException(status_code=401, detail=ERROR_MESSAGES.ACCESS_PROHIBITED)
 
 
+async def fetch_url(url):
+    try:
+        async with aiohttp.ClientSession() as session:
+            async with session.get(url) as response:
+                return await response.json()
+    except Exception as e:
+        # Handle connection error here
+        print(f"Connection error: {e}")
+        return None
+
+
+def merge_models_lists(model_lists):
+    merged_models = {}
+
+    for idx, model_list in enumerate(model_lists):
+        for model in model_list:
+            digest = model["digest"]
+            if digest not in merged_models:
+                model["urls"] = [idx]
+                merged_models[digest] = model
+            else:
+                merged_models[digest]["urls"].append(idx)
+
+    return list(merged_models.values())
+
+
+# user=Depends(get_current_user)
+
+
+async def get_all_models():
+    print("get_all_models")
+    tasks = [fetch_url(f"{url}/api/tags") for url in app.state.OLLAMA_BASE_URLS]
+    responses = await asyncio.gather(*tasks)
+    responses = list(filter(lambda x: x is not None, responses))
+
+    models = {
+        "models": merge_models_lists(
+            map(lambda response: response["models"], responses)
+        )
+    }
+    app.state.MODELS = {model["model"]: model for model in models["models"]}
+
+    return models
+
+
+@app.get("/api/tags")
+@app.get("/api/tags/{url_idx}")
+async def get_ollama_tags(
+    url_idx: Optional[int] = None, user=Depends(get_current_user)
+):
+
+    if url_idx == None:
+        return await get_all_models()
+    else:
+        url = app.state.OLLAMA_BASE_URLS[url_idx]
+        try:
+            r = requests.request(method="GET", url=f"{url}/api/tags")
+            r.raise_for_status()
+
+            return r.json()
+        except Exception as e:
+            print(e)
+            error_detail = "Open WebUI: Server Connection Error"
+            if r is not None:
+                try:
+                    res = r.json()
+                    if "error" in res:
+                        error_detail = f"Ollama: {res['error']}"
+                except:
+                    error_detail = f"Ollama: {e}"
+
+            raise HTTPException(
+                status_code=r.status_code if r else 500,
+                detail=error_detail,
+            )
+
+
+@app.get("/api/version")
+@app.get("/api/version/{url_idx}")
+async def get_ollama_versions(url_idx: Optional[int] = None):
+
+    if url_idx == None:
+
+        # returns lowest version
+        tasks = [fetch_url(f"{url}/api/version") for url in app.state.OLLAMA_BASE_URLS]
+        responses = await asyncio.gather(*tasks)
+        responses = list(filter(lambda x: x is not None, responses))
+
+        lowest_version = min(
+            responses, key=lambda x: tuple(map(int, x["version"].split(".")))
+        )
+
+        return {"version": lowest_version["version"]}
+    else:
+        url = app.state.OLLAMA_BASE_URLS[url_idx]
+        try:
+            r = requests.request(method="GET", url=f"{url}/api/version")
+            r.raise_for_status()
+
+            return r.json()
+        except Exception as e:
+            print(e)
+            error_detail = "Open WebUI: Server Connection Error"
+            if r is not None:
+                try:
+                    res = r.json()
+                    if "error" in res:
+                        error_detail = f"Ollama: {res['error']}"
+                except:
+                    error_detail = f"Ollama: {e}"
+
+            raise HTTPException(
+                status_code=r.status_code if r else 500,
+                detail=error_detail,
+            )
+
+
+class ModelNameForm(BaseModel):
+    name: str
+
+
+@app.post("/api/pull")
+@app.post("/api/pull/{url_idx}")
+async def pull_model(
+    form_data: ModelNameForm, url_idx: int = 0, user=Depends(get_admin_user)
+):
+    url = app.state.OLLAMA_BASE_URLS[url_idx]
+    r = None
+
+    def get_request(url):
+        nonlocal r
+        try:
+
+            def stream_content():
+                for chunk in r.iter_content(chunk_size=8192):
+                    yield chunk
+
+            r = requests.request(
+                method="POST",
+                url=f"{url}/api/pull",
+                data=form_data.model_dump_json(exclude_none=True),
+                stream=True,
+            )
+
+            r.raise_for_status()
+
+            return StreamingResponse(
+                stream_content(),
+                status_code=r.status_code,
+                headers=dict(r.headers),
+            )
+        except Exception as e:
+            raise e
+
+    try:
+        return await run_in_threadpool(get_request(url))
+    except Exception as e:
+        print(e)
+        error_detail = "Open WebUI: Server Connection Error"
+        if r is not None:
+            try:
+                res = r.json()
+                if "error" in res:
+                    error_detail = f"Ollama: {res['error']}"
+            except:
+                error_detail = f"Ollama: {e}"
+
+        raise HTTPException(
+            status_code=r.status_code if r else 500,
+            detail=error_detail,
+        )
+
+
+class PushModelForm(BaseModel):
+    name: str
+    insecure: Optional[bool] = None
+    stream: Optional[bool] = None
+
+
+@app.delete("/api/push")
+@app.delete("/api/push/{url_idx}")
+async def push_model(
+    form_data: PushModelForm,
+    url_idx: Optional[int] = None,
+    user=Depends(get_admin_user),
+):
+    if url_idx == None:
+        if form_data.name in app.state.MODELS:
+            url_idx = app.state.MODELS[form_data.name]["urls"][0]
+        else:
+            raise HTTPException(
+                status_code=400,
+                detail="error_detail",
+            )
+
+    url = app.state.OLLAMA_BASE_URLS[url_idx]
+
+    r = None
+
+    def get_request():
+        nonlocal url
+        nonlocal r
+        try:
+
+            def stream_content():
+                for chunk in r.iter_content(chunk_size=8192):
+                    yield chunk
+
+            r = requests.request(
+                method="POST",
+                url=f"{url}/api/push",
+                data=form_data.model_dump_json(exclude_none=True),
+            )
+
+            r.raise_for_status()
+
+            return StreamingResponse(
+                stream_content(),
+                status_code=r.status_code,
+                headers=dict(r.headers),
+            )
+        except Exception as e:
+            raise e
+
+    try:
+        return await run_in_threadpool(get_request)
+    except Exception as e:
+        print(e)
+        error_detail = "Open WebUI: Server Connection Error"
+        if r is not None:
+            try:
+                res = r.json()
+                if "error" in res:
+                    error_detail = f"Ollama: {res['error']}"
+            except:
+                error_detail = f"Ollama: {e}"
+
+        raise HTTPException(
+            status_code=r.status_code if r else 500,
+            detail=error_detail,
+        )
+
+
+class CreateModelForm(BaseModel):
+    name: str
+    modelfile: Optional[str] = None
+    stream: Optional[bool] = None
+    path: Optional[str] = None
+
+
+@app.post("/api/create")
+@app.post("/api/create/{url_idx}")
+async def create_model(
+    form_data: CreateModelForm, url_idx: int = 0, user=Depends(get_admin_user)
+):
+    print(form_data)
+    url = app.state.OLLAMA_BASE_URLS[url_idx]
+    r = None
+
+    def get_request():
+        nonlocal url
+        nonlocal r
+        try:
+
+            def stream_content():
+                for chunk in r.iter_content(chunk_size=8192):
+                    yield chunk
+
+            r = requests.request(
+                method="POST",
+                url=f"{url}/api/create",
+                data=form_data.model_dump_json(exclude_none=True),
+                stream=True,
+            )
+
+            r.raise_for_status()
+
+            print(r)
+
+            return StreamingResponse(
+                stream_content(),
+                status_code=r.status_code,
+                headers=dict(r.headers),
+            )
+        except Exception as e:
+            raise e
+
+    try:
+        return await run_in_threadpool(get_request)
+    except Exception as e:
+        print(e)
+        error_detail = "Open WebUI: Server Connection Error"
+        if r is not None:
+            try:
+                res = r.json()
+                if "error" in res:
+                    error_detail = f"Ollama: {res['error']}"
+            except:
+                error_detail = f"Ollama: {e}"
+
+        raise HTTPException(
+            status_code=r.status_code if r else 500,
+            detail=error_detail,
+        )
+
+
+class CopyModelForm(BaseModel):
+    source: str
+    destination: str
+
+
+@app.post("/api/copy")
+@app.post("/api/copy/{url_idx}")
+async def copy_model(
+    form_data: CopyModelForm,
+    url_idx: Optional[int] = None,
+    user=Depends(get_admin_user),
+):
+    if url_idx == None:
+        if form_data.source in app.state.MODELS:
+            url_idx = app.state.MODELS[form_data.source]["urls"][0]
+        else:
+            raise HTTPException(
+                status_code=400,
+                detail="error_detail",
+            )
+
+    url = app.state.OLLAMA_BASE_URLS[url_idx]
+
+    try:
+        r = requests.request(
+            method="POST",
+            url=f"{url}/api/copy",
+            data=form_data.model_dump_json(exclude_none=True),
+        )
+        r.raise_for_status()
+
+        print(r.text)
+
+        return True
+    except Exception as e:
+        print(e)
+        error_detail = "Open WebUI: Server Connection Error"
+        if r is not None:
+            try:
+                res = r.json()
+                if "error" in res:
+                    error_detail = f"Ollama: {res['error']}"
+            except:
+                error_detail = f"Ollama: {e}"
+
+        raise HTTPException(
+            status_code=r.status_code if r else 500,
+            detail=error_detail,
+        )
+
+
+@app.delete("/api/delete")
+@app.delete("/api/delete/{url_idx}")
+async def delete_model(
+    form_data: ModelNameForm,
+    url_idx: Optional[int] = None,
+    user=Depends(get_admin_user),
+):
+    if url_idx == None:
+        if form_data.name in app.state.MODELS:
+            url_idx = app.state.MODELS[form_data.name]["urls"][0]
+        else:
+            raise HTTPException(
+                status_code=400,
+                detail="error_detail",
+            )
+
+    url = app.state.OLLAMA_BASE_URLS[url_idx]
+
+    try:
+        r = requests.request(
+            method="DELETE",
+            url=f"{url}/api/delete",
+            data=form_data.model_dump_json(exclude_none=True),
+        )
+        r.raise_for_status()
+
+        print(r.text)
+
+        return True
+    except Exception as e:
+        print(e)
+        error_detail = "Open WebUI: Server Connection Error"
+        if r is not None:
+            try:
+                res = r.json()
+                if "error" in res:
+                    error_detail = f"Ollama: {res['error']}"
+            except:
+                error_detail = f"Ollama: {e}"
+
+        raise HTTPException(
+            status_code=r.status_code if r else 500,
+            detail=error_detail,
+        )
+
+
+@app.post("/api/show")
+async def show_model_info(form_data: ModelNameForm, user=Depends(get_current_user)):
+    if form_data.name not in app.state.MODELS:
+        raise HTTPException(
+            status_code=400,
+            detail="error_detail",
+        )
+
+    url_idx = random.choice(app.state.MODELS[form_data.name]["urls"])
+    url = app.state.OLLAMA_BASE_URLS[url_idx]
+
+    try:
+        r = requests.request(
+            method="POST",
+            url=f"{url}/api/show",
+            data=form_data.model_dump_json(exclude_none=True),
+        )
+        r.raise_for_status()
+
+        return r.json()
+    except Exception as e:
+        print(e)
+        error_detail = "Open WebUI: Server Connection Error"
+        if r is not None:
+            try:
+                res = r.json()
+                if "error" in res:
+                    error_detail = f"Ollama: {res['error']}"
+            except:
+                error_detail = f"Ollama: {e}"
+
+        raise HTTPException(
+            status_code=r.status_code if r else 500,
+            detail=error_detail,
+        )
+
+
+class GenerateEmbeddingsForm(BaseModel):
+    model: str
+    prompt: str
+    options: Optional[dict] = None
+    keep_alive: Optional[Union[int, str]] = None
+
+
+@app.post("/api/embeddings")
+@app.post("/api/embeddings/{url_idx}")
+async def generate_embeddings(
+    form_data: GenerateEmbeddingsForm,
+    url_idx: Optional[int] = None,
+    user=Depends(get_current_user),
+):
+    if url_idx == None:
+        if form_data.model in app.state.MODELS:
+            url_idx = random.choice(app.state.MODELS[form_data.model]["urls"])
+        else:
+            raise HTTPException(
+                status_code=400,
+                detail="error_detail",
+            )
+
+    url = app.state.OLLAMA_BASE_URLS[url_idx]
+
+    try:
+        r = requests.request(
+            method="POST",
+            url=f"{url}/api/embeddings",
+            data=form_data.model_dump_json(exclude_none=True),
+        )
+        r.raise_for_status()
+
+        return r.json()
+    except Exception as e:
+        print(e)
+        error_detail = "Open WebUI: Server Connection Error"
+        if r is not None:
+            try:
+                res = r.json()
+                if "error" in res:
+                    error_detail = f"Ollama: {res['error']}"
+            except:
+                error_detail = f"Ollama: {e}"
+
+        raise HTTPException(
+            status_code=r.status_code if r else 500,
+            detail=error_detail,
+        )
+
+
+class GenerateCompletionForm(BaseModel):
+    model: str
+    prompt: str
+    images: Optional[List[str]] = None
+    format: Optional[str] = None
+    options: Optional[dict] = None
+    system: Optional[str] = None
+    template: Optional[str] = None
+    context: Optional[str] = None
+    stream: Optional[bool] = True
+    raw: Optional[bool] = None
+    keep_alive: Optional[Union[int, str]] = None
+
+
+@app.post("/api/generate")
+@app.post("/api/generate/{url_idx}")
+async def generate_completion(
+    form_data: GenerateCompletionForm,
+    url_idx: Optional[int] = None,
+    user=Depends(get_current_user),
+):
+
+    if url_idx == None:
+        if form_data.model in app.state.MODELS:
+            url_idx = random.choice(app.state.MODELS[form_data.model]["urls"])
+        else:
+            raise HTTPException(
+                status_code=400,
+                detail="error_detail",
+            )
+
+    url = app.state.OLLAMA_BASE_URLS[url_idx]
+
+    r = None
+
+    def get_request():
+        nonlocal form_data
+        nonlocal r
+
+        request_id = str(uuid.uuid4())
+        try:
+            REQUEST_POOL.append(request_id)
+
+            def stream_content():
+                try:
+                    if form_data.stream:
+                        yield json.dumps({"id": request_id, "done": False}) + "\n"
+
+                    for chunk in r.iter_content(chunk_size=8192):
+                        if request_id in REQUEST_POOL:
+                            yield chunk
+                        else:
+                            print("User: canceled request")
+                            break
+                finally:
+                    if hasattr(r, "close"):
+                        r.close()
+                        if request_id in REQUEST_POOL:
+                            REQUEST_POOL.remove(request_id)
+
+            r = requests.request(
+                method="POST",
+                url=f"{url}/api/generate",
+                data=form_data.model_dump_json(exclude_none=True),
+                stream=True,
+            )
+
+            r.raise_for_status()
+
+            return StreamingResponse(
+                stream_content(),
+                status_code=r.status_code,
+                headers=dict(r.headers),
+            )
+        except Exception as e:
+            raise e
+
+    try:
+        return await run_in_threadpool(get_request)
+    except Exception as e:
+        error_detail = "Open WebUI: Server Connection Error"
+        if r is not None:
+            try:
+                res = r.json()
+                if "error" in res:
+                    error_detail = f"Ollama: {res['error']}"
+            except:
+                error_detail = f"Ollama: {e}"
+
+        raise HTTPException(
+            status_code=r.status_code if r else 500,
+            detail=error_detail,
+        )
+
+
+class ChatMessage(BaseModel):
+    role: str
+    content: str
+    images: Optional[List[str]] = None
+
+
+class GenerateChatCompletionForm(BaseModel):
+    model: str
+    messages: List[ChatMessage]
+    format: Optional[str] = None
+    options: Optional[dict] = None
+    template: Optional[str] = None
+    stream: Optional[bool] = True
+    keep_alive: Optional[Union[int, str]] = None
+
+
+@app.post("/api/chat")
+@app.post("/api/chat/{url_idx}")
+async def generate_completion(
+    form_data: GenerateChatCompletionForm,
+    url_idx: Optional[int] = None,
+    user=Depends(get_current_user),
+):
+
+    if url_idx == None:
+        if form_data.model in app.state.MODELS:
+            url_idx = random.choice(app.state.MODELS[form_data.model]["urls"])
+        else:
+            raise HTTPException(
+                status_code=400,
+                detail="error_detail",
+            )
+
+    url = app.state.OLLAMA_BASE_URLS[url_idx]
+
+    r = None
+
+    print(form_data.model_dump_json(exclude_none=True))
+
+    def get_request():
+        nonlocal form_data
+        nonlocal r
+
+        request_id = str(uuid.uuid4())
+        try:
+            REQUEST_POOL.append(request_id)
+
+            def stream_content():
+                try:
+                    if form_data.stream:
+                        yield json.dumps({"id": request_id, "done": False}) + "\n"
+
+                    for chunk in r.iter_content(chunk_size=8192):
+                        if request_id in REQUEST_POOL:
+                            yield chunk
+                        else:
+                            print("User: canceled request")
+                            break
+                finally:
+                    if hasattr(r, "close"):
+                        r.close()
+                        if request_id in REQUEST_POOL:
+                            REQUEST_POOL.remove(request_id)
+
+            r = requests.request(
+                method="POST",
+                url=f"{url}/api/chat",
+                data=form_data.model_dump_json(exclude_none=True),
+                stream=True,
+            )
+
+            r.raise_for_status()
+
+            return StreamingResponse(
+                stream_content(),
+                status_code=r.status_code,
+                headers=dict(r.headers),
+            )
+        except Exception as e:
+            raise e
+
+    try:
+        return await run_in_threadpool(get_request)
+    except Exception as e:
+        error_detail = "Open WebUI: Server Connection Error"
+        if r is not None:
+            try:
+                res = r.json()
+                if "error" in res:
+                    error_detail = f"Ollama: {res['error']}"
+            except:
+                error_detail = f"Ollama: {e}"
+
+        raise HTTPException(
+            status_code=r.status_code if r else 500,
+            detail=error_detail,
+        )
+
+
 @app.api_route("/{path:path}", methods=["GET", "POST", "PUT", "DELETE"])
 async def proxy(path: str, request: Request, user=Depends(get_current_user)):
-    target_url = f"{app.state.OLLAMA_BASE_URL}/{path}"
+    url = app.state.OLLAMA_BASE_URLS[0]
+    target_url = f"{url}/{path}"
 
     body = await request.body()
     headers = dict(request.headers)

+ 8 - 0
backend/main.py

@@ -125,6 +125,14 @@ async def get_app_config():
     }
 
 
+@app.get("/api/version")
+async def get_app_config():
+
+    return {
+        "version": VERSION,
+    }
+
+
 @app.get("/api/changelog")
 async def get_app_changelog():
     return CHANGELOG

+ 27 - 15
src/lib/apis/ollama/index.ts

@@ -1,9 +1,9 @@
 import { OLLAMA_API_BASE_URL } from '$lib/constants';
 
-export const getOllamaAPIUrl = async (token: string = '') => {
+export const getOllamaUrls = async (token: string = '') => {
 	let error = null;
 
-	const res = await fetch(`${OLLAMA_API_BASE_URL}/url`, {
+	const res = await fetch(`${OLLAMA_API_BASE_URL}/urls`, {
 		method: 'GET',
 		headers: {
 			Accept: 'application/json',
@@ -29,13 +29,13 @@ export const getOllamaAPIUrl = async (token: string = '') => {
 		throw error;
 	}
 
-	return res.OLLAMA_BASE_URL;
+	return res.OLLAMA_BASE_URLS;
 };
 
-export const updateOllamaAPIUrl = async (token: string = '', url: string) => {
+export const updateOllamaUrls = async (token: string = '', urls: string[]) => {
 	let error = null;
 
-	const res = await fetch(`${OLLAMA_API_BASE_URL}/url/update`, {
+	const res = await fetch(`${OLLAMA_API_BASE_URL}/urls/update`, {
 		method: 'POST',
 		headers: {
 			Accept: 'application/json',
@@ -43,7 +43,7 @@ export const updateOllamaAPIUrl = async (token: string = '', url: string) => {
 			...(token && { authorization: `Bearer ${token}` })
 		},
 		body: JSON.stringify({
-			url: url
+			urls: urls
 		})
 	})
 		.then(async (res) => {
@@ -64,7 +64,7 @@ export const updateOllamaAPIUrl = async (token: string = '', url: string) => {
 		throw error;
 	}
 
-	return res.OLLAMA_BASE_URL;
+	return res.OLLAMA_BASE_URLS;
 };
 
 export const getOllamaVersion = async (token: string = '') => {
@@ -151,7 +151,8 @@ export const generateTitle = async (
 	const res = await fetch(`${OLLAMA_API_BASE_URL}/api/generate`, {
 		method: 'POST',
 		headers: {
-			'Content-Type': 'text/event-stream',
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
 			Authorization: `Bearer ${token}`
 		},
 		body: JSON.stringify({
@@ -189,7 +190,8 @@ export const generatePrompt = async (token: string = '', model: string, conversa
 	const res = await fetch(`${OLLAMA_API_BASE_URL}/api/generate`, {
 		method: 'POST',
 		headers: {
-			'Content-Type': 'text/event-stream',
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
 			Authorization: `Bearer ${token}`
 		},
 		body: JSON.stringify({
@@ -223,7 +225,8 @@ export const generateTextCompletion = async (token: string = '', model: string,
 	const res = await fetch(`${OLLAMA_API_BASE_URL}/api/generate`, {
 		method: 'POST',
 		headers: {
-			'Content-Type': 'text/event-stream',
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
 			Authorization: `Bearer ${token}`
 		},
 		body: JSON.stringify({
@@ -251,7 +254,8 @@ export const generateChatCompletion = async (token: string = '', body: object) =
 		signal: controller.signal,
 		method: 'POST',
 		headers: {
-			'Content-Type': 'text/event-stream',
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
 			Authorization: `Bearer ${token}`
 		},
 		body: JSON.stringify(body)
@@ -294,7 +298,8 @@ export const createModel = async (token: string, tagName: string, content: strin
 	const res = await fetch(`${OLLAMA_API_BASE_URL}/api/create`, {
 		method: 'POST',
 		headers: {
-			'Content-Type': 'text/event-stream',
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
 			Authorization: `Bearer ${token}`
 		},
 		body: JSON.stringify({
@@ -319,7 +324,8 @@ export const deleteModel = async (token: string, tagName: string) => {
 	const res = await fetch(`${OLLAMA_API_BASE_URL}/api/delete`, {
 		method: 'DELETE',
 		headers: {
-			'Content-Type': 'text/event-stream',
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
 			Authorization: `Bearer ${token}`
 		},
 		body: JSON.stringify({
@@ -336,7 +342,12 @@ export const deleteModel = async (token: string, tagName: string) => {
 		})
 		.catch((err) => {
 			console.log(err);
-			error = err.error;
+			error = err;
+
+			if ('detail' in err) {
+				error = err.detail;
+			}
+
 			return null;
 		});
 
@@ -353,7 +364,8 @@ export const pullModel = async (token: string, tagName: string) => {
 	const res = await fetch(`${OLLAMA_API_BASE_URL}/api/pull`, {
 		method: 'POST',
 		headers: {
-			'Content-Type': 'text/event-stream',
+			Accept: 'application/json',
+			'Content-Type': 'application/json',
 			Authorization: `Bearer ${token}`
 		},
 		body: JSON.stringify({

+ 79 - 36
src/lib/components/chat/Settings/Connections.svelte

@@ -3,14 +3,15 @@
 	import { createEventDispatcher, onMount } from 'svelte';
 	const dispatch = createEventDispatcher();
 
-	import { getOllamaAPIUrl, getOllamaVersion, updateOllamaAPIUrl } from '$lib/apis/ollama';
+	import { getOllamaUrls, getOllamaVersion, updateOllamaUrls } from '$lib/apis/ollama';
 	import { getOpenAIKey, getOpenAIUrl, updateOpenAIKey, updateOpenAIUrl } from '$lib/apis/openai';
 	import { toast } from 'svelte-sonner';
 
 	export let getModels: Function;
 
 	// External
-	let API_BASE_URL = '';
+	let OLLAMA_BASE_URL = '';
+	let OLLAMA_BASE_URLS = [''];
 
 	let OPENAI_API_KEY = '';
 	let OPENAI_API_BASE_URL = '';
@@ -25,8 +26,8 @@
 		await models.set(await getModels());
 	};
 
-	const updateOllamaAPIUrlHandler = async () => {
-		API_BASE_URL = await updateOllamaAPIUrl(localStorage.token, API_BASE_URL);
+	const updateOllamaUrlsHandler = async () => {
+		OLLAMA_BASE_URLS = await updateOllamaUrls(localStorage.token, OLLAMA_BASE_URLS);
 
 		const ollamaVersion = await getOllamaVersion(localStorage.token).catch((error) => {
 			toast.error(error);
@@ -41,7 +42,7 @@
 
 	onMount(async () => {
 		if ($user.role === 'admin') {
-			API_BASE_URL = await getOllamaAPIUrl(localStorage.token);
+			OLLAMA_BASE_URLS = await getOllamaUrls(localStorage.token);
 			OPENAI_API_BASE_URL = await getOpenAIUrl(localStorage.token);
 			OPENAI_API_KEY = await getOpenAIKey(localStorage.token);
 		}
@@ -53,11 +54,6 @@
 	on:submit|preventDefault={() => {
 		updateOpenAIHandler();
 		dispatch('save');
-
-		// saveSettings({
-		// 	OPENAI_API_KEY: OPENAI_API_KEY !== '' ? OPENAI_API_KEY : undefined,
-		// 	OPENAI_API_BASE_URL: OPENAI_API_BASE_URL !== '' ? OPENAI_API_BASE_URL : undefined
-		// });
 	}}
 >
 	<div class="  pr-1.5 overflow-y-scroll max-h-[20.5rem] space-y-3">
@@ -115,34 +111,81 @@
 
 		<div>
 			<div class=" mb-2.5 text-sm font-medium">Ollama Base URL</div>
-			<div class="flex w-full">
-				<div class="flex-1 mr-2">
-					<input
-						class="w-full rounded py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-800 outline-none"
-						placeholder="Enter URL (e.g. http://localhost:11434)"
-						bind:value={API_BASE_URL}
-					/>
+			<div class="flex w-full gap-1.5">
+				<div class="flex-1 flex flex-col gap-2">
+					{#each OLLAMA_BASE_URLS as url, idx}
+						<div class="flex gap-1.5">
+							<input
+								class="w-full rounded-lg py-2 px-4 text-sm dark:text-gray-300 dark:bg-gray-850 outline-none"
+								placeholder="Enter URL (e.g. http://localhost:11434)"
+								bind:value={url}
+							/>
+
+							<div class="self-center flex items-center">
+								{#if idx === 0}
+									<button
+										class="px-1"
+										on:click={() => {
+											OLLAMA_BASE_URLS = [...OLLAMA_BASE_URLS, ''];
+										}}
+										type="button"
+									>
+										<svg
+											xmlns="http://www.w3.org/2000/svg"
+											viewBox="0 0 16 16"
+											fill="currentColor"
+											class="w-4 h-4"
+										>
+											<path
+												d="M8.75 3.75a.75.75 0 0 0-1.5 0v3.5h-3.5a.75.75 0 0 0 0 1.5h3.5v3.5a.75.75 0 0 0 1.5 0v-3.5h3.5a.75.75 0 0 0 0-1.5h-3.5v-3.5Z"
+											/>
+										</svg>
+									</button>
+								{:else}
+									<button
+										class="px-1"
+										on:click={() => {
+											OLLAMA_BASE_URLS = OLLAMA_BASE_URLS.filter((url, urlIdx) => idx !== urlIdx);
+										}}
+										type="button"
+									>
+										<svg
+											xmlns="http://www.w3.org/2000/svg"
+											viewBox="0 0 16 16"
+											fill="currentColor"
+											class="w-4 h-4"
+										>
+											<path d="M3.75 7.25a.75.75 0 0 0 0 1.5h8.5a.75.75 0 0 0 0-1.5h-8.5Z" />
+										</svg>
+									</button>
+								{/if}
+							</div>
+						</div>
+					{/each}
 				</div>
-				<button
-					class="px-3 bg-gray-200 hover:bg-gray-300 dark:bg-gray-600 dark:hover:bg-gray-700 rounded transition"
-					on:click={() => {
-						updateOllamaAPIUrlHandler();
-					}}
-					type="button"
-				>
-					<svg
-						xmlns="http://www.w3.org/2000/svg"
-						viewBox="0 0 20 20"
-						fill="currentColor"
-						class="w-4 h-4"
+
+				<div class="">
+					<button
+						class="p-2.5 bg-gray-200 hover:bg-gray-300 dark:bg-gray-850 dark:hover:bg-gray-800 rounded-lg transition"
+						on:click={() => {
+							updateOllamaUrlsHandler();
+						}}
+						type="button"
 					>
-						<path
-							fill-rule="evenodd"
-							d="M15.312 11.424a5.5 5.5 0 01-9.201 2.466l-.312-.311h2.433a.75.75 0 000-1.5H3.989a.75.75 0 00-.75.75v4.242a.75.75 0 001.5 0v-2.43l.31.31a7 7 0 0011.712-3.138.75.75 0 00-1.449-.39zm1.23-3.723a.75.75 0 00.219-.53V2.929a.75.75 0 00-1.5 0V5.36l-.31-.31A7 7 0 003.239 8.188a.75.75 0 101.448.389A5.5 5.5 0 0113.89 6.11l.311.31h-2.432a.75.75 0 000 1.5h4.243a.75.75 0 00.53-.219z"
-							clip-rule="evenodd"
-						/>
-					</svg>
-				</button>
+						<svg
+							xmlns="http://www.w3.org/2000/svg"
+							viewBox="0 0 20 20"
+							fill="currentColor"
+							class="w-4 h-4"
+						>
+							<path
+								fill-rule="evenodd"
+								d="M15.312 11.424a5.5 5.5 0 01-9.201 2.466l-.312-.311h2.433a.75.75 0 000-1.5H3.989a.75.75 0 00-.75.75v4.242a.75.75 0 001.5 0v-2.43l.31.31a7 7 0 0011.712-3.138.75.75 0 00-1.449-.39zm1.23-3.723a.75.75 0 00.219-.53V2.929a.75.75 0 00-1.5 0V5.36l-.31-.31A7 7 0 003.239 8.188a.75.75 0 101.448.389A5.5 5.5 0 0113.89 6.11l.311.31h-2.432a.75.75 0 000 1.5h4.243a.75.75 0 00.53-.219z"
+								clip-rule="evenodd"
+							/>
+						</svg>
+					</button>
+				</div>
 			</div>
 
 			<div class="mt-2 text-xs text-gray-400 dark:text-gray-500">