Bläddra i källkod

restructure proto.py

Bruce MacDonald 1 år sedan
förälder
incheckning
2e99e7d5cb
1 ändrade filer med 57 tillägg och 37 borttagningar
  1. 57 37
      proto.py

+ 57 - 37
proto.py

@@ -1,5 +1,6 @@
 import json
 import os
+import threading
 from llama_cpp import Llama
 from flask import Flask, Response, stream_with_context, request
 from flask_cors import CORS
@@ -9,77 +10,96 @@ CORS(app)  # enable CORS for all routes
 
 # llms tracks which models are loaded
 llms = {}
+lock = threading.Lock()
+
+
+def load(model):
+    with lock:
+        if not os.path.exists(f"./models/{model}.bin"):
+            return {"error": "The model does not exist."}
+        if model not in llms:
+            llms[model] = Llama(model_path=f"./models/{model}.bin")
+    return None
+
+
+def unload(model):
+    with lock:
+        if not os.path.exists(f"./models/{model}.bin"):
+            return {"error": "The model does not exist."}
+        llms.pop(model, None)
+    return None
+
+
+def generate(model, prompt):
+    # auto load
+    error = load(model)
+    if error is not None:
+        return error
+    stream = llms[model](
+        str(prompt),  # TODO: optimize prompt based on model
+        max_tokens=4096,
+        stop=["Q:", "\n"],
+        echo=True,
+        stream=True,
+    )
+    for output in stream:
+        yield json.dumps(output)
+
+
+def models():
+    all_files = os.listdir("./models")
+    bin_files = [
+        file.replace(".bin", "") for file in all_files if file.endswith(".bin")
+    ]
+    return bin_files
 
 
 @app.route("/load", methods=["POST"])
-def load():
+def load_route_handler():
     data = request.get_json()
     model = data.get("model")
-
     if not model:
         return Response("Model is required", status=400)
-    if not os.path.exists(f"./models/{model}.bin"):
-        return {"error": "The model does not exist."}, 400
-
-    if model not in llms:
-        llms[model] = Llama(model_path=f"./models/{model}.bin")
-
+    error = load(model)
+    if error is not None:
+        return error
     return Response(status=204)
 
 
 @app.route("/unload", methods=["POST"])
-def unload():
+def unload_route_handler():
     data = request.get_json()
     model = data.get("model")
-
     if not model:
         return Response("Model is required", status=400)
-    if not os.path.exists(f"./models/{model}.bin"):
-        return {"error": "The model does not exist."}, 400
-
-    llms.pop(model, None)
-
+    error = unload(model)
+    if error is not None:
+        return error
     return Response(status=204)
 
 
 @app.route("/generate", methods=["POST"])
-def generate():
+def generate_route_handler():
     data = request.get_json()
     model = data.get("model")
     prompt = data.get("prompt")
-
     if not model:
         return Response("Model is required", status=400)
     if not prompt:
         return Response("Prompt is required", status=400)
     if not os.path.exists(f"./models/{model}.bin"):
         return {"error": "The model does not exist."}, 400
-
-    if model not in llms:
-        # auto load
-        llms[model] = Llama(model_path=f"./models/{model}.bin")
-
-    def stream_response():
-        stream = llms[model](
-            str(prompt),  # TODO: optimize prompt based on model
-            max_tokens=4096,
-            stop=["Q:", "\n"],
-            echo=True,
-            stream=True,
-        )
-        for output in stream:
-            yield json.dumps(output)
-
     return Response(
-        stream_with_context(stream_response()), mimetype="text/event-stream"
+        stream_with_context(generate(model, prompt)), mimetype="text/event-stream"
     )
 
+
 @app.route("/models", methods=["GET"])
-def models():
-    all_files = os.listdir("./models")
-    bin_files = [file.replace(".bin", "") for file in all_files if file.endswith(".bin")]
+def models_route_handler():
+    bin_files = models()
     return Response(json.dumps(bin_files), mimetype="application/json")
 
+
 if __name__ == "__main__":
     app.run(debug=True, threaded=True, port=5001)
     app.run()