server.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. import json
  2. import os
  3. from llama_cpp import Llama
  4. from flask import Flask, Response, stream_with_context, request
  5. from flask_cors import CORS
  6. app = Flask(__name__)
  7. CORS(app) # enable CORS for all routes
  8. # llms tracks which models are loaded
  9. llms = {}
  10. @app.route("/load", methods=["POST"])
  11. def load():
  12. data = request.get_json()
  13. model = data.get("model")
  14. if not model:
  15. return Response("Model is required", status=400)
  16. if not os.path.exists(f"../models/{model}.bin"):
  17. return {"error": "The model does not exist."}, 400
  18. if model not in llms:
  19. llms[model] = Llama(model_path=f"../models/{model}.bin")
  20. return Response(status=204)
  21. @app.route("/unload", methods=["POST"])
  22. def unload():
  23. data = request.get_json()
  24. model = data.get("model")
  25. if not model:
  26. return Response("Model is required", status=400)
  27. if not os.path.exists(f"../models/{model}.bin"):
  28. return {"error": "The model does not exist."}, 400
  29. llms.pop(model, None)
  30. return Response(status=204)
  31. @app.route("/generate", methods=["POST"])
  32. def generate():
  33. data = request.get_json()
  34. model = data.get("model")
  35. prompt = data.get("prompt")
  36. if not model:
  37. return Response("Model is required", status=400)
  38. if not prompt:
  39. return Response("Prompt is required", status=400)
  40. if not os.path.exists(f"../models/{model}.bin"):
  41. return {"error": "The model does not exist."}, 400
  42. if model not in llms:
  43. # auto load
  44. llms[model] = Llama(model_path=f"../models/{model}.bin")
  45. def stream_response():
  46. stream = llms[model](
  47. str(prompt), # TODO: optimize prompt based on model
  48. max_tokens=4096,
  49. stop=["Q:", "\n"],
  50. echo=True,
  51. stream=True,
  52. )
  53. for output in stream:
  54. yield json.dumps(output)
  55. return Response(
  56. stream_with_context(stream_response()), mimetype="text/event-stream"
  57. )
  58. if __name__ == "__main__":
  59. app.run(debug=True, threaded=True, port=5001)