Browse Source

add generate command

Bruce MacDonald 1 year ago
parent
commit
3ca8f72327
1 changed files with 18 additions and 4 deletions
  1. 18 4
      proto.py

+ 18 - 4
proto.py

@@ -31,19 +31,19 @@ def unload(model):
     return None
 
 
-def generate(model, prompt):
+def query(model, prompt):
     # auto load
     error = load(model)
     if error is not None:
         return error
-    stream = llms[model](
+    generated = llms[model](
         str(prompt),  # TODO: optimize prompt based on model
         max_tokens=4096,
         stop=["Q:", "\n"],
         echo=True,
         stream=True,
     )
-    for output in stream:
+    for output in generated:
         yield json.dumps(output)
 
 
@@ -91,7 +91,7 @@ def generate_route_handler():
     if not os.path.exists(f"./models/{model}.bin"):
         return {"error": "The model does not exist."}, 400
     return Response(
-        stream_with_context(generate(model, prompt)), mimetype="text/event-stream"
+        stream_with_context(query(model, prompt)), mimetype="text/event-stream"
     )
 
 
@@ -117,5 +117,19 @@ def serve(port, debug):
     app.run(host="0.0.0.0", port=port, debug=debug)
 
 
+@cli.command()
+@click.option("--model", default="vicuna-7b-v1.3.ggmlv3.q8_0", help="The model to use")
+@click.option("--prompt", default="", help="The prompt for the model")
+def generate(model, prompt):
+    if prompt == "":
+        prompt = input("Prompt: ")
+    output = ""
+    for generated in query(model, prompt):
+        generated_json = json.loads(generated)
+        text = generated_json["choices"][0]["text"]
+        output += text
+        print(f"\r{output}", end="", flush=True)
+
+
 if __name__ == "__main__":
     cli()