12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879 |
- import json
- import os
- from llama_cpp import Llama
- from flask import Flask, Response, stream_with_context, request
- from flask_cors import CORS
- app = Flask(__name__)
- CORS(app) # enable CORS for all routes
- # llms tracks which models are loaded
- llms = {}
- @app.route("/load", methods=["POST"])
- def load():
- data = request.get_json()
- model = data.get("model")
- if not model:
- return Response("Model is required", status=400)
- if not os.path.exists(f"../models/{model}.bin"):
- return {"error": "The model does not exist."}, 400
- if model not in llms:
- llms[model] = Llama(model_path=f"../models/{model}.bin")
- return Response(status=204)
- @app.route("/unload", methods=["POST"])
- def unload():
- data = request.get_json()
- model = data.get("model")
- if not model:
- return Response("Model is required", status=400)
- if not os.path.exists(f"../models/{model}.bin"):
- return {"error": "The model does not exist."}, 400
- llms.pop(model, None)
- return Response(status=204)
- @app.route("/generate", methods=["POST"])
- def generate():
- data = request.get_json()
- model = data.get("model")
- prompt = data.get("prompt")
- if not model:
- return Response("Model is required", status=400)
- if not prompt:
- return Response("Prompt is required", status=400)
- if not os.path.exists(f"../models/{model}.bin"):
- return {"error": "The model does not exist."}, 400
- if model not in llms:
- # auto load
- llms[model] = Llama(model_path=f"../models/{model}.bin")
- def stream_response():
- stream = llms[model](
- str(prompt), # TODO: optimize prompt based on model
- max_tokens=4096,
- stop=["Q:", "\n"],
- echo=True,
- stream=True,
- )
- for output in stream:
- yield json.dumps(output)
- return Response(
- stream_with_context(stream_response()), mimetype="text/event-stream"
- )
- if __name__ == "__main__":
- app.run(debug=True, threaded=True, port=5001)
|