123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137 |
- import json
- import os
- import threading
- import click
- from llama_cpp import Llama
- from flask import Flask, Response, stream_with_context, request
- from flask_cors import CORS
- from template import template
- app = Flask(__name__)
- CORS(app) # enable CORS for all routes
- # llms tracks which models are loaded
- llms = {}
- lock = threading.Lock()
- def load(model):
- with lock:
- if not os.path.exists(f"./models/{model}.bin"):
- return {"error": "The model does not exist."}
- if model not in llms:
- llms[model] = Llama(model_path=f"./models/{model}.bin")
- return None
- def unload(model):
- with lock:
- if not os.path.exists(f"./models/{model}.bin"):
- return {"error": "The model does not exist."}
- llms.pop(model, None)
- return None
- def query(model, prompt):
- # auto load
- error = load(model)
- if error is not None:
- return error
- generated = llms[model](
- str(prompt), # TODO: optimize prompt based on model
- max_tokens=4096,
- stop=["Q:", "\n"],
- echo=True,
- stream=True,
- )
- for output in generated:
- yield json.dumps(output)
- def models():
- all_files = os.listdir("./models")
- bin_files = [
- file.replace(".bin", "") for file in all_files if file.endswith(".bin")
- ]
- return bin_files
- @app.route("/load", methods=["POST"])
- def load_route_handler():
- data = request.get_json()
- model = data.get("model")
- if not model:
- return Response("Model is required", status=400)
- error = load(model)
- if error is not None:
- return error
- return Response(status=204)
- @app.route("/unload", methods=["POST"])
- def unload_route_handler():
- data = request.get_json()
- model = data.get("model")
- if not model:
- return Response("Model is required", status=400)
- error = unload(model)
- if error is not None:
- return error
- return Response(status=204)
- @app.route("/generate", methods=["POST"])
- def generate_route_handler():
- data = request.get_json()
- model = data.get("model")
- prompt = data.get("prompt")
- if not model:
- return Response("Model is required", status=400)
- if not prompt:
- return Response("Prompt is required", status=400)
- if not os.path.exists(f"./models/{model}.bin"):
- return {"error": "The model does not exist."}, 400
- return Response(
- stream_with_context(query(model, prompt)), mimetype="text/event-stream"
- )
- @app.route("/models", methods=["GET"])
- def models_route_handler():
- bin_files = models()
- return Response(json.dumps(bin_files), mimetype="application/json")
- @click.group(invoke_without_command=True)
- @click.pass_context
- def cli(ctx):
- # allows the script to respond to command line input when executed directly
- if ctx.invoked_subcommand is None:
- click.echo(ctx.get_help())
- @cli.command()
- @click.option("--port", default=5000, help="Port to run the server on")
- @click.option("--debug", default=False, help="Enable debug mode")
- def serve(port, debug):
- print("Serving on http://localhost:{port}")
- app.run(host="0.0.0.0", port=port, debug=debug)
- @cli.command()
- @click.option("--model", default="vicuna-7b-v1.3.ggmlv3.q8_0", help="The model to use")
- @click.option("--prompt", default="", help="The prompt for the model")
- def generate(model, prompt):
- if prompt == "":
- prompt = input("Prompt: ")
- output = ""
- prompt = template(model, prompt)
- for generated in query(model, prompt):
- generated_json = json.loads(generated)
- text = generated_json["choices"][0]["text"]
- output += text
- print(f"\r{output}", end="", flush=True)
- if __name__ == "__main__":
- cli()
|