Przeglądaj źródła

Merge pull request #11 from jmorganca/interactive-generate

interactive generate
Michael Yang 1 rok temu
rodzic
commit
5610405e77
2 zmienionych plików z 32 dodań i 4 usunięć
  1. 31 3
      ollama/cmd/cli.py
  2. 1 1
      ollama/engine.py

+ 31 - 3
ollama/cmd/cli.py

@@ -1,4 +1,5 @@
 import os
+import sys
 import json
 from pathlib import Path
 from argparse import ArgumentParser
@@ -20,7 +21,7 @@ def main():
 
     generate_parser = subparsers.add_parser("generate")
     generate_parser.add_argument("model")
-    generate_parser.add_argument("prompt")
+    generate_parser.add_argument("prompt", nargs="?")
     generate_parser.set_defaults(fn=generate)
 
     add_parser = subparsers.add_parser("add")
@@ -37,6 +38,8 @@ def main():
     try:
         fn = args.pop("fn")
         fn(**args)
+    except KeyboardInterrupt:
+        pass
     except KeyError:
         parser.print_help()
     except Exception as e:
@@ -49,12 +52,37 @@ def list_models(*args, **kwargs):
 
 
 def generate(*args, **kwargs):
+    if prompt := kwargs.get('prompt'):
+        print('>>>', prompt, flush=True)
+        print(flush=True)
+        generate_oneshot(*args, **kwargs)
+        print(flush=True)
+        return
+
+    return generate_interactive(*args, **kwargs)
+
+
+def generate_oneshot(*args, **kwargs):
     for output in engine.generate(*args, **kwargs):
         output = json.loads(output)
-
         choices = output.get("choices", [])
         if len(choices) > 0:
-            print(choices[0].get("text", ""), end="")
+            print(choices[0].get("text", ""), end="", flush=True)
+
+    print()
+
+
+def generate_interactive(*args, **kwargs):
+    print('>>> ', end='', flush=True)
+    for line in sys.stdin:
+        if not sys.stdin.isatty():
+            print(line, end='')
+
+        print(flush=True)
+        kwargs.update({'prompt': line})
+        generate_oneshot(*args, **kwargs)
+        print(flush=True)
+        print('>>> ', end='', flush=True)
 
 
 def add(model, models_home):

+ 1 - 1
ollama/engine.py

@@ -27,7 +27,7 @@ def generate(model, prompt, models_home=".", llms={}, *args, **kwargs):
         kwargs.update({"max_tokens": 16384})
 
     if "stop" not in kwargs:
-        kwargs.update({"stop": ["Q:", "\n"]})
+        kwargs.update({"stop": ["Q:"]})
 
     if "stream" not in kwargs:
         kwargs.update({"stream": True})