|
@@ -31,19 +31,19 @@ def unload(model):
|
|
return None
|
|
return None
|
|
|
|
|
|
|
|
|
|
-def generate(model, prompt):
|
|
|
|
|
|
+def query(model, prompt):
|
|
# auto load
|
|
# auto load
|
|
error = load(model)
|
|
error = load(model)
|
|
if error is not None:
|
|
if error is not None:
|
|
return error
|
|
return error
|
|
- stream = llms[model](
|
|
|
|
|
|
+ generated = llms[model](
|
|
str(prompt), # TODO: optimize prompt based on model
|
|
str(prompt), # TODO: optimize prompt based on model
|
|
max_tokens=4096,
|
|
max_tokens=4096,
|
|
stop=["Q:", "\n"],
|
|
stop=["Q:", "\n"],
|
|
echo=True,
|
|
echo=True,
|
|
stream=True,
|
|
stream=True,
|
|
)
|
|
)
|
|
- for output in stream:
|
|
|
|
|
|
+ for output in generated:
|
|
yield json.dumps(output)
|
|
yield json.dumps(output)
|
|
|
|
|
|
|
|
|
|
@@ -91,7 +91,7 @@ def generate_route_handler():
|
|
if not os.path.exists(f"./models/{model}.bin"):
|
|
if not os.path.exists(f"./models/{model}.bin"):
|
|
return {"error": "The model does not exist."}, 400
|
|
return {"error": "The model does not exist."}, 400
|
|
return Response(
|
|
return Response(
|
|
- stream_with_context(generate(model, prompt)), mimetype="text/event-stream"
|
|
|
|
|
|
+ stream_with_context(query(model, prompt)), mimetype="text/event-stream"
|
|
)
|
|
)
|
|
|
|
|
|
|
|
|
|
@@ -117,5 +117,19 @@ def serve(port, debug):
|
|
app.run(host="0.0.0.0", port=port, debug=debug)
|
|
app.run(host="0.0.0.0", port=port, debug=debug)
|
|
|
|
|
|
|
|
|
|
|
|
+@cli.command()
|
|
|
|
+@click.option("--model", default="vicuna-7b-v1.3.ggmlv3.q8_0", help="The model to use")
|
|
|
|
+@click.option("--prompt", default="", help="The prompt for the model")
|
|
|
|
+def generate(model, prompt):
|
|
|
|
+ if prompt == "":
|
|
|
|
+ prompt = input("Prompt: ")
|
|
|
|
+ output = ""
|
|
|
|
+ for generated in query(model, prompt):
|
|
|
|
+ generated_json = json.loads(generated)
|
|
|
|
+ text = generated_json["choices"][0]["text"]
|
|
|
|
+ output += text
|
|
|
|
+ print(f"\r{output}", end="", flush=True)
|
|
|
|
+
|
|
|
|
+
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
cli()
|
|
cli()
|