123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105 |
- import json
- import os
- import threading
- from llama_cpp import Llama
- from flask import Flask, Response, stream_with_context, request
- from flask_cors import CORS
- app = Flask(__name__)
- CORS(app) # enable CORS for all routes
- # llms tracks which models are loaded
- llms = {}
- lock = threading.Lock()
- def load(model):
- with lock:
- if not os.path.exists(f"./models/{model}.bin"):
- return {"error": "The model does not exist."}
- if model not in llms:
- llms[model] = Llama(model_path=f"./models/{model}.bin")
- return None
- def unload(model):
- with lock:
- if not os.path.exists(f"./models/{model}.bin"):
- return {"error": "The model does not exist."}
- llms.pop(model, None)
- return None
- def generate(model, prompt):
- # auto load
- error = load(model)
- if error is not None:
- return error
- stream = llms[model](
- str(prompt), # TODO: optimize prompt based on model
- max_tokens=4096,
- stop=["Q:", "\n"],
- echo=True,
- stream=True,
- )
- for output in stream:
- yield json.dumps(output)
- def models():
- all_files = os.listdir("./models")
- bin_files = [
- file.replace(".bin", "") for file in all_files if file.endswith(".bin")
- ]
- return bin_files
- @app.route("/load", methods=["POST"])
- def load_route_handler():
- data = request.get_json()
- model = data.get("model")
- if not model:
- return Response("Model is required", status=400)
- error = load(model)
- if error is not None:
- return error
- return Response(status=204)
- @app.route("/unload", methods=["POST"])
- def unload_route_handler():
- data = request.get_json()
- model = data.get("model")
- if not model:
- return Response("Model is required", status=400)
- error = unload(model)
- if error is not None:
- return error
- return Response(status=204)
- @app.route("/generate", methods=["POST"])
- def generate_route_handler():
- data = request.get_json()
- model = data.get("model")
- prompt = data.get("prompt")
- if not model:
- return Response("Model is required", status=400)
- if not prompt:
- return Response("Prompt is required", status=400)
- if not os.path.exists(f"./models/{model}.bin"):
- return {"error": "The model does not exist."}, 400
- return Response(
- stream_with_context(generate(model, prompt)), mimetype="text/event-stream"
- )
- @app.route("/models", methods=["GET"])
- def models_route_handler():
- bin_files = models()
- return Response(json.dumps(bin_files), mimetype="application/json")
- if __name__ == "__main__":
- app.run(debug=True, threaded=True, port=5001)
- app.run()
|