Procházet zdrojové kódy

remove models home param

Bruce MacDonald před 1 rokem
rodič
revize
a11cddbf99
4 změnil soubory, kde provedl 14 přidání a 19 odebrání
  1. 1 5
      ollama/cmd/cli.py
  2. 1 4
      ollama/cmd/server.py
  3. 4 4
      ollama/engine.py
  4. 8 6
      ollama/model.py

+ 1 - 5
ollama/cmd/cli.py

@@ -1,6 +1,5 @@
 import os
 import sys
-from pathlib import Path
 from argparse import ArgumentParser
 from yaspin import yaspin
 
@@ -10,12 +9,9 @@ from ollama.cmd import server
 
 def main():
     parser = ArgumentParser()
-    parser.add_argument("--models-home", default=Path.home() / ".ollama" / "models")
 
     # create models home if it doesn't exist
-    models_home = parser.parse_known_args()[0].models_home
-    if not models_home.exists():
-        os.makedirs(models_home)
+    os.makedirs(model.models_home, exist_ok=True)
 
     subparsers = parser.add_subparsers()
 

+ 1 - 4
ollama/cmd/server.py

@@ -11,7 +11,7 @@ def set_parser(parser):
     parser.set_defaults(fn=serve)
 
 
-def serve(models_home=".", *args, **kwargs):
+def serve(*args, **kwargs):
     app = web.Application()
 
     cors = aiohttp_cors.setup(
@@ -39,7 +39,6 @@ def serve(models_home=".", *args, **kwargs):
     app.update(
         {
             "llms": {},
-            "models_home": models_home,
         }
     )
 
@@ -54,7 +53,6 @@ async def load(request):
 
     kwargs = {
         "llms": request.app.get("llms"),
-        "models_home": request.app.get("models_home"),
     }
 
     engine.load(model, **kwargs)
@@ -86,7 +84,6 @@ async def generate(request):
 
     kwargs = {
         "llms": request.app.get("llms"),
-        "models_home": request.app.get("models_home"),
     }
 
     for output in engine.generate(model, prompt, **kwargs):

+ 4 - 4
ollama/engine.py

@@ -18,8 +18,8 @@ def suppress_stderr():
     os.dup2(stderr, sys.stderr.fileno())
 
 
-def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
-    llm = load(model, models_home=models_home, llms=llms)
+def generate(model, prompt, llms={}, *args, **kwargs):
+    llm = load(model, llms=llms)
 
     prompt = ollama.prompt.template(model, prompt)
     if "max_tokens" not in kwargs:
@@ -35,10 +35,10 @@ def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
         yield output
 
 
-def load(model, models_home=".", llms={}):
+def load(model, llms={}):
     llm = llms.get(model, None)
     if not llm:
-        stored_model_path = path.join(models_home, model) + ".bin"
+        stored_model_path = path.join(ollama.model.models_home, model) + ".bin"
         if path.exists(stored_model_path):
             model_path = stored_model_path
         else:

+ 8 - 6
ollama/model.py

@@ -1,14 +1,16 @@
 import requests
 import validators
+from pathlib import Path
 from os import path, walk
 from urllib.parse import urlsplit, urlunsplit
 from tqdm import tqdm
 
 
 models_endpoint_url = 'https://ollama.ai/api/models'
+models_home = path.join(Path.home(), '.ollama', 'models')
 
 
-def models(models_home='.', *args, **kwargs):
+def models(*args, **kwargs):
     for _, _, files in walk(models_home):
         for file in files:
             base, ext = path.splitext(file)
@@ -27,7 +29,7 @@ def get_url_from_directory(model):
     return model
 
 
-def download_from_repo(url, file_name, models_home='.'):
+def download_from_repo(url, file_name):
     parts = urlsplit(url)
     path_parts = parts.path.split('/tree/')
 
@@ -55,7 +57,7 @@ def download_from_repo(url, file_name, models_home='.'):
     json_response = response.json()
 
     download_url, file_size = find_bin_file(json_response, location, branch)
-    return download_file(download_url, models_home, file_name, file_size)
+    return download_file(download_url, file_name, file_size)
 
 
 def find_bin_file(json_response, location, branch):
@@ -75,7 +77,7 @@ def find_bin_file(json_response, location, branch):
     return download_url, file_size
 
 
-def download_file(download_url, models_home, file_name, file_size):
+def download_file(download_url, file_name, file_size):
     local_filename = path.join(models_home, file_name) + '.bin'
 
     first_byte = path.getsize(local_filename) if path.exists(local_filename) else 0
@@ -108,7 +110,7 @@ def download_file(download_url, models_home, file_name, file_size):
     return local_filename
 
 
-def pull(model, models_home='.', *args, **kwargs):
+def pull(model, *args, **kwargs):
     if path.exists(model):
         # a file on the filesystem is being specified
         return model
@@ -128,6 +130,6 @@ def pull(model, models_home='.', *args, **kwargs):
             return model
         raise Exception(f'Unknown model {model}')
 
-    local_filename = download_from_repo(url, file_name, models_home)
+    local_filename = download_from_repo(url, file_name)
 
     return local_filename