浏览代码

load from file path

Bruce MacDonald 1 年之前
父节点
当前提交
843ccf5070
共有 1 个文件被更改,包括 18 次插入12 次删除
  1. 18 12
      ollama/engine.py

+ 18 - 12
ollama/engine.py

@@ -3,6 +3,7 @@ import json
 import sys
 from contextlib import contextmanager
 from llama_cpp import Llama as LLM
+from template import template
 
 import ollama.model
 
@@ -10,39 +11,44 @@ import ollama.model
 @contextmanager
 def suppress_stderr():
     stderr = os.dup(sys.stderr.fileno())
-    with open(os.devnull, 'w') as devnull:
+    with open(os.devnull, "w") as devnull:
         os.dup2(devnull.fileno(), sys.stderr.fileno())
         yield
 
     os.dup2(stderr, sys.stderr.fileno())
 
 
-def generate(model, prompt, models_home='.', llms={}, *args, **kwargs):
+def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
     llm = load(model, models_home=models_home, llms=llms)
 
-    if 'max_tokens' not in kwargs:
-        kwargs.update({'max_tokens': 16384})
+    prompt = template(model, prompt)
 
-    if 'stop' not in kwargs:
-        kwargs.update({'stop': ['Q:', '\n']})
+    if "max_tokens" not in kwargs:
+        kwargs.update({"max_tokens": 16384})
 
-    if 'stream' not in kwargs:
-        kwargs.update({'stream': True})
+    if "stop" not in kwargs:
+        kwargs.update({"stop": ["Q:", "\n"]})
+
+    if "stream" not in kwargs:
+        kwargs.update({"stream": True})
 
     for output in llm(prompt, *args, **kwargs):
         yield json.dumps(output)
 
 
-def load(model, models_home='.', llms={}):
+def load(model, models_home=".", llms={}):
     llm = llms.get(model, None)
     if not llm:
         model_path = {
-            name: path
-            for name, path in ollama.model.models(models_home)
+            name: path for name, path in ollama.model.models(models_home)
         }.get(model, None)
 
         if model_path is None:
-            raise ValueError('Model not found')
+            # try loading this as a path to a model, rather than a model name
+            if os.path.isfile(model):
+                model_path = model
+            else:
+                raise ValueError("Model not found")
 
         # suppress LLM's output
         with suppress_stderr():