Переглянути джерело

Merge pull request #29 from jmorganca/pull-model-name

Pull model name
Michael Yang 1 рік тому
батько
коміт
9e5dfc66a3
3 змінених файлів з 13 додано та 13 видалено
  1. 1 1
      ollama/cmd/cli.py
  2. 1 1
      ollama/engine.py
  3. 11 11
      ollama/model.py

+ 1 - 1
ollama/cmd/cli.py

@@ -151,7 +151,7 @@ def pull(*args, **kwargs):
 
 
 def run(*args, **kwargs):
-    name = model.pull(*args, **kwargs)
+    name = model.pull(model_name=kwargs.pop('model'), *args, **kwargs)
     kwargs.update({"model": name})
     print(f"Running {name}...")
     generate(*args, **kwargs)

+ 1 - 1
ollama/engine.py

@@ -30,7 +30,7 @@ def load(model_name, models={}):
     if not models.get(model_name, None):
         model_path = path.expanduser(model_name)
         if not path.exists(model_path):
-            model_path = path.join(MODELS_CACHE_PATH, model_name + ".bin")
+            model_path = MODELS_CACHE_PATH / model_name + ".bin"
 
         runners = {
             model_type: cls

+ 11 - 11
ollama/model.py

@@ -7,7 +7,7 @@ from tqdm import tqdm
 
 
 MODELS_MANIFEST = 'https://ollama.ai/api/models'
-MODELS_CACHE_PATH = path.join(Path.home(), '.ollama', 'models')
+MODELS_CACHE_PATH = Path.home() / '.ollama' / 'models'
 
 
 def models(*args, **kwargs):
@@ -78,7 +78,7 @@ def find_bin_file(json_response, location, branch):
 
 
 def download_file(download_url, file_name, file_size):
-    local_filename = path.join(MODELS_CACHE_PATH, file_name) + '.bin'
+    local_filename = MODELS_CACHE_PATH / file_name + '.bin'
 
     first_byte = path.getsize(local_filename) if path.exists(local_filename) else 0
 
@@ -110,25 +110,25 @@ def download_file(download_url, file_name, file_size):
     return local_filename
 
 
-def pull(model, *args, **kwargs):
-    if path.exists(model):
+def pull(model_name, *args, **kwargs):
+    if path.exists(model_name):
         # a file on the filesystem is being specified
-        return model
+        return model_name
     # check the remote model location and see if it needs to be downloaded
-    url = model
+    url = model_name
     file_name = ""
     if not validators.url(url) and not url.startswith('huggingface.co'):
-        url = get_url_from_directory(model)
-        file_name = model
+        url = get_url_from_directory(model_name)
+        file_name = model_name
 
     if not (url.startswith('http://') or url.startswith('https://')):
         url = f'https://{url}'
 
     if not validators.url(url):
-        if model in models(MODELS_CACHE_PATH):
+        if model_name in models(MODELS_CACHE_PATH):
             # the model is already downloaded, and specified by name
-            return model
-        raise Exception(f'Unknown model {model}')
+            return model_name
+        raise Exception(f'Unknown model {model_name}')
 
     local_filename = download_from_repo(url, file_name)