Browse Source

feat: ollama auth support

Timothy Jaeryang Baek 5 months ago
parent
commit
99446c4b76
1 changed files with 64 additions and 9 deletions
  1. 64 9
      backend/open_webui/apps/ollama/main.py

+ 64 - 9
backend/open_webui/apps/ollama/main.py

@@ -209,10 +209,18 @@ async def post_streaming_url(
         session = aiohttp.ClientSession(
             trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
         )
+
+        api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+        key = api_config.get("key", None)
+
+        headers = {"Content-Type": "application/json"}
+        if key:
+            headers["Authorization"] = f"Bearer {key}"
+
         r = await session.post(
             url,
             data=payload,
-            headers={"Content-Type": "application/json"},
+            headers=headers,
         )
         r.raise_for_status()
 
@@ -275,9 +283,10 @@ async def get_all_models():
             else:
                 api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
                 enable = api_config.get("enable", True)
+                key = api_config.get("key", None)
 
                 if enable:
-                    tasks.append(aiohttp_get(f"{url}/api/tags"))
+                    tasks.append(aiohttp_get(f"{url}/api/tags", key))
                 else:
                     tasks.append(None)
 
@@ -341,9 +350,16 @@ async def get_ollama_tags(
     else:
         url = app.state.config.OLLAMA_BASE_URLS[url_idx]
 
+        api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+        key = api_config.get("key", None)
+
+        headers = {}
+        if key:
+            headers["Authorization"] = f"Bearer {key}"
+
         r = None
         try:
-            r = requests.request(method="GET", url=f"{url}/api/tags")
+            r = requests.request(method="GET", url=f"{url}/api/tags", headers=headers)
             r.raise_for_status()
 
             return r.json()
@@ -371,7 +387,10 @@ async def get_ollama_versions(url_idx: Optional[int] = None):
         if url_idx is None:
             # returns lowest version
             tasks = [
-                aiohttp_get(f"{url}/api/version")
+                aiohttp_get(
+                    f"{url}/api/version",
+                    app.state.config.OLLAMA_API_CONFIGS.get(url, {}).get("key", None),
+                )
                 for url in app.state.config.OLLAMA_BASE_URLS
             ]
             responses = await asyncio.gather(*tasks)
@@ -511,10 +530,18 @@ async def copy_model(
 
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     log.info(f"url: {url}")
+
+    api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+    key = api_config.get("key", None)
+
+    headers = {"Content-Type": "application/json"}
+    if key:
+        headers["Authorization"] = f"Bearer {key}"
+
     r = requests.request(
         method="POST",
         url=f"{url}/api/copy",
-        headers={"Content-Type": "application/json"},
+        headers=headers,
         data=form_data.model_dump_json(exclude_none=True).encode(),
     )
 
@@ -560,11 +587,18 @@ async def delete_model(
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     log.info(f"url: {url}")
 
+    api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+    key = api_config.get("key", None)
+
+    headers = {"Content-Type": "application/json"}
+    if key:
+        headers["Authorization"] = f"Bearer {key}"
+
     r = requests.request(
         method="DELETE",
         url=f"{url}/api/delete",
-        headers={"Content-Type": "application/json"},
         data=form_data.model_dump_json(exclude_none=True).encode(),
+        headers=headers,
     )
     try:
         r.raise_for_status()
@@ -601,10 +635,17 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     log.info(f"url: {url}")
 
+    api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+    key = api_config.get("key", None)
+
+    headers = {"Content-Type": "application/json"}
+    if key:
+        headers["Authorization"] = f"Bearer {key}"
+
     r = requests.request(
         method="POST",
         url=f"{url}/api/show",
-        headers={"Content-Type": "application/json"},
+        headers=headers,
         data=form_data.model_dump_json(exclude_none=True).encode(),
     )
     try:
@@ -686,10 +727,17 @@ def generate_ollama_embeddings(
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     log.info(f"url: {url}")
 
+    api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+    key = api_config.get("key", None)
+
+    headers = {"Content-Type": "application/json"}
+    if key:
+        headers["Authorization"] = f"Bearer {key}"
+
     r = requests.request(
         method="POST",
         url=f"{url}/api/embeddings",
-        headers={"Content-Type": "application/json"},
+        headers=headers,
         data=form_data.model_dump_json(exclude_none=True).encode(),
     )
     try:
@@ -743,10 +791,17 @@ def generate_ollama_batch_embeddings(
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     log.info(f"url: {url}")
 
+    api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+    key = api_config.get("key", None)
+
+    headers = {"Content-Type": "application/json"}
+    if key:
+        headers["Authorization"] = f"Bearer {key}"
+
     r = requests.request(
         method="POST",
         url=f"{url}/api/embed",
-        headers={"Content-Type": "application/json"},
+        headers=headers,
         data=form_data.model_dump_json(exclude_none=True).encode(),
     )
     try: