Quellcode durchsuchen

load and unload model endpoints

Bruce MacDonald vor 1 Jahr
Ursprung
Commit
ebec1c61db
1 geänderte Dateien mit 33 neuen und 1 gelöschten Zeilen
  1. 33 1
      server/server.py

+ 33 - 1
server/server.py

@@ -11,6 +11,37 @@ CORS(app)  # enable CORS for all routes
 llms = {}
 
 
+@app.route("/load", methods=["POST"])
+def load():
+    data = request.get_json()
+    model = data.get("model")
+
+    if not model:
+        return Response("Model is required", status=400)
+    if not os.path.exists(f"../models/{model}.bin"):
+        return {"error": "The model does not exist."}, 400
+
+    if model not in llms:
+        llms[model] = Llama(model_path=f"../models/{model}.bin")
+
+    return Response(status=204)
+
+
+@app.route("/unload", methods=["POST"])
+def unload():
+    data = request.get_json()
+    model = data.get("model")
+
+    if not model:
+        return Response("Model is required", status=400)
+    if not os.path.exists(f"../models/{model}.bin"):
+        return {"error": "The model does not exist."}, 400
+
+    llms.pop(model, None)
+
+    return Response(status=204)
+
+
 @app.route("/generate", methods=["POST"])
 def generate():
     data = request.get_json()
@@ -22,9 +53,10 @@ def generate():
     if not prompt:
         return Response("Prompt is required", status=400)
     if not os.path.exists(f"../models/{model}.bin"):
-        return {"error": "The model file does not exist."}, 400
+        return {"error": "The model does not exist."}, 400
 
     if model not in llms:
+        # auto load
         llms[model] = Llama(model_path=f"../models/{model}.bin")
 
     def stream_response():