proto.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import json
  2. import os
  3. import threading
  4. from llama_cpp import Llama
  5. from flask import Flask, Response, stream_with_context, request
  6. from flask_cors import CORS
  7. app = Flask(__name__)
  8. CORS(app) # enable CORS for all routes
  9. # llms tracks which models are loaded
  10. llms = {}
  11. lock = threading.Lock()
  12. def load(model):
  13. with lock:
  14. if not os.path.exists(f"./models/{model}.bin"):
  15. return {"error": "The model does not exist."}
  16. if model not in llms:
  17. llms[model] = Llama(model_path=f"./models/{model}.bin")
  18. return None
  19. def unload(model):
  20. with lock:
  21. if not os.path.exists(f"./models/{model}.bin"):
  22. return {"error": "The model does not exist."}
  23. llms.pop(model, None)
  24. return None
  25. def generate(model, prompt):
  26. # auto load
  27. error = load(model)
  28. if error is not None:
  29. return error
  30. stream = llms[model](
  31. str(prompt), # TODO: optimize prompt based on model
  32. max_tokens=4096,
  33. stop=["Q:", "\n"],
  34. echo=True,
  35. stream=True,
  36. )
  37. for output in stream:
  38. yield json.dumps(output)
  39. def models():
  40. all_files = os.listdir("./models")
  41. bin_files = [
  42. file.replace(".bin", "") for file in all_files if file.endswith(".bin")
  43. ]
  44. return bin_files
  45. @app.route("/load", methods=["POST"])
  46. def load_route_handler():
  47. data = request.get_json()
  48. model = data.get("model")
  49. if not model:
  50. return Response("Model is required", status=400)
  51. error = load(model)
  52. if error is not None:
  53. return error
  54. return Response(status=204)
  55. @app.route("/unload", methods=["POST"])
  56. def unload_route_handler():
  57. data = request.get_json()
  58. model = data.get("model")
  59. if not model:
  60. return Response("Model is required", status=400)
  61. error = unload(model)
  62. if error is not None:
  63. return error
  64. return Response(status=204)
  65. @app.route("/generate", methods=["POST"])
  66. def generate_route_handler():
  67. data = request.get_json()
  68. model = data.get("model")
  69. prompt = data.get("prompt")
  70. if not model:
  71. return Response("Model is required", status=400)
  72. if not prompt:
  73. return Response("Prompt is required", status=400)
  74. if not os.path.exists(f"./models/{model}.bin"):
  75. return {"error": "The model does not exist."}, 400
  76. return Response(
  77. stream_with_context(generate(model, prompt)), mimetype="text/event-stream"
  78. )
  79. @app.route("/models", methods=["GET"])
  80. def models_route_handler():
  81. bin_files = models()
  82. return Response(json.dumps(bin_files), mimetype="application/json")
  83. if __name__ == "__main__":
  84. app.run(debug=True, threaded=True, port=5001)
  85. app.run()