|
@@ -1,5 +1,6 @@
|
|
|
import json
|
|
|
import os
|
|
|
+import threading
|
|
|
from llama_cpp import Llama
|
|
|
from flask import Flask, Response, stream_with_context, request
|
|
|
from flask_cors import CORS
|
|
@@ -9,77 +10,96 @@ CORS(app) # enable CORS for all routes
|
|
|
|
|
|
# llms tracks which models are loaded
|
|
|
llms = {}
|
|
|
+lock = threading.Lock()
|
|
|
+
|
|
|
+
|
|
|
+def load(model):
|
|
|
+ with lock:
|
|
|
+ if not os.path.exists(f"./models/{model}.bin"):
|
|
|
+ return {"error": "The model does not exist."}
|
|
|
+ if model not in llms:
|
|
|
+ llms[model] = Llama(model_path=f"./models/{model}.bin")
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+def unload(model):
|
|
|
+ with lock:
|
|
|
+ if not os.path.exists(f"./models/{model}.bin"):
|
|
|
+ return {"error": "The model does not exist."}
|
|
|
+ llms.pop(model, None)
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+def generate(model, prompt):
|
|
|
+ # auto load
|
|
|
+ error = load(model)
|
|
|
+ if error is not None:
|
|
|
+ return error
|
|
|
+ stream = llms[model](
|
|
|
+ str(prompt), # TODO: optimize prompt based on model
|
|
|
+ max_tokens=4096,
|
|
|
+ stop=["Q:", "\n"],
|
|
|
+ echo=True,
|
|
|
+ stream=True,
|
|
|
+ )
|
|
|
+ for output in stream:
|
|
|
+ yield json.dumps(output)
|
|
|
+
|
|
|
+
|
|
|
+def models():
|
|
|
+ all_files = os.listdir("./models")
|
|
|
+ bin_files = [
|
|
|
+ file.replace(".bin", "") for file in all_files if file.endswith(".bin")
|
|
|
+ ]
|
|
|
+ return bin_files
|
|
|
|
|
|
|
|
|
@app.route("/load", methods=["POST"])
|
|
|
-def load():
|
|
|
+def load_route_handler():
|
|
|
data = request.get_json()
|
|
|
model = data.get("model")
|
|
|
-
|
|
|
if not model:
|
|
|
return Response("Model is required", status=400)
|
|
|
- if not os.path.exists(f"./models/{model}.bin"):
|
|
|
- return {"error": "The model does not exist."}, 400
|
|
|
-
|
|
|
- if model not in llms:
|
|
|
- llms[model] = Llama(model_path=f"./models/{model}.bin")
|
|
|
-
|
|
|
+ error = load(model)
|
|
|
+ if error is not None:
|
|
|
+ return error
|
|
|
return Response(status=204)
|
|
|
|
|
|
|
|
|
@app.route("/unload", methods=["POST"])
|
|
|
-def unload():
|
|
|
+def unload_route_handler():
|
|
|
data = request.get_json()
|
|
|
model = data.get("model")
|
|
|
-
|
|
|
if not model:
|
|
|
return Response("Model is required", status=400)
|
|
|
- if not os.path.exists(f"./models/{model}.bin"):
|
|
|
- return {"error": "The model does not exist."}, 400
|
|
|
-
|
|
|
- llms.pop(model, None)
|
|
|
-
|
|
|
+ error = unload(model)
|
|
|
+ if error is not None:
|
|
|
+ return error
|
|
|
return Response(status=204)
|
|
|
|
|
|
|
|
|
@app.route("/generate", methods=["POST"])
|
|
|
-def generate():
|
|
|
+def generate_route_handler():
|
|
|
data = request.get_json()
|
|
|
model = data.get("model")
|
|
|
prompt = data.get("prompt")
|
|
|
-
|
|
|
if not model:
|
|
|
return Response("Model is required", status=400)
|
|
|
if not prompt:
|
|
|
return Response("Prompt is required", status=400)
|
|
|
if not os.path.exists(f"./models/{model}.bin"):
|
|
|
return {"error": "The model does not exist."}, 400
|
|
|
-
|
|
|
- if model not in llms:
|
|
|
- # auto load
|
|
|
- llms[model] = Llama(model_path=f"./models/{model}.bin")
|
|
|
-
|
|
|
- def stream_response():
|
|
|
- stream = llms[model](
|
|
|
- str(prompt), # TODO: optimize prompt based on model
|
|
|
- max_tokens=4096,
|
|
|
- stop=["Q:", "\n"],
|
|
|
- echo=True,
|
|
|
- stream=True,
|
|
|
- )
|
|
|
- for output in stream:
|
|
|
- yield json.dumps(output)
|
|
|
-
|
|
|
return Response(
|
|
|
- stream_with_context(stream_response()), mimetype="text/event-stream"
|
|
|
+ stream_with_context(generate(model, prompt)), mimetype="text/event-stream"
|
|
|
)
|
|
|
|
|
|
+
|
|
|
@app.route("/models", methods=["GET"])
|
|
|
-def models():
|
|
|
- all_files = os.listdir("./models")
|
|
|
- bin_files = [file.replace(".bin", "") for file in all_files if file.endswith(".bin")]
|
|
|
+def models_route_handler():
|
|
|
+ bin_files = models()
|
|
|
return Response(json.dumps(bin_files), mimetype="application/json")
|
|
|
|
|
|
+
|
|
|
if __name__ == "__main__":
|
|
|
app.run(debug=True, threaded=True, port=5001)
|
|
|
app.run()
|