瀏覽代碼

search command

Bruce MacDonald 1 年之前
父節點
當前提交
d01be075b6
共有 5 個文件被更改,包括 62 次插入24 次删除
  1. 0 2
      README.md
  2. 36 8
      ollama/cmd/cli.py
  3. 5 8
      ollama/engine.py
  4. 17 5
      ollama/model.py
  5. 4 1
      ollama/prompt.py

+ 0 - 2
README.md

@@ -87,8 +87,6 @@ Download a model
 ollama.pull("huggingface.co/thebloke/llama-7b-ggml")
 ```
 
-## Coming Soon
-
 ### `ollama.search("query")`
 
 Search for compatible models that Ollama can run

+ 36 - 8
ollama/cmd/cli.py

@@ -37,14 +37,6 @@ def main():
         title='commands',
     )
 
-    server.set_parser(
-        subparsers.add_parser(
-            "serve",
-            description="Start a persistent server to interact with models via the API.",
-            help="Start a persistent server to interact with models via the API.",
-        )
-    )
-
     list_parser = subparsers.add_parser(
         "models",
         description="List all available models stored locally.",
@@ -52,6 +44,18 @@ def main():
     )
     list_parser.set_defaults(fn=list_models)
 
+    search_parser = subparsers.add_parser(
+        "search",
+        description="Search for compatible models that Ollama can run.",
+        help="Search for compatible models that Ollama can run. Usage: search [model]",
+    )
+    search_parser.add_argument(
+        "query",
+        nargs="?",
+        help="Optional name of the model to search for.",
+    )
+    search_parser.set_defaults(fn=search)
+
     pull_parser = subparsers.add_parser(
         "pull",
         description="Download a specified model from a remote source.",
@@ -73,6 +77,14 @@ def main():
     )
     run_parser.set_defaults(fn=run)
 
+    server.set_parser(
+        subparsers.add_parser(
+            "serve",
+            description="Start a persistent server to interact with models via the API.",
+            help="Start a persistent server to interact with models via the API.",
+        )
+    )
+
     args = parser.parse_args()
     args = vars(args)
 
@@ -146,6 +158,22 @@ def generate_batch(*args, **kwargs):
         generate_oneshot(*args, **kwargs)
 
 
+def search(*args, **kwargs):
+    try:
+        model_names = model.search_directory(*args, **kwargs)
+        if len(model_names) == 0:
+            print("No models found.")
+            return
+        elif len(model_names) == 1:
+            print(f"Found {len(model_names)} available model:")
+        else:
+            print(f"Found {len(model_names)} available models:")
+        for model_name in model_names:
+            print(model_name.lower())
+    except Exception as e:
+        print("Failed to fetch available models, check your network connection")
+
+
 def pull(*args, **kwargs):
     model.pull(model_name=kwargs.pop('model'), *args, **kwargs)
 

+ 5 - 8
ollama/engine.py

@@ -1,6 +1,7 @@
 import os
 import sys
 from os import path
+from pathlib import Path
 from contextlib import contextmanager
 from fuzzywuzzy import process
 from llama_cpp import Llama
@@ -30,7 +31,7 @@ def load(model_name, models={}):
     if not models.get(model_name, None):
         model_path = path.expanduser(model_name)
         if not path.exists(model_path):
-            model_path = MODELS_CACHE_PATH / model_name + ".bin"
+            model_path = str(MODELS_CACHE_PATH / (model_name + ".bin"))
 
         runners = {
             model_type: cls
@@ -52,14 +53,10 @@ def unload(model_name, models={}):
 
 
 class LlamaCppRunner:
-
     def __init__(self, model_path, model_type):
         try:
             with suppress(sys.stderr), suppress(sys.stdout):
-                self.model = Llama(model_path,
-                                   verbose=False,
-                                   n_gpu_layers=1,
-                                   seed=-1)
+                self.model = Llama(model_path, verbose=False, n_gpu_layers=1, seed=-1)
         except Exception:
             raise Exception("Failed to load model", model_path, model_type)
 
@@ -88,10 +85,10 @@ class LlamaCppRunner:
 
 
 class CtransformerRunner:
-
     def __init__(self, model_path, model_type):
         self.model = AutoModelForCausalLM.from_pretrained(
-            model_path, model_type=model_type, local_files_only=True)
+            model_path, model_type=model_type, local_files_only=True
+        )
 
     @staticmethod
     def model_types():

+ 17 - 5
ollama/model.py

@@ -18,13 +18,26 @@ def models(*args, **kwargs):
                 yield base
 
 
+# search the directory and return all models which contain the search term as a substring,
+# or all models if no search term is provided
+def search_directory(query):
+    response = requests.get(MODELS_MANIFEST)
+    response.raise_for_status()
+    directory = response.json()
+    model_names = []
+    for model_info in directory:
+        if not query or query.lower() in model_info.get('name', '').lower():
+            model_names.append(model_info.get('name'))
+    return model_names
+
+
 # get the url of the model from our curated directory
 def get_url_from_directory(model):
     response = requests.get(MODELS_MANIFEST)
     response.raise_for_status()
     directory = response.json()
     for model_info in directory:
-        if model_info.get('name') == model:
+        if model_info.get('name').lower() == model.lower():
             return model_info.get('url')
     return model
 
@@ -42,7 +55,6 @@ def download_from_repo(url, file_name):
     location = location.strip('/')
     if file_name == '':
         file_name = path.basename(location).lower()
-
     download_url = urlunsplit(
         (
             'https',
@@ -78,7 +90,7 @@ def find_bin_file(json_response, location, branch):
 
 
 def download_file(download_url, file_name, file_size):
-    local_filename = MODELS_CACHE_PATH / file_name + '.bin'
+    local_filename = MODELS_CACHE_PATH / str(file_name + '.bin')
 
     first_byte = path.getsize(local_filename) if path.exists(local_filename) else 0
 
@@ -111,7 +123,8 @@ def download_file(download_url, file_name, file_size):
 
 
 def pull(model_name, *args, **kwargs):
-    if path.exists(model_name):
+    maybe_existing_model_location = MODELS_CACHE_PATH / str(model_name + '.bin')
+    if path.exists(model_name) or path.exists(maybe_existing_model_location):
         # a file on the filesystem is being specified
         return model_name
     # check the remote model location and see if it needs to be downloaded
@@ -120,7 +133,6 @@ def pull(model_name, *args, **kwargs):
     if not validators.url(url) and not url.startswith('huggingface.co'):
         url = get_url_from_directory(model_name)
         file_name = model_name
-
     if not (url.startswith('http://') or url.startswith('https://')):
         url = f'https://{url}'
 

+ 4 - 1
ollama/prompt.py

@@ -1,9 +1,12 @@
+from os import path
 from difflib import get_close_matches
 from jinja2 import Environment, PackageLoader
 
 
 def template(name, prompt):
     environment = Environment(loader=PackageLoader(__name__, 'templates'))
-    best_templates = get_close_matches(name, environment.list_templates(), n=1, cutoff=0)
+    best_templates = get_close_matches(
+        path.basename(name), environment.list_templates(), n=1, cutoff=0
+    )
     template = environment.get_template(best_templates.pop())
     return template.render(prompt=prompt)