proto.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. import json
  2. import os
  3. import threading
  4. import click
  5. from llama_cpp import Llama
  6. from flask import Flask, Response, stream_with_context, request
  7. from flask_cors import CORS
  8. app = Flask(__name__)
  9. CORS(app) # enable CORS for all routes
  10. # llms tracks which models are loaded
  11. llms = {}
  12. lock = threading.Lock()
  13. def load(model):
  14. with lock:
  15. if not os.path.exists(f"./models/{model}.bin"):
  16. return {"error": "The model does not exist."}
  17. if model not in llms:
  18. llms[model] = Llama(model_path=f"./models/{model}.bin")
  19. return None
  20. def unload(model):
  21. with lock:
  22. if not os.path.exists(f"./models/{model}.bin"):
  23. return {"error": "The model does not exist."}
  24. llms.pop(model, None)
  25. return None
  26. def query(model, prompt):
  27. # auto load
  28. error = load(model)
  29. if error is not None:
  30. return error
  31. generated = llms[model](
  32. str(prompt), # TODO: optimize prompt based on model
  33. max_tokens=4096,
  34. stop=["Q:", "\n"],
  35. echo=True,
  36. stream=True,
  37. )
  38. for output in generated:
  39. yield json.dumps(output)
  40. def models():
  41. all_files = os.listdir("./models")
  42. bin_files = [
  43. file.replace(".bin", "") for file in all_files if file.endswith(".bin")
  44. ]
  45. return bin_files
  46. @app.route("/load", methods=["POST"])
  47. def load_route_handler():
  48. data = request.get_json()
  49. model = data.get("model")
  50. if not model:
  51. return Response("Model is required", status=400)
  52. error = load(model)
  53. if error is not None:
  54. return error
  55. return Response(status=204)
  56. @app.route("/unload", methods=["POST"])
  57. def unload_route_handler():
  58. data = request.get_json()
  59. model = data.get("model")
  60. if not model:
  61. return Response("Model is required", status=400)
  62. error = unload(model)
  63. if error is not None:
  64. return error
  65. return Response(status=204)
  66. @app.route("/generate", methods=["POST"])
  67. def generate_route_handler():
  68. data = request.get_json()
  69. model = data.get("model")
  70. prompt = data.get("prompt")
  71. if not model:
  72. return Response("Model is required", status=400)
  73. if not prompt:
  74. return Response("Prompt is required", status=400)
  75. if not os.path.exists(f"./models/{model}.bin"):
  76. return {"error": "The model does not exist."}, 400
  77. return Response(
  78. stream_with_context(query(model, prompt)), mimetype="text/event-stream"
  79. )
  80. @app.route("/models", methods=["GET"])
  81. def models_route_handler():
  82. bin_files = models()
  83. return Response(json.dumps(bin_files), mimetype="application/json")
  84. @click.group(invoke_without_command=True)
  85. @click.pass_context
  86. def cli(ctx):
  87. # allows the script to respond to command line input when executed directly
  88. if ctx.invoked_subcommand is None:
  89. click.echo(ctx.get_help())
  90. @cli.command()
  91. @click.option("--port", default=5000, help="Port to run the server on")
  92. @click.option("--debug", default=False, help="Enable debug mode")
  93. def serve(port, debug):
  94. print("Serving on http://localhost:{port}")
  95. app.run(host="0.0.0.0", port=port, debug=debug)
  96. @cli.command()
  97. @click.option("--model", default="vicuna-7b-v1.3.ggmlv3.q8_0", help="The model to use")
  98. @click.option("--prompt", default="", help="The prompt for the model")
  99. def generate(model, prompt):
  100. if prompt == "":
  101. prompt = input("Prompt: ")
  102. output = ""
  103. for generated in query(model, prompt):
  104. generated_json = json.loads(generated)
  105. text = generated_json["choices"][0]["text"]
  106. output += text
  107. print(f"\r{output}", end="", flush=True)
  108. if __name__ == "__main__":
  109. cli()