cli.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import os
  2. import sys
  3. from argparse import ArgumentParser, HelpFormatter, PARSER
  4. from yaspin import yaspin
  5. from ollama import model, engine
  6. from ollama.cmd import server
  7. class CustomHelpFormatter(HelpFormatter):
  8. """
  9. This class is used to customize the way the argparse help text is displayed.
  10. We specifically override the _format_action method to exclude the line that
  11. shows all the subparser command options in the help text. This line is typically
  12. in the form "{serve,models,pull,run}".
  13. """
  14. def _format_action(self, action):
  15. # get the original help text
  16. parts = super()._format_action(action)
  17. if action.nargs == PARSER:
  18. # remove the unwanted first line
  19. parts = "\n".join(parts.split("\n")[1:])
  20. return parts
  21. def main():
  22. parser = ArgumentParser(
  23. description='Ollama: Run any large language model on any machine.',
  24. formatter_class=CustomHelpFormatter,
  25. )
  26. # create models home if it doesn't exist
  27. os.makedirs(model.MODELS_CACHE_PATH, exist_ok=True)
  28. subparsers = parser.add_subparsers(
  29. title='commands',
  30. )
  31. list_parser = subparsers.add_parser(
  32. "models",
  33. description="List all available models stored locally.",
  34. help="List all available models stored locally.",
  35. )
  36. list_parser.set_defaults(fn=list_models)
  37. search_parser = subparsers.add_parser(
  38. "search",
  39. description="Search for compatible models that Ollama can run.",
  40. help="Search for compatible models that Ollama can run. Usage: search [model]",
  41. )
  42. search_parser.add_argument(
  43. "query",
  44. nargs="?",
  45. help="Optional name of the model to search for.",
  46. )
  47. search_parser.set_defaults(fn=search)
  48. pull_parser = subparsers.add_parser(
  49. "pull",
  50. description="Download a specified model from a remote source.",
  51. help="Download a specified model from a remote source. Usage: pull [model]",
  52. )
  53. pull_parser.add_argument("model", help="Name of the model to download.")
  54. pull_parser.set_defaults(fn=pull)
  55. run_parser = subparsers.add_parser(
  56. "run",
  57. description="Run a model and submit prompts.",
  58. help="Run a model and submit prompts. Usage: run [model] [prompt]",
  59. )
  60. run_parser.add_argument("model", help="Name of the model to run.")
  61. run_parser.add_argument(
  62. "prompt",
  63. nargs="?",
  64. help="Optional prompt for the model, interactive mode enabled when not specified.",
  65. )
  66. run_parser.set_defaults(fn=run)
  67. server.set_parser(
  68. subparsers.add_parser(
  69. "serve",
  70. description="Start a persistent server to interact with models via the API.",
  71. help="Start a persistent server to interact with models via the API.",
  72. )
  73. )
  74. args = parser.parse_args()
  75. args = vars(args)
  76. try:
  77. fn = args.pop("fn")
  78. fn(**args)
  79. except KeyboardInterrupt:
  80. pass
  81. except KeyError:
  82. parser.print_help()
  83. except Exception as e:
  84. print(e)
  85. def list_models(*args, **kwargs):
  86. for m in model.models(*args, **kwargs):
  87. print(m)
  88. def generate(*args, **kwargs):
  89. if prompt := kwargs.get("prompt"):
  90. print(">>>", prompt, flush=True)
  91. generate_oneshot(*args, **kwargs)
  92. return
  93. if sys.stdin.isatty():
  94. return generate_interactive(*args, **kwargs)
  95. return generate_batch(*args, **kwargs)
  96. def generate_oneshot(*args, **kwargs):
  97. print(flush=True)
  98. spinner = yaspin()
  99. spinner.start()
  100. spinner_running = True
  101. try:
  102. for output in engine.generate(model_name=kwargs.pop('model'), *args, **kwargs):
  103. choices = output.get("choices", [])
  104. if len(choices) > 0:
  105. if spinner_running:
  106. spinner.stop()
  107. spinner_running = False
  108. print("\r", end="") # move cursor back to beginning of line again
  109. print(choices[0].get("text", ""), end="", flush=True)
  110. except Exception:
  111. spinner.stop()
  112. raise
  113. # end with a new line
  114. print(flush=True)
  115. print(flush=True)
  116. def generate_interactive(*args, **kwargs):
  117. while True:
  118. print(">>> ", end="", flush=True)
  119. line = next(sys.stdin)
  120. if not line:
  121. return
  122. kwargs.update({"prompt": line})
  123. generate_oneshot(*args, **kwargs)
  124. def generate_batch(*args, **kwargs):
  125. for line in sys.stdin:
  126. print(">>> ", line, end="", flush=True)
  127. kwargs.update({"prompt": line})
  128. generate_oneshot(*args, **kwargs)
  129. def search(*args, **kwargs):
  130. try:
  131. model_names = model.search_directory(*args, **kwargs)
  132. if len(model_names) == 0:
  133. print("No models found.")
  134. return
  135. elif len(model_names) == 1:
  136. print(f"Found {len(model_names)} available model:")
  137. else:
  138. print(f"Found {len(model_names)} available models:")
  139. for model_name in model_names:
  140. print(model_name.lower())
  141. except Exception as e:
  142. print("Failed to fetch available models, check your network connection")
  143. def pull(*args, **kwargs):
  144. model.pull(model_name=kwargs.pop('model'), *args, **kwargs)
  145. def run(*args, **kwargs):
  146. name = model.pull(model_name=kwargs.pop('model'), *args, **kwargs)
  147. kwargs.update({"model": name})
  148. print(f"Running {name}...")
  149. generate(*args, **kwargs)