server.py 1.2 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import json
  2. import os
  3. from llama_cpp import Llama
  4. from flask import Flask, Response, stream_with_context, request
  5. from flask_cors import CORS, cross_origin
  6. app = Flask(__name__)
  7. CORS(app) # enable CORS for all routes
  8. # llms tracks which models are loaded
  9. llms = {}
  10. @app.route("/generate", methods=["POST"])
  11. def generate():
  12. data = request.get_json()
  13. model = data.get("model")
  14. prompt = data.get("prompt")
  15. if not model:
  16. return Response("Model is required", status=400)
  17. if not prompt:
  18. return Response("Prompt is required", status=400)
  19. if not os.path.exists(f"../models/{model}.bin"):
  20. return {"error": "The model file does not exist."}, 400
  21. if model not in llms:
  22. llms[model] = Llama(model_path=f"../models/{model}.bin")
  23. def stream_response():
  24. stream = llms[model](
  25. str(prompt), # TODO: optimize prompt based on model
  26. max_tokens=4096,
  27. stop=["Q:", "\n"],
  28. echo=True,
  29. stream=True,
  30. )
  31. for output in stream:
  32. yield json.dumps(output)
  33. return Response(
  34. stream_with_context(stream_response()), mimetype="text/event-stream"
  35. )
  36. if __name__ == "__main__":
  37. app.run(debug=True, threaded=True, port=5000)