浏览代码

remove models home param

Bruce MacDonald 1 年之前
父节点
当前提交
a11cddbf99
共有 4 个文件被更改,包括 14 次插入19 次删除
  1. 1 5
      ollama/cmd/cli.py
  2. 1 4
      ollama/cmd/server.py
  3. 4 4
      ollama/engine.py
  4. 8 6
      ollama/model.py

+ 1 - 5
ollama/cmd/cli.py

@@ -1,6 +1,5 @@
 import os
 import os
 import sys
 import sys
-from pathlib import Path
 from argparse import ArgumentParser
 from argparse import ArgumentParser
 from yaspin import yaspin
 from yaspin import yaspin
 
 
@@ -10,12 +9,9 @@ from ollama.cmd import server
 
 
 def main():
 def main():
     parser = ArgumentParser()
     parser = ArgumentParser()
-    parser.add_argument("--models-home", default=Path.home() / ".ollama" / "models")
 
 
     # create models home if it doesn't exist
     # create models home if it doesn't exist
-    models_home = parser.parse_known_args()[0].models_home
-    if not models_home.exists():
-        os.makedirs(models_home)
+    os.makedirs(model.models_home, exist_ok=True)
 
 
     subparsers = parser.add_subparsers()
     subparsers = parser.add_subparsers()
 
 

+ 1 - 4
ollama/cmd/server.py

@@ -11,7 +11,7 @@ def set_parser(parser):
     parser.set_defaults(fn=serve)
     parser.set_defaults(fn=serve)
 
 
 
 
-def serve(models_home=".", *args, **kwargs):
+def serve(*args, **kwargs):
     app = web.Application()
     app = web.Application()
 
 
     cors = aiohttp_cors.setup(
     cors = aiohttp_cors.setup(
@@ -39,7 +39,6 @@ def serve(models_home=".", *args, **kwargs):
     app.update(
     app.update(
         {
         {
             "llms": {},
             "llms": {},
-            "models_home": models_home,
         }
         }
     )
     )
 
 
@@ -54,7 +53,6 @@ async def load(request):
 
 
     kwargs = {
     kwargs = {
         "llms": request.app.get("llms"),
         "llms": request.app.get("llms"),
-        "models_home": request.app.get("models_home"),
     }
     }
 
 
     engine.load(model, **kwargs)
     engine.load(model, **kwargs)
@@ -86,7 +84,6 @@ async def generate(request):
 
 
     kwargs = {
     kwargs = {
         "llms": request.app.get("llms"),
         "llms": request.app.get("llms"),
-        "models_home": request.app.get("models_home"),
     }
     }
 
 
     for output in engine.generate(model, prompt, **kwargs):
     for output in engine.generate(model, prompt, **kwargs):

+ 4 - 4
ollama/engine.py

@@ -18,8 +18,8 @@ def suppress_stderr():
     os.dup2(stderr, sys.stderr.fileno())
     os.dup2(stderr, sys.stderr.fileno())
 
 
 
 
-def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
-    llm = load(model, models_home=models_home, llms=llms)
+def generate(model, prompt, llms={}, *args, **kwargs):
+    llm = load(model, llms=llms)
 
 
     prompt = ollama.prompt.template(model, prompt)
     prompt = ollama.prompt.template(model, prompt)
     if "max_tokens" not in kwargs:
     if "max_tokens" not in kwargs:
@@ -35,10 +35,10 @@ def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
         yield output
         yield output
 
 
 
 
-def load(model, models_home=".", llms={}):
+def load(model, llms={}):
     llm = llms.get(model, None)
     llm = llms.get(model, None)
     if not llm:
     if not llm:
-        stored_model_path = path.join(models_home, model) + ".bin"
+        stored_model_path = path.join(ollama.model.models_home, model) + ".bin"
         if path.exists(stored_model_path):
         if path.exists(stored_model_path):
             model_path = stored_model_path
             model_path = stored_model_path
         else:
         else:

+ 8 - 6
ollama/model.py

@@ -1,14 +1,16 @@
 import requests
 import requests
 import validators
 import validators
+from pathlib import Path
 from os import path, walk
 from os import path, walk
 from urllib.parse import urlsplit, urlunsplit
 from urllib.parse import urlsplit, urlunsplit
 from tqdm import tqdm
 from tqdm import tqdm
 
 
 
 
 models_endpoint_url = 'https://ollama.ai/api/models'
 models_endpoint_url = 'https://ollama.ai/api/models'
+models_home = path.join(Path.home(), '.ollama', 'models')
 
 
 
 
-def models(models_home='.', *args, **kwargs):
+def models(*args, **kwargs):
     for _, _, files in walk(models_home):
     for _, _, files in walk(models_home):
         for file in files:
         for file in files:
             base, ext = path.splitext(file)
             base, ext = path.splitext(file)
@@ -27,7 +29,7 @@ def get_url_from_directory(model):
     return model
     return model
 
 
 
 
-def download_from_repo(url, file_name, models_home='.'):
+def download_from_repo(url, file_name):
     parts = urlsplit(url)
     parts = urlsplit(url)
     path_parts = parts.path.split('/tree/')
     path_parts = parts.path.split('/tree/')
 
 
@@ -55,7 +57,7 @@ def download_from_repo(url, file_name, models_home='.'):
     json_response = response.json()
     json_response = response.json()
 
 
     download_url, file_size = find_bin_file(json_response, location, branch)
     download_url, file_size = find_bin_file(json_response, location, branch)
-    return download_file(download_url, models_home, file_name, file_size)
+    return download_file(download_url, file_name, file_size)
 
 
 
 
 def find_bin_file(json_response, location, branch):
 def find_bin_file(json_response, location, branch):
@@ -75,7 +77,7 @@ def find_bin_file(json_response, location, branch):
     return download_url, file_size
     return download_url, file_size
 
 
 
 
-def download_file(download_url, models_home, file_name, file_size):
+def download_file(download_url, file_name, file_size):
     local_filename = path.join(models_home, file_name) + '.bin'
     local_filename = path.join(models_home, file_name) + '.bin'
 
 
     first_byte = path.getsize(local_filename) if path.exists(local_filename) else 0
     first_byte = path.getsize(local_filename) if path.exists(local_filename) else 0
@@ -108,7 +110,7 @@ def download_file(download_url, models_home, file_name, file_size):
     return local_filename
     return local_filename
 
 
 
 
-def pull(model, models_home='.', *args, **kwargs):
+def pull(model, *args, **kwargs):
     if path.exists(model):
     if path.exists(model):
         # a file on the filesystem is being specified
         # a file on the filesystem is being specified
         return model
         return model
@@ -128,6 +130,6 @@ def pull(model, models_home='.', *args, **kwargs):
             return model
             return model
         raise Exception(f'Unknown model {model}')
         raise Exception(f'Unknown model {model}')
 
 
-    local_filename = download_from_repo(url, file_name, models_home)
+    local_filename = download_from_repo(url, file_name)
 
 
     return local_filename
     return local_filename