1234567891011121314151617181920212223242526272829303132333435363738394041424344454647 |
- import json
- import os
- from llama_cpp import Llama
- from flask import Flask, Response, stream_with_context, request
- from flask_cors import CORS, cross_origin
- app = Flask(__name__)
- CORS(app) # enable CORS for all routes
- # llms tracks which models are loaded
- llms = {}
- @app.route("/generate", methods=["POST"])
- def generate():
- data = request.get_json()
- model = data.get("model")
- prompt = data.get("prompt")
- if not model:
- return Response("Model is required", status=400)
- if not prompt:
- return Response("Prompt is required", status=400)
- if not os.path.exists(f"../models/{model}.bin"):
- return {"error": "The model file does not exist."}, 400
- if model not in llms:
- llms[model] = Llama(model_path=f"../models/{model}.bin")
- def stream_response():
- stream = llms[model](
- str(prompt), # TODO: optimize prompt based on model
- max_tokens=4096,
- stop=["Q:", "\n"],
- echo=True,
- stream=True,
- )
- for output in stream:
- yield json.dumps(output)
- return Response(
- stream_with_context(stream_response()), mimetype="text/event-stream"
- )
- if __name__ == "__main__":
- app.run(debug=True, threaded=True, port=5000)
|