cli.py 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import os
  2. import sys
  3. from argparse import ArgumentParser
  4. from yaspin import yaspin
  5. from ollama import model, engine
  6. from ollama.cmd import server
  7. def main():
  8. parser = ArgumentParser()
  9. # create models home if it doesn't exist
  10. os.makedirs(model.models_home, exist_ok=True)
  11. subparsers = parser.add_subparsers()
  12. server.set_parser(subparsers.add_parser("serve"))
  13. list_parser = subparsers.add_parser("list")
  14. list_parser.set_defaults(fn=list_models)
  15. generate_parser = subparsers.add_parser("generate")
  16. generate_parser.add_argument("model")
  17. generate_parser.add_argument("prompt", nargs="?")
  18. generate_parser.set_defaults(fn=generate)
  19. pull_parser = subparsers.add_parser("pull")
  20. pull_parser.add_argument("model")
  21. pull_parser.set_defaults(fn=pull)
  22. run_parser = subparsers.add_parser("run")
  23. run_parser.add_argument("model")
  24. run_parser.add_argument("prompt", nargs="?")
  25. run_parser.set_defaults(fn=run)
  26. args = parser.parse_args()
  27. args = vars(args)
  28. try:
  29. fn = args.pop("fn")
  30. fn(**args)
  31. except KeyboardInterrupt:
  32. pass
  33. except KeyError:
  34. parser.print_help()
  35. except Exception as e:
  36. print(e)
  37. def list_models(*args, **kwargs):
  38. for m in model.models(*args, **kwargs):
  39. print(m)
  40. def generate(*args, **kwargs):
  41. if prompt := kwargs.get("prompt"):
  42. print(">>>", prompt, flush=True)
  43. generate_oneshot(*args, **kwargs)
  44. return
  45. if sys.stdin.isatty():
  46. return generate_interactive(*args, **kwargs)
  47. return generate_batch(*args, **kwargs)
  48. def generate_oneshot(*args, **kwargs):
  49. print(flush=True)
  50. spinner = yaspin()
  51. spinner.start()
  52. spinner_running = True
  53. try:
  54. for output in engine.generate(*args, **kwargs):
  55. choices = output.get("choices", [])
  56. if len(choices) > 0:
  57. if spinner_running:
  58. spinner.stop()
  59. spinner_running = False
  60. print("\r", end="") # move cursor back to beginning of line again
  61. print(choices[0].get("text", ""), end="", flush=True)
  62. except Exception:
  63. spinner.stop()
  64. raise
  65. # end with a new line
  66. print(flush=True)
  67. print(flush=True)
  68. def generate_interactive(*args, **kwargs):
  69. while True:
  70. print(">>> ", end="", flush=True)
  71. line = next(sys.stdin)
  72. if not line:
  73. return
  74. kwargs.update({"prompt": line})
  75. generate_oneshot(*args, **kwargs)
  76. def generate_batch(*args, **kwargs):
  77. for line in sys.stdin:
  78. print(">>> ", line, end="", flush=True)
  79. kwargs.update({"prompt": line})
  80. generate_oneshot(*args, **kwargs)
  81. def pull(*args, **kwargs):
  82. model.pull(*args, **kwargs)
  83. def run(*args, **kwargs):
  84. name = model.pull(*args, **kwargs)
  85. kwargs.update({"model": name})
  86. print(f"Running {name}...")
  87. generate(*args, **kwargs)