瀏覽代碼

Merge pull request #7326 from bnodnarb/fix/ollama-authentication

fix: Include Authorization header in /api/pull and /api/chat requests
Timothy Jaeryang Baek 5 月之前
父節點
當前提交
da4676de2e
共有 1 個文件被更改,包括 35 次插入11 次删除
  1. 35 11
      backend/open_webui/apps/ollama/main.py

+ 35 - 11
backend/open_webui/apps/ollama/main.py

@@ -195,7 +195,10 @@ async def post_streaming_url(
             trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
             trust_env=True, timeout=aiohttp.ClientTimeout(total=AIOHTTP_CLIENT_TIMEOUT)
         )
         )
 
 
-        api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+        parsed_url = urlparse(url)
+        base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
+
+        api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
         key = api_config.get("key", None)
         key = api_config.get("key", None)
 
 
         headers = {"Content-Type": "application/json"}
         headers = {"Content-Type": "application/json"}
@@ -210,13 +213,13 @@ async def post_streaming_url(
         r.raise_for_status()
         r.raise_for_status()
 
 
         if stream:
         if stream:
-            headers = dict(r.headers)
+            response_headers = dict(r.headers)
             if content_type:
             if content_type:
-                headers["Content-Type"] = content_type
+                response_headers["Content-Type"] = content_type
             return StreamingResponse(
             return StreamingResponse(
                 r.content,
                 r.content,
                 status_code=r.status,
                 status_code=r.status,
-                headers=headers,
+                headers=response_headers,
                 background=BackgroundTask(
                 background=BackgroundTask(
                     cleanup_response, response=r, session=session
                     cleanup_response, response=r, session=session
                 ),
                 ),
@@ -324,7 +327,10 @@ async def get_ollama_tags(
     else:
     else:
         url = app.state.config.OLLAMA_BASE_URLS[url_idx]
         url = app.state.config.OLLAMA_BASE_URLS[url_idx]
 
 
-        api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+        parsed_url = urlparse(url)
+        base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
+
+        api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
         key = api_config.get("key", None)
         key = api_config.get("key", None)
 
 
         headers = {}
         headers = {}
@@ -525,7 +531,10 @@ async def copy_model(
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     log.info(f"url: {url}")
     log.info(f"url: {url}")
 
 
-    api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+    parsed_url = urlparse(url)
+    base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
+
+    api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
     key = api_config.get("key", None)
     key = api_config.get("key", None)
 
 
     headers = {"Content-Type": "application/json"}
     headers = {"Content-Type": "application/json"}
@@ -584,7 +593,10 @@ async def delete_model(
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     log.info(f"url: {url}")
     log.info(f"url: {url}")
 
 
-    api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+    parsed_url = urlparse(url)
+    base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
+
+    api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
     key = api_config.get("key", None)
     key = api_config.get("key", None)
 
 
     headers = {"Content-Type": "application/json"}
     headers = {"Content-Type": "application/json"}
@@ -635,7 +647,10 @@ async def show_model_info(form_data: ModelNameForm, user=Depends(get_verified_us
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     log.info(f"url: {url}")
     log.info(f"url: {url}")
 
 
-    api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+    parsed_url = urlparse(url)
+    base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
+
+    api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
     key = api_config.get("key", None)
     key = api_config.get("key", None)
 
 
     headers = {"Content-Type": "application/json"}
     headers = {"Content-Type": "application/json"}
@@ -730,7 +745,10 @@ async def generate_ollama_embeddings(
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     log.info(f"url: {url}")
     log.info(f"url: {url}")
 
 
-    api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+    parsed_url = urlparse(url)
+    base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
+
+    api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
     key = api_config.get("key", None)
     key = api_config.get("key", None)
 
 
     headers = {"Content-Type": "application/json"}
     headers = {"Content-Type": "application/json"}
@@ -797,7 +815,10 @@ async def generate_ollama_batch_embeddings(
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     url = app.state.config.OLLAMA_BASE_URLS[url_idx]
     log.info(f"url: {url}")
     log.info(f"url: {url}")
 
 
-    api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+    parsed_url = urlparse(url)
+    base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
+
+    api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
     key = api_config.get("key", None)
     key = api_config.get("key", None)
 
 
     headers = {"Content-Type": "application/json"}
     headers = {"Content-Type": "application/json"}
@@ -974,7 +995,10 @@ async def generate_chat_completion(
     log.info(f"url: {url}")
     log.info(f"url: {url}")
     log.debug(f"generate_chat_completion() - 2.payload = {payload}")
     log.debug(f"generate_chat_completion() - 2.payload = {payload}")
 
 
-    api_config = app.state.config.OLLAMA_API_CONFIGS.get(url, {})
+    parsed_url = urlparse(url)
+    base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
+
+    api_config = app.state.config.OLLAMA_API_CONFIGS.get(base_url, {})
     prefix_id = api_config.get("prefix_id", None)
     prefix_id = api_config.get("prefix_id", None)
     if prefix_id:
     if prefix_id:
         payload["model"] = payload["model"].replace(f"{prefix_id}.", "")
         payload["model"] = payload["model"].replace(f"{prefix_id}.", "")