ollama.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import json
  2. import os
  3. import threading
  4. import click
  5. from llama_cpp import Llama
  6. from flask import Flask, Response, stream_with_context, request
  7. from flask_cors import CORS
  8. from template import template
  9. app = Flask(__name__)
  10. CORS(app) # enable CORS for all routes
  11. # llms tracks which models are loaded
  12. llms = {}
  13. lock = threading.Lock()
  14. def load(model):
  15. with lock:
  16. if not os.path.exists(f"./models/{model}.bin"):
  17. return {"error": "The model does not exist."}
  18. if model not in llms:
  19. llms[model] = Llama(model_path=f"./models/{model}.bin")
  20. return None
  21. def unload(model):
  22. with lock:
  23. if not os.path.exists(f"./models/{model}.bin"):
  24. return {"error": "The model does not exist."}
  25. llms.pop(model, None)
  26. return None
  27. def query(model, prompt):
  28. # auto load
  29. error = load(model)
  30. if error is not None:
  31. return error
  32. generated = llms[model](
  33. str(prompt), # TODO: optimize prompt based on model
  34. max_tokens=4096,
  35. stop=["Q:", "\n"],
  36. echo=True,
  37. stream=True,
  38. )
  39. for output in generated:
  40. yield json.dumps(output)
  41. def models():
  42. all_files = os.listdir("./models")
  43. bin_files = [
  44. file.replace(".bin", "") for file in all_files if file.endswith(".bin")
  45. ]
  46. return bin_files
  47. @app.route("/load", methods=["POST"])
  48. def load_route_handler():
  49. data = request.get_json()
  50. model = data.get("model")
  51. if not model:
  52. return Response("Model is required", status=400)
  53. error = load(model)
  54. if error is not None:
  55. return error
  56. return Response(status=204)
  57. @app.route("/unload", methods=["POST"])
  58. def unload_route_handler():
  59. data = request.get_json()
  60. model = data.get("model")
  61. if not model:
  62. return Response("Model is required", status=400)
  63. error = unload(model)
  64. if error is not None:
  65. return error
  66. return Response(status=204)
  67. @app.route("/generate", methods=["POST"])
  68. def generate_route_handler():
  69. data = request.get_json()
  70. model = data.get("model")
  71. prompt = data.get("prompt")
  72. if not model:
  73. return Response("Model is required", status=400)
  74. if not prompt:
  75. return Response("Prompt is required", status=400)
  76. if not os.path.exists(f"./models/{model}.bin"):
  77. return {"error": "The model does not exist."}, 400
  78. return Response(
  79. stream_with_context(query(model, prompt)), mimetype="text/event-stream"
  80. )
  81. @app.route("/models", methods=["GET"])
  82. def models_route_handler():
  83. bin_files = models()
  84. return Response(json.dumps(bin_files), mimetype="application/json")
  85. @click.group(invoke_without_command=True)
  86. @click.pass_context
  87. def cli(ctx):
  88. # allows the script to respond to command line input when executed directly
  89. if ctx.invoked_subcommand is None:
  90. click.echo(ctx.get_help())
  91. @cli.command()
  92. @click.option("--port", default=5000, help="Port to run the server on")
  93. @click.option("--debug", default=False, help="Enable debug mode")
  94. def serve(port, debug):
  95. print("Serving on http://localhost:{port}")
  96. app.run(host="0.0.0.0", port=port, debug=debug)
  97. @cli.command()
  98. @click.option("--model", default="vicuna-7b-v1.3.ggmlv3.q8_0", help="The model to use")
  99. @click.option("--prompt", default="", help="The prompt for the model")
  100. def generate(model, prompt):
  101. if prompt == "":
  102. prompt = input("Prompt: ")
  103. output = ""
  104. prompt = template(model, prompt)
  105. for generated in query(model, prompt):
  106. generated_json = json.loads(generated)
  107. text = generated_json["choices"][0]["text"]
  108. output += text
  109. print(f"\r{output}", end="", flush=True)
  110. if __name__ == "__main__":
  111. cli()