Browse Source

use difflib.get_close_matches

Michael Yang 1 year ago
parent
commit
ba3cbbbb4c
1 changed files with 3 additions and 13 deletions
  1. 3 13
      ollama/prompt.py

+ 3 - 13
ollama/prompt.py

@@ -1,19 +1,9 @@
-from os import path
-from difflib import SequenceMatcher
+from difflib import get_close_matches
 from jinja2 import Environment, PackageLoader
 from jinja2 import Environment, PackageLoader
 
 
 
 
 def template(name, prompt):
 def template(name, prompt):
-    best_ratio = 0
-    best_template = ''
-
     environment = Environment(loader=PackageLoader(__name__, 'templates'))
     environment = Environment(loader=PackageLoader(__name__, 'templates'))
-    for template in environment.list_templates():
-        base, _ = path.splitext(template)
-        ratio = SequenceMatcher(None, path.basename(name).lower(), base).ratio()
-        if ratio > best_ratio:
-            best_ratio = ratio
-            best_template = template
-
-    template = environment.get_template(best_template)
+    best_templates = get_close_matches(name, environment.list_templates(), n=1, cutoff=0)
+    template = environment.get_template(best_templates.pop())
     return template.render(prompt=prompt)
     return template.render(prompt=prompt)