Browse Source

fix pull model name

Michael Yang 1 year ago
parent
commit
3ce3caaf65
2 changed files with 10 additions and 10 deletions
  1. 1 1
      ollama/cmd/cli.py
  2. 9 9
      ollama/model.py

+ 1 - 1
ollama/cmd/cli.py

@@ -151,7 +151,7 @@ def pull(*args, **kwargs):
 
 
 
 
 def run(*args, **kwargs):
 def run(*args, **kwargs):
-    name = model.pull(*args, **kwargs)
+    name = model.pull(model_name=kwargs.pop('model'), *args, **kwargs)
     kwargs.update({"model": name})
     kwargs.update({"model": name})
     print(f"Running {name}...")
     print(f"Running {name}...")
     generate(*args, **kwargs)
     generate(*args, **kwargs)

+ 9 - 9
ollama/model.py

@@ -110,25 +110,25 @@ def download_file(download_url, file_name, file_size):
     return local_filename
     return local_filename
 
 
 
 
-def pull(model, *args, **kwargs):
-    if path.exists(model):
+def pull(model_name, *args, **kwargs):
+    if path.exists(model_name):
         # a file on the filesystem is being specified
         # a file on the filesystem is being specified
-        return model
+        return model_name
     # check the remote model location and see if it needs to be downloaded
     # check the remote model location and see if it needs to be downloaded
-    url = model
+    url = model_name
     file_name = ""
     file_name = ""
     if not validators.url(url) and not url.startswith('huggingface.co'):
     if not validators.url(url) and not url.startswith('huggingface.co'):
-        url = get_url_from_directory(model)
-        file_name = model
+        url = get_url_from_directory(model_name)
+        file_name = model_name
 
 
     if not (url.startswith('http://') or url.startswith('https://')):
     if not (url.startswith('http://') or url.startswith('https://')):
         url = f'https://{url}'
         url = f'https://{url}'
 
 
     if not validators.url(url):
     if not validators.url(url):
-        if model in models(MODELS_CACHE_PATH):
+        if model_name in models(MODELS_CACHE_PATH):
             # the model is already downloaded, and specified by name
             # the model is already downloaded, and specified by name
-            return model
-        raise Exception(f'Unknown model {model}')
+            return model_name
+        raise Exception(f'Unknown model {model_name}')
 
 
     local_filename = download_from_repo(url, file_name)
     local_filename = download_from_repo(url, file_name)