ollama.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222
  1. import json
  2. import os
  3. import threading
  4. import click
  5. from transformers import AutoModel
  6. from tqdm import tqdm
  7. from pathlib import Path
  8. from llama_cpp import Llama
  9. from flask import Flask, Response, stream_with_context, request
  10. from flask_cors import CORS
  11. from template import template
  12. app = Flask(__name__)
  13. CORS(app) # enable CORS for all routes
  14. # llms tracks which models are loaded
  15. llms = {}
  16. lock = threading.Lock()
  17. def models_directory():
  18. home_dir = Path.home()
  19. models_dir = home_dir / ".ollama/models"
  20. if not models_dir.exists():
  21. models_dir.mkdir(parents=True)
  22. return models_dir
  23. def load(model=None, path=None):
  24. """
  25. Load a model.
  26. The model can be specified by providing either the path or the model name,
  27. but not both. If both are provided, this function will raise a ValueError.
  28. If the model does not exist or could not be loaded, this function returns an error.
  29. Args:
  30. model (str, optional): The name of the model to load.
  31. path (str, optional): The path to the model file.
  32. Returns:
  33. dict or None: If the model cannot be loaded, a dictionary with an 'error' key is returned.
  34. If the model is successfully loaded, None is returned.
  35. """
  36. with lock:
  37. if path is not None and model is not None:
  38. raise ValueError(
  39. "Both path and model are specified. Please provide only one of them."
  40. )
  41. elif path is not None:
  42. name = os.path.basename(path)
  43. load_from = path
  44. elif model is not None:
  45. name = model
  46. dir = models_directory()
  47. load_from = str(dir / f"{model}.bin")
  48. else:
  49. raise ValueError("Either path or model must be specified.")
  50. if not os.path.exists(load_from):
  51. return {"error": f"The model at {load_from} does not exist."}
  52. if name not in llms:
  53. # TODO: download model from a repository if it does not exist
  54. llms[name] = Llama(model_path=load_from)
  55. # TODO: this should start a persistent instance of ollama with the model loaded
  56. return None
  57. def unload(model):
  58. """
  59. Unload a model.
  60. Remove a model from the list of loaded models. If the model is not loaded, this is a no-op.
  61. Args:
  62. model (str): The name of the model to unload.
  63. """
  64. llms.pop(model, None)
  65. def generate(model, prompt):
  66. # auto load
  67. error = load(model)
  68. if error is not None:
  69. return error
  70. generated = llms[model](
  71. str(prompt), # TODO: optimize prompt based on model
  72. max_tokens=4096,
  73. stop=["Q:", "\n"],
  74. stream=True,
  75. )
  76. for output in generated:
  77. yield json.dumps(output)
  78. def models():
  79. dir = models_directory()
  80. all_files = os.listdir(dir)
  81. bin_files = [
  82. file.replace(".bin", "") for file in all_files if file.endswith(".bin")
  83. ]
  84. return bin_files
  85. @app.route("/load", methods=["POST"])
  86. def load_route_handler():
  87. data = request.get_json()
  88. model = data.get("model")
  89. if not model:
  90. return Response("Model is required", status=400)
  91. error = load(model)
  92. if error is not None:
  93. return error
  94. return Response(status=204)
  95. @app.route("/unload", methods=["POST"])
  96. def unload_route_handler():
  97. data = request.get_json()
  98. model = data.get("model")
  99. if not model:
  100. return Response("Model is required", status=400)
  101. unload(model)
  102. return Response(status=204)
  103. @app.route("/generate", methods=["POST"])
  104. def generate_route_handler():
  105. data = request.get_json()
  106. model = data.get("model")
  107. prompt = data.get("prompt")
  108. if not model:
  109. return Response("Model is required", status=400)
  110. if not prompt:
  111. return Response("Prompt is required", status=400)
  112. if not os.path.exists(f"{model}"):
  113. return {"error": "The model does not exist."}, 400
  114. return Response(
  115. stream_with_context(generate(model, prompt)), mimetype="text/event-stream"
  116. )
  117. @app.route("/models", methods=["GET"])
  118. def models_route_handler():
  119. bin_files = models()
  120. return Response(json.dumps(bin_files), mimetype="application/json")
  121. @click.group(invoke_without_command=True)
  122. @click.pass_context
  123. def cli(ctx):
  124. # allows the script to respond to command line input when executed directly
  125. if ctx.invoked_subcommand is None:
  126. click.echo(ctx.get_help())
  127. @cli.command()
  128. @click.option("--port", default=5000, help="Port to run the server on")
  129. @click.option("--debug", default=False, help="Enable debug mode")
  130. def serve(port, debug):
  131. print("Serving on http://localhost:{port}")
  132. app.run(host="0.0.0.0", port=port, debug=debug)
  133. @cli.command(name="load")
  134. @click.argument("model")
  135. @click.option("--file", default=False, help="Indicates that a file path is provided")
  136. def load_cli(model, file):
  137. if file:
  138. error = load(path=model)
  139. else:
  140. error = load(model)
  141. if error is not None:
  142. print(error)
  143. return
  144. print("Model loaded")
  145. @cli.command(name="generate")
  146. @click.argument("model")
  147. @click.option("--prompt", default="", help="The prompt for the model")
  148. def generate_cli(model, prompt):
  149. if prompt == "":
  150. prompt = input("Prompt: ")
  151. output = ""
  152. prompt = template(model, prompt)
  153. for generated in generate(model, prompt):
  154. generated_json = json.loads(generated)
  155. text = generated_json["choices"][0]["text"]
  156. output += text
  157. print(f"\r{output}", end="", flush=True)
  158. def download_model(model_name):
  159. dir = models_directory()
  160. AutoModel.from_pretrained(model_name, cache_dir=dir)
  161. @cli.command(name="models")
  162. def models_cli():
  163. print(models())
  164. @cli.command(name="pull")
  165. @click.argument("model")
  166. def pull_cli(model):
  167. print("not implemented")
  168. @cli.command(name="import")
  169. @click.argument("model")
  170. def import_cli(model):
  171. print("not implemented")
  172. if __name__ == "__main__":
  173. cli()