Browse Source

fix: download allowed hosts

Timothy J. Baek 1 year ago
parent
commit
d72653cdea
2 changed files with 12 additions and 7 deletions
  1. 9 7
      backend/apps/ollama/main.py
  2. 3 0
      src/lib/components/chat/Settings/Models.svelte

+ 9 - 7
backend/apps/ollama/main.py

@@ -970,13 +970,6 @@ def parse_huggingface_url(hf_url):
 async def download_file_stream(
     ollama_url, file_url, file_path, file_name, chunk_size=1024 * 1024
 ):
-    allowed_hosts = ["https://huggingface.co/", "https://github.com/"]
-
-    if not any(file_url.startswith(host) for host in allowed_hosts):
-        raise ValueError(
-            "Invalid file_url. Only URLs from allowed hosts are permitted."
-        )
-
     done = False
 
     if os.path.exists(file_path):
@@ -1036,6 +1029,14 @@ async def download_model(
     url_idx: Optional[int] = None,
 ):
 
+    allowed_hosts = ["https://huggingface.co/", "https://github.com/"]
+
+    if not any(form_data.url.startswith(host) for host in allowed_hosts):
+        raise HTTPException(
+            status_code=400,
+            detail="Invalid file_url. Only URLs from allowed hosts are permitted.",
+        )
+
     if url_idx == None:
         url_idx = 0
     url = app.state.OLLAMA_BASE_URLS[url_idx]
@@ -1044,6 +1045,7 @@ async def download_model(
 
     if file_name:
         file_path = f"{UPLOAD_DIR}/{file_name}"
+
         return StreamingResponse(
             download_file_stream(url, form_data.url, file_path, file_name),
         )

+ 3 - 0
src/lib/components/chat/Settings/Models.svelte

@@ -258,6 +258,9 @@
 					console.log(error);
 				}
 			}
+		} else {
+			const error = await fileResponse?.json();
+			toast.error(error?.detail ?? error);
 		}
 
 		if (uploaded) {