ollama.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. import json
  2. import os
  3. import threading
  4. import click
  5. from tqdm import tqdm
  6. from pathlib import Path
  7. from llama_cpp import Llama
  8. from flask import Flask, Response, stream_with_context, request
  9. from flask_cors import CORS
  10. from template import template
  11. app = Flask(__name__)
  12. CORS(app) # enable CORS for all routes
  13. # llms tracks which models are loaded
  14. llms = {}
  15. lock = threading.Lock()
  16. def models_directory():
  17. home_dir = Path.home()
  18. models_dir = home_dir / ".ollama/models"
  19. if not models_dir.exists():
  20. models_dir.mkdir(parents=True)
  21. return models_dir
  22. def load(model):
  23. """
  24. Load a model.
  25. Args:
  26. model (str): The name or path of the model to load.
  27. Returns:
  28. str or None: The name of the model
  29. dict or None: If the model cannot be loaded, a dictionary with an 'error' key is returned.
  30. If the model is successfully loaded, None is returned.
  31. """
  32. with lock:
  33. load_from = ""
  34. if os.path.exists(model) and model.endswith(".bin"):
  35. # model is being referenced by path rather than name directly
  36. path = os.path.abspath(model)
  37. base = os.path.basename(path)
  38. load_from = path
  39. name = os.path.splitext(base)[0] # Split the filename and extension
  40. else:
  41. # model is being loaded from the ollama models directory
  42. dir = models_directory()
  43. # TODO: download model from a repository if it does not exist
  44. load_from = str(dir / f"{model}.bin")
  45. name = model
  46. if load_from == "":
  47. return None, {"error": "Model not found."}
  48. if not os.path.exists(load_from):
  49. return None, {"error": f"The model {load_from} does not exist."}
  50. if name not in llms:
  51. llms[name] = Llama(model_path=load_from)
  52. return name, None
  53. def unload(model):
  54. """
  55. Unload a model.
  56. Remove a model from the list of loaded models. If the model is not loaded, this is a no-op.
  57. Args:
  58. model (str): The name of the model to unload.
  59. """
  60. llms.pop(model, None)
  61. def generate(model, prompt):
  62. # auto load
  63. name, error = load(model)
  64. if error is not None:
  65. return error
  66. generated = llms[name](
  67. str(prompt), # TODO: optimize prompt based on model
  68. max_tokens=4096,
  69. stop=["Q:", "\n"],
  70. stream=True,
  71. )
  72. for output in generated:
  73. yield json.dumps(output)
  74. def models():
  75. dir = models_directory()
  76. all_files = os.listdir(dir)
  77. bin_files = [
  78. file.replace(".bin", "") for file in all_files if file.endswith(".bin")
  79. ]
  80. return bin_files
  81. @app.route("/load", methods=["POST"])
  82. def load_route_handler():
  83. data = request.get_json()
  84. model = data.get("model")
  85. if not model:
  86. return Response("Model is required", status=400)
  87. error = load(model)
  88. if error is not None:
  89. return error
  90. return Response(status=204)
  91. @app.route("/unload", methods=["POST"])
  92. def unload_route_handler():
  93. data = request.get_json()
  94. model = data.get("model")
  95. if not model:
  96. return Response("Model is required", status=400)
  97. unload(model)
  98. return Response(status=204)
  99. @app.route("/generate", methods=["POST"])
  100. def generate_route_handler():
  101. data = request.get_json()
  102. model = data.get("model")
  103. prompt = data.get("prompt")
  104. if not model:
  105. return Response("Model is required", status=400)
  106. if not prompt:
  107. return Response("Prompt is required", status=400)
  108. if not os.path.exists(f"{model}"):
  109. return {"error": "The model does not exist."}, 400
  110. return Response(
  111. stream_with_context(generate(model, prompt)), mimetype="text/event-stream"
  112. )
  113. @app.route("/models", methods=["GET"])
  114. def models_route_handler():
  115. bin_files = models()
  116. return Response(json.dumps(bin_files), mimetype="application/json")
  117. @click.group(invoke_without_command=True)
  118. @click.pass_context
  119. def cli(ctx):
  120. # allows the script to respond to command line input when executed directly
  121. if ctx.invoked_subcommand is None:
  122. click.echo(ctx.get_help())
  123. @cli.command()
  124. @click.option("--port", default=7734, help="Port to run the server on")
  125. @click.option("--debug", default=False, help="Enable debug mode")
  126. def serve(port, debug):
  127. print("Serving on http://localhost:{port}")
  128. app.run(host="0.0.0.0", port=port, debug=debug)
  129. @cli.command(name="load")
  130. @click.argument("model")
  131. @click.option("--file", default=False, help="Indicates that a file path is provided")
  132. def load_cli(model, file):
  133. if file:
  134. error = load(path=model)
  135. else:
  136. error = load(model)
  137. if error is not None:
  138. print(error)
  139. return
  140. print("Model loaded")
  141. @cli.command(name="generate")
  142. @click.argument("model")
  143. @click.option("--prompt", default="", help="The prompt for the model")
  144. def generate_cli(model, prompt):
  145. if prompt == "":
  146. prompt = input("Prompt: ")
  147. output = ""
  148. prompt = template(model, prompt)
  149. for generated in generate(model, prompt):
  150. generated_json = json.loads(generated)
  151. text = generated_json["choices"][0]["text"]
  152. output += text
  153. print(f"\r{output}", end="", flush=True)
  154. @cli.command(name="models")
  155. def models_cli():
  156. print(models())
  157. @cli.command(name="pull")
  158. @click.argument("model")
  159. def pull_cli(model):
  160. print("not implemented")
  161. @cli.command(name="import")
  162. @click.argument("model")
  163. def import_cli(model):
  164. print("not implemented")
  165. if __name__ == "__main__":
  166. cli()