|
@@ -1029,6 +1029,14 @@ async def download_model(
|
|
url_idx: Optional[int] = None,
|
|
url_idx: Optional[int] = None,
|
|
):
|
|
):
|
|
|
|
|
|
|
|
+ allowed_hosts = ["https://huggingface.co/", "https://github.com/"]
|
|
|
|
+
|
|
|
|
+ if not any(form_data.url.startswith(host) for host in allowed_hosts):
|
|
|
|
+ raise HTTPException(
|
|
|
|
+ status_code=400,
|
|
|
|
+ detail="Invalid file_url. Only URLs from allowed hosts are permitted.",
|
|
|
|
+ )
|
|
|
|
+
|
|
if url_idx == None:
|
|
if url_idx == None:
|
|
url_idx = 0
|
|
url_idx = 0
|
|
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
|
url = app.state.OLLAMA_BASE_URLS[url_idx]
|
|
@@ -1037,6 +1045,7 @@ async def download_model(
|
|
|
|
|
|
if file_name:
|
|
if file_name:
|
|
file_path = f"{UPLOAD_DIR}/{file_name}"
|
|
file_path = f"{UPLOAD_DIR}/{file_name}"
|
|
|
|
+
|
|
return StreamingResponse(
|
|
return StreamingResponse(
|
|
download_file_stream(url, form_data.url, file_path, file_name),
|
|
download_file_stream(url, form_data.url, file_path, file_name),
|
|
)
|
|
)
|