浏览代码

use all caps for constants

Michael Yang 1 年之前
父节点
当前提交
07d8d56177
共有 3 个文件被更改,包括 9 次插入9 次删除
  1. 1 1
      ollama/cmd/cli.py
  2. 2 2
      ollama/engine.py
  3. 6 6
      ollama/model.py

+ 1 - 1
ollama/cmd/cli.py

@@ -31,7 +31,7 @@ def main():
     )
 
     # create models home if it doesn't exist
-    os.makedirs(model.models_home, exist_ok=True)
+    os.makedirs(model.MODELS_CACHE_PATH, exist_ok=True)
 
     subparsers = parser.add_subparsers(
         title='commands',

+ 2 - 2
ollama/engine.py

@@ -7,7 +7,7 @@ from llama_cpp import Llama
 from ctransformers import AutoModelForCausalLM
 
 import ollama.prompt
-from ollama.model import models_home
+from ollama.model import MODELS_CACHE_PATH
 
 
 @contextmanager
@@ -30,7 +30,7 @@ def load(model_name, models={}):
     if not models.get(model_name, None):
         model_path = path.expanduser(model_name)
         if not path.exists(model_path):
-            model_path = path.join(models_home, model_name + ".bin")
+            model_path = path.join(MODELS_CACHE_PATH, model_name + ".bin")
 
         runners = {
             model_type: cls

+ 6 - 6
ollama/model.py

@@ -6,12 +6,12 @@ from urllib.parse import urlsplit, urlunsplit
 from tqdm import tqdm
 
 
-models_endpoint_url = 'https://ollama.ai/api/models'
-models_home = path.join(Path.home(), '.ollama', 'models')
+MODELS_MANIFEST = 'https://ollama.ai/api/models'
+MODELS_CACHE_PATH = path.join(Path.home(), '.ollama', 'models')
 
 
 def models(*args, **kwargs):
-    for _, _, files in walk(models_home):
+    for _, _, files in walk(MODELS_CACHE_PATH):
         for file in files:
             base, ext = path.splitext(file)
             if ext == '.bin':
@@ -20,7 +20,7 @@ def models(*args, **kwargs):
 
 # get the url of the model from our curated directory
 def get_url_from_directory(model):
-    response = requests.get(models_endpoint_url)
+    response = requests.get(MODELS_MANIFEST)
     response.raise_for_status()
     directory = response.json()
     for model_info in directory:
@@ -78,7 +78,7 @@ def find_bin_file(json_response, location, branch):
 
 
 def download_file(download_url, file_name, file_size):
-    local_filename = path.join(models_home, file_name) + '.bin'
+    local_filename = path.join(MODELS_CACHE_PATH, file_name) + '.bin'
 
     first_byte = path.getsize(local_filename) if path.exists(local_filename) else 0
 
@@ -125,7 +125,7 @@ def pull(model, *args, **kwargs):
         url = f'https://{url}'
 
     if not validators.url(url):
-        if model in models(models_home):
+        if model in models(MODELS_CACHE_PATH):
             # the model is already downloaded, and specified by name
             return model
         raise Exception(f'Unknown model {model}')