浏览代码

consistency between generate and add naming

Bruce MacDonald 1 年之前
父节点
当前提交
01c31aac78
共有 3 个文件被更改,包括 38 次插入30 次删除
  1. 12 8
      ollama/cmd/cli.py
  2. 11 9
      ollama/engine.py
  3. 15 13
      ollama/model.py

+ 12 - 8
ollama/cmd/cli.py

@@ -79,14 +79,18 @@ def generate_oneshot(*args, **kwargs):
     spinner = yaspin()
     spinner = yaspin()
     spinner.start()
     spinner.start()
     spinner_running = True
     spinner_running = True
-    for output in engine.generate(*args, **kwargs):
-        choices = output.get("choices", [])
-        if len(choices) > 0:
-            if spinner_running:
-                spinner.stop()
-                spinner_running = False
-                print("\r", end="")  # move cursor back to beginning of line again
-            print(choices[0].get("text", ""), end="", flush=True)
+    try:
+        for output in engine.generate(*args, **kwargs):
+            choices = output.get("choices", [])
+            if len(choices) > 0:
+                if spinner_running:
+                    spinner.stop()
+                    spinner_running = False
+                    print("\r", end="")  # move cursor back to beginning of line again
+                print(choices[0].get("text", ""), end="", flush=True)
+    except Exception:
+        spinner.stop()
+        raise
 
 
     # end with a new line
     # end with a new line
     print(flush=True)
     print(flush=True)

+ 11 - 9
ollama/engine.py

@@ -1,5 +1,4 @@
-import os
-import json
+from os import path, dup, dup2, devnull
 import sys
 import sys
 from contextlib import contextmanager
 from contextlib import contextmanager
 from llama_cpp import Llama as LLM
 from llama_cpp import Llama as LLM
@@ -10,12 +9,12 @@ import ollama.prompt
 
 
 @contextmanager
 @contextmanager
 def suppress_stderr():
 def suppress_stderr():
-    stderr = os.dup(sys.stderr.fileno())
-    with open(os.devnull, "w") as devnull:
-        os.dup2(devnull.fileno(), sys.stderr.fileno())
+    stderr = dup(sys.stderr.fileno())
+    with open(devnull, "w") as devnull:
+        dup2(devnull.fileno(), sys.stderr.fileno())
         yield
         yield
 
 
-    os.dup2(stderr, sys.stderr.fileno())
+    dup2(stderr, sys.stderr.fileno())
 
 
 
 
 def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
 def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
@@ -38,12 +37,15 @@ def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
 def load(model, models_home=".", llms={}):
 def load(model, models_home=".", llms={}):
     llm = llms.get(model, None)
     llm = llms.get(model, None)
     if not llm:
     if not llm:
-        stored_model_path = os.path.join(models_home, model, ".bin")
-        if os.path.exists(stored_model_path):
+        stored_model_path = path.join(models_home, model) + ".bin"
+        if path.exists(stored_model_path):
             model_path = stored_model_path
             model_path = stored_model_path
         else:
         else:
             # try loading this as a path to a model, rather than a model name
             # try loading this as a path to a model, rather than a model name
-            model_path = os.path.abspath(model)
+            model_path = path.abspath(model)
+
+        if not path.exists(model_path):
+            raise Exception(f"Model not found: {model}")
 
 
         try:
         try:
             # suppress LLM's output
             # suppress LLM's output

+ 15 - 13
ollama/model.py

@@ -1,6 +1,6 @@
-import os
 import requests
 import requests
 import validators
 import validators
+from os import path, walk
 from urllib.parse import urlsplit, urlunsplit
 from urllib.parse import urlsplit, urlunsplit
 from tqdm import tqdm
 from tqdm import tqdm
 
 
@@ -9,9 +9,9 @@ models_endpoint_url = 'https://ollama.ai/api/models'
 
 
 
 
 def models(models_home='.', *args, **kwargs):
 def models(models_home='.', *args, **kwargs):
-    for _, _, files in os.walk(models_home):
+    for _, _, files in walk(models_home):
         for file in files:
         for file in files:
-            base, ext = os.path.splitext(file)
+            base, ext = path.splitext(file)
             if ext == '.bin':
             if ext == '.bin':
                 yield base
                 yield base
 
 
@@ -27,7 +27,7 @@ def get_url_from_directory(model):
     return model
     return model
 
 
 
 
-def download_from_repo(url, models_home='.'):
+def download_from_repo(url, file_name, models_home='.'):
     parts = urlsplit(url)
     parts = urlsplit(url)
     path_parts = parts.path.split('/tree/')
     path_parts = parts.path.split('/tree/')
 
 
@@ -38,6 +38,8 @@ def download_from_repo(url, models_home='.'):
         location, branch = path_parts
         location, branch = path_parts
 
 
     location = location.strip('/')
     location = location.strip('/')
+    if file_name == '':
+        file_name = path.basename(location)
 
 
     download_url = urlunsplit(
     download_url = urlunsplit(
         (
         (
@@ -53,7 +55,7 @@ def download_from_repo(url, models_home='.'):
     json_response = response.json()
     json_response = response.json()
 
 
     download_url, file_size = find_bin_file(json_response, location, branch)
     download_url, file_size = find_bin_file(json_response, location, branch)
-    return download_file(download_url, models_home, location, file_size)
+    return download_file(download_url, models_home, file_name, file_size)
 
 
 
 
 def find_bin_file(json_response, location, branch):
 def find_bin_file(json_response, location, branch):
@@ -73,17 +75,15 @@ def find_bin_file(json_response, location, branch):
     return download_url, file_size
     return download_url, file_size
 
 
 
 
-def download_file(download_url, models_home, location, file_size):
-    local_filename = os.path.join(models_home, os.path.basename(location)) + '.bin'
+def download_file(download_url, models_home, file_name, file_size):
+    local_filename = path.join(models_home, file_name) + '.bin'
 
 
-    first_byte = (
-        os.path.getsize(local_filename) if os.path.exists(local_filename) else 0
-    )
+    first_byte = path.getsize(local_filename) if path.exists(local_filename) else 0
 
 
     if first_byte >= file_size:
     if first_byte >= file_size:
         return local_filename
         return local_filename
 
 
-    print(f'Pulling {os.path.basename(location)}...')
+    print(f'Pulling {file_name}...')
 
 
     header = {'Range': f'bytes={first_byte}-'} if first_byte != 0 else {}
     header = {'Range': f'bytes={first_byte}-'} if first_byte != 0 else {}
 
 
@@ -109,13 +109,15 @@ def download_file(download_url, models_home, location, file_size):
 
 
 
 
 def pull(model, models_home='.', *args, **kwargs):
 def pull(model, models_home='.', *args, **kwargs):
-    if os.path.exists(model):
+    if path.exists(model):
         # a file on the filesystem is being specified
         # a file on the filesystem is being specified
         return model
         return model
     # check the remote model location and see if it needs to be downloaded
     # check the remote model location and see if it needs to be downloaded
     url = model
     url = model
+    file_name = ""
     if not validators.url(url) and not url.startswith('huggingface.co'):
     if not validators.url(url) and not url.startswith('huggingface.co'):
         url = get_url_from_directory(model)
         url = get_url_from_directory(model)
+        file_name = model
 
 
     if not (url.startswith('http://') or url.startswith('https://')):
     if not (url.startswith('http://') or url.startswith('https://')):
         url = f'https://{url}'
         url = f'https://{url}'
@@ -126,6 +128,6 @@ def pull(model, models_home='.', *args, **kwargs):
             return model
             return model
         raise Exception(f'Unknown model {model}')
         raise Exception(f'Unknown model {model}')
 
 
-    local_filename = download_from_repo(url, models_home)
+    local_filename = download_from_repo(url, file_name, models_home)
 
 
     return local_filename
     return local_filename