|
@@ -1,6 +1,6 @@
|
|
-import os
|
|
|
|
import requests
|
|
import requests
|
|
import validators
|
|
import validators
|
|
|
|
+from os import path, walk
|
|
from urllib.parse import urlsplit, urlunsplit
|
|
from urllib.parse import urlsplit, urlunsplit
|
|
from tqdm import tqdm
|
|
from tqdm import tqdm
|
|
|
|
|
|
@@ -9,9 +9,9 @@ models_endpoint_url = 'https://ollama.ai/api/models'
|
|
|
|
|
|
|
|
|
|
def models(models_home='.', *args, **kwargs):
|
|
def models(models_home='.', *args, **kwargs):
|
|
- for _, _, files in os.walk(models_home):
|
|
|
|
|
|
+ for _, _, files in walk(models_home):
|
|
for file in files:
|
|
for file in files:
|
|
- base, ext = os.path.splitext(file)
|
|
|
|
|
|
+ base, ext = path.splitext(file)
|
|
if ext == '.bin':
|
|
if ext == '.bin':
|
|
yield base
|
|
yield base
|
|
|
|
|
|
@@ -27,7 +27,7 @@ def get_url_from_directory(model):
|
|
return model
|
|
return model
|
|
|
|
|
|
|
|
|
|
-def download_from_repo(url, models_home='.'):
|
|
|
|
|
|
+def download_from_repo(url, file_name, models_home='.'):
|
|
parts = urlsplit(url)
|
|
parts = urlsplit(url)
|
|
path_parts = parts.path.split('/tree/')
|
|
path_parts = parts.path.split('/tree/')
|
|
|
|
|
|
@@ -38,6 +38,8 @@ def download_from_repo(url, models_home='.'):
|
|
location, branch = path_parts
|
|
location, branch = path_parts
|
|
|
|
|
|
location = location.strip('/')
|
|
location = location.strip('/')
|
|
|
|
+ if file_name == '':
|
|
|
|
+ file_name = path.basename(location)
|
|
|
|
|
|
download_url = urlunsplit(
|
|
download_url = urlunsplit(
|
|
(
|
|
(
|
|
@@ -53,7 +55,7 @@ def download_from_repo(url, models_home='.'):
|
|
json_response = response.json()
|
|
json_response = response.json()
|
|
|
|
|
|
download_url, file_size = find_bin_file(json_response, location, branch)
|
|
download_url, file_size = find_bin_file(json_response, location, branch)
|
|
- return download_file(download_url, models_home, location, file_size)
|
|
|
|
|
|
+ return download_file(download_url, models_home, file_name, file_size)
|
|
|
|
|
|
|
|
|
|
def find_bin_file(json_response, location, branch):
|
|
def find_bin_file(json_response, location, branch):
|
|
@@ -73,17 +75,15 @@ def find_bin_file(json_response, location, branch):
|
|
return download_url, file_size
|
|
return download_url, file_size
|
|
|
|
|
|
|
|
|
|
-def download_file(download_url, models_home, location, file_size):
|
|
|
|
- local_filename = os.path.join(models_home, os.path.basename(location)) + '.bin'
|
|
|
|
|
|
+def download_file(download_url, models_home, file_name, file_size):
|
|
|
|
+ local_filename = path.join(models_home, file_name) + '.bin'
|
|
|
|
|
|
- first_byte = (
|
|
|
|
- os.path.getsize(local_filename) if os.path.exists(local_filename) else 0
|
|
|
|
- )
|
|
|
|
|
|
+ first_byte = path.getsize(local_filename) if path.exists(local_filename) else 0
|
|
|
|
|
|
if first_byte >= file_size:
|
|
if first_byte >= file_size:
|
|
return local_filename
|
|
return local_filename
|
|
|
|
|
|
- print(f'Pulling {os.path.basename(location)}...')
|
|
|
|
|
|
+ print(f'Pulling {file_name}...')
|
|
|
|
|
|
header = {'Range': f'bytes={first_byte}-'} if first_byte != 0 else {}
|
|
header = {'Range': f'bytes={first_byte}-'} if first_byte != 0 else {}
|
|
|
|
|
|
@@ -109,13 +109,15 @@ def download_file(download_url, models_home, location, file_size):
|
|
|
|
|
|
|
|
|
|
def pull(model, models_home='.', *args, **kwargs):
|
|
def pull(model, models_home='.', *args, **kwargs):
|
|
- if os.path.exists(model):
|
|
|
|
|
|
+ if path.exists(model):
|
|
# a file on the filesystem is being specified
|
|
# a file on the filesystem is being specified
|
|
return model
|
|
return model
|
|
# check the remote model location and see if it needs to be downloaded
|
|
# check the remote model location and see if it needs to be downloaded
|
|
url = model
|
|
url = model
|
|
|
|
+ file_name = ""
|
|
if not validators.url(url) and not url.startswith('huggingface.co'):
|
|
if not validators.url(url) and not url.startswith('huggingface.co'):
|
|
url = get_url_from_directory(model)
|
|
url = get_url_from_directory(model)
|
|
|
|
+ file_name = model
|
|
|
|
|
|
if not (url.startswith('http://') or url.startswith('https://')):
|
|
if not (url.startswith('http://') or url.startswith('https://')):
|
|
url = f'https://{url}'
|
|
url = f'https://{url}'
|
|
@@ -126,6 +128,6 @@ def pull(model, models_home='.', *args, **kwargs):
|
|
return model
|
|
return model
|
|
raise Exception(f'Unknown model {model}')
|
|
raise Exception(f'Unknown model {model}')
|
|
|
|
|
|
- local_filename = download_from_repo(url, models_home)
|
|
|
|
|
|
+ local_filename = download_from_repo(url, file_name, models_home)
|
|
|
|
|
|
return local_filename
|
|
return local_filename
|