model.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import requests
  2. import validators
  3. from pathlib import Path
  4. from os import path, walk
  5. from urllib.parse import urlsplit, urlunsplit
  6. from tqdm import tqdm
  7. MODELS_MANIFEST = 'https://ollama.ai/api/models'
  8. MODELS_CACHE_PATH = Path.home() / '.ollama' / 'models'
  9. def models(*args, **kwargs):
  10. for _, _, files in walk(MODELS_CACHE_PATH):
  11. for file in files:
  12. base, ext = path.splitext(file)
  13. if ext == '.bin':
  14. yield base
  15. # get the url of the model from our curated directory
  16. def get_url_from_directory(model):
  17. response = requests.get(MODELS_MANIFEST)
  18. response.raise_for_status()
  19. directory = response.json()
  20. for model_info in directory:
  21. if model_info.get('name') == model:
  22. return model_info.get('url')
  23. return model
  24. def download_from_repo(url, file_name):
  25. parts = urlsplit(url)
  26. path_parts = parts.path.split('/tree/')
  27. if len(path_parts) == 1:
  28. location = path_parts[0]
  29. branch = 'main'
  30. else:
  31. location, branch = path_parts
  32. location = location.strip('/')
  33. if file_name == '':
  34. file_name = path.basename(location).lower()
  35. download_url = urlunsplit(
  36. (
  37. 'https',
  38. parts.netloc,
  39. f'/api/models/{location}/tree/{branch}',
  40. parts.query,
  41. parts.fragment,
  42. )
  43. )
  44. response = requests.get(download_url)
  45. response.raise_for_status()
  46. json_response = response.json()
  47. download_url, file_size = find_bin_file(json_response, location, branch)
  48. return download_file(download_url, file_name, file_size)
  49. def find_bin_file(json_response, location, branch):
  50. download_url = None
  51. file_size = 0
  52. for file_info in json_response:
  53. if file_info.get('type') == 'file' and file_info.get('path').endswith('.bin'):
  54. f_path = file_info.get('path')
  55. download_url = (
  56. f'https://huggingface.co/{location}/resolve/{branch}/{f_path}'
  57. )
  58. file_size = file_info.get('size')
  59. if download_url is None:
  60. raise Exception('No model found')
  61. return download_url, file_size
  62. def download_file(download_url, file_name, file_size):
  63. local_filename = MODELS_CACHE_PATH / file_name + '.bin'
  64. first_byte = path.getsize(local_filename) if path.exists(local_filename) else 0
  65. if first_byte >= file_size:
  66. return local_filename
  67. print(f'Pulling {file_name}...')
  68. header = {'Range': f'bytes={first_byte}-'} if first_byte != 0 else {}
  69. response = requests.get(download_url, headers=header, stream=True)
  70. response.raise_for_status()
  71. total_size = int(response.headers.get('content-length', 0)) + first_byte
  72. with open(local_filename, 'ab' if first_byte else 'wb') as file, tqdm(
  73. total=total_size,
  74. unit='iB',
  75. unit_scale=True,
  76. unit_divisor=1024,
  77. initial=first_byte,
  78. ascii=' ==',
  79. bar_format='Downloading [{bar}] {percentage:3.2f}% {rate_fmt}{postfix}',
  80. ) as bar:
  81. for data in response.iter_content(chunk_size=1024):
  82. size = file.write(data)
  83. bar.update(size)
  84. return local_filename
  85. def pull(model_name, *args, **kwargs):
  86. if path.exists(model_name):
  87. # a file on the filesystem is being specified
  88. return model_name
  89. # check the remote model location and see if it needs to be downloaded
  90. url = model_name
  91. file_name = ""
  92. if not validators.url(url) and not url.startswith('huggingface.co'):
  93. url = get_url_from_directory(model_name)
  94. file_name = model_name
  95. if not (url.startswith('http://') or url.startswith('https://')):
  96. url = f'https://{url}'
  97. if not validators.url(url):
  98. if model_name in models(MODELS_CACHE_PATH):
  99. # the model is already downloaded, and specified by name
  100. return model_name
  101. raise Exception(f'Unknown model {model_name}')
  102. local_filename = download_from_repo(url, file_name)
  103. return local_filename