Explorar o código

llama server wrapper

Bruce MacDonald hai 1 ano
pai
achega
0758cb2d4b
Modificáronse 7 ficheiros con 83 adicións e 138 borrados
  1. 34 0
      server/README.md
  2. 0 2
      server/build.sh
  3. 0 8
      server/go.mod
  4. 0 15
      server/go.sum
  5. 0 113
      server/main.go
  6. 2 0
      server/requirements.txt
  7. 47 0
      server/server.py

+ 34 - 0
server/README.md

@@ -0,0 +1,34 @@
+# Server
+
+🙊
+
+## Installation
+
+If using Apple silicon, you need a Python version that supports arm64:
+
+```bash
+wget https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-MacOSX-arm64.sh
+bash Miniforge3-MacOSX-arm64.sh
+```
+
+Get the dependencies:
+
+```bash
+pip install llama-cpp-python
+pip install -r requirements.txt
+```
+
+## Running
+
+Put your model in `models/` and run:
+
+```bash
+python server.py
+```
+
+## API
+
+### `POST /generate`
+
+model: `string` - The name of the model to use in the `models` folder.
+prompt: `string` - The prompt to use.

+ 0 - 2
server/build.sh

@@ -1,2 +0,0 @@
-LIBRARY_PATH=$PWD/go-llama.cpp C_INCLUDE_PATH=$PWD/go-llama.cpp go build .
-

+ 0 - 8
server/go.mod

@@ -1,8 +0,0 @@
-module github.com/keypairdev/keypair
-
-go 1.20
-
-require (
-	github.com/go-skynet/go-llama.cpp v0.0.0-20230620192753-7a36befaece1
-	github.com/sashabaranov/go-openai v1.11.3
-)

+ 0 - 15
server/go.sum

@@ -1,15 +0,0 @@
-github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
-github.com/go-skynet/go-llama.cpp v0.0.0-20230620192753-7a36befaece1 h1:UQ8y3kHxBgh3BnaW06y/X97fEN48yHPwWobMz8/aztU=
-github.com/go-skynet/go-llama.cpp v0.0.0-20230620192753-7a36befaece1/go.mod h1:tzi97YvT1bVQ+iTG39LvpDkKG1WbizgtljC+orSoM40=
-github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
-github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
-github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
-github.com/onsi/ginkgo/v2 v2.11.0 h1:WgqUCUt/lT6yXoQ8Wef0fsNn5cAuMK7+KT9UFRz2tcU=
-github.com/onsi/gomega v1.27.8 h1:gegWiwZjBsf2DgiSbf5hpokZ98JVDMcWkUiigk6/KXc=
-github.com/sashabaranov/go-openai v1.11.3 h1:bvwWF8hj4UhPlswBdL9/IfOpaHXfzGCJO8WY8ml9sGc=
-github.com/sashabaranov/go-openai v1.11.3/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
-golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M=
-golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s=
-golang.org/x/text v0.9.0 h1:2sjJmO8cDvYveuX97RDLsxlyUxLl+GHoLxBiRdHllBE=
-golang.org/x/tools v0.9.3 h1:Gn1I8+64MsuTb/HpH+LmQtNas23LhUVr3rYZ0eKuaMM=
-gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=

+ 0 - 113
server/main.go

@@ -1,113 +0,0 @@
-package main
-
-import (
-	"bytes"
-	"context"
-	"fmt"
-	"io"
-	"net/http"
-	"os"
-	"runtime"
-
-	"github.com/sashabaranov/go-openai"
-
-	llama "github.com/go-skynet/go-llama.cpp"
-)
-
-
-type Model interface {
-	Name() string
-	Handler(w http.ResponseWriter, r *http.Request)
-}
-
-type LLama7B struct {
-	llama *llama.LLama
-}
-
-func NewLLama7B() *LLama7B {
-	llama, err := llama.New("./models/7B/ggml-model-q4_0.bin", llama.EnableF16Memory, llama.SetContext(128), llama.EnableEmbeddings, llama.SetGPULayers(128))
-	if err != nil {
-		fmt.Println("Loading the model failed:", err.Error())
-		os.Exit(1)
-	}
-
-	return &LLama7B{
-		llama: llama,
-	}
-}
-
-func (l *LLama7B) Name() string {
-	return "LLaMA 7B"
-}
-
-func (m *LLama7B) Handler(w http.ResponseWriter, r *http.Request) {
-	var text bytes.Buffer
-	io.Copy(&text, r.Body)
-
-	_, err := m.llama.Predict(text.String(), llama.Debug, llama.SetTokenCallback(func(token string) bool {
-		w.Write([]byte(token))
-		return true
-	}), llama.SetTokens(512), llama.SetThreads(runtime.NumCPU()), llama.SetTopK(90), llama.SetTopP(0.86), llama.SetStopWords("llama"))
-
-	if err != nil {
-		fmt.Println("Predict failed:", err.Error())
-		os.Exit(1)
-	}
-
-	embeds, err := m.llama.Embeddings(text.String())
-	if err != nil {
-		fmt.Printf("Embeddings: error %s \n", err.Error())
-	}
-	fmt.Printf("Embeddings: %v", embeds)
-
-	w.Header().Set("Content-Type", "text/event-stream")
-    w.Header().Set("Cache-Control", "no-cache")
-    w.Header().Set("Connection", "keep-alive")
-}
-
-type GPT4 struct {
-	apiKey string
-}
-
-func (g *GPT4) Name() string {
-	return "OpenAI GPT-4"
-}
-
-func (g *GPT4) Handler(w http.ResponseWriter, r *http.Request) {
-	w.WriteHeader(http.StatusOK)
-	client := openai.NewClient("your token")
-	resp, err := client.CreateChatCompletion(
-		context.Background(),
-		openai.ChatCompletionRequest{
-			Model: openai.GPT3Dot5Turbo,
-			Messages: []openai.ChatCompletionMessage{
-				{
-					Role:    openai.ChatMessageRoleUser,
-					Content: "Hello!",
-				},
-			},
-		},
-	)
-	if err != nil {
-		fmt.Printf("chat completion error: %v\n", err)
-		return
-	}
-
-	fmt.Println(resp.Choices[0].Message.Content)
-
-	w.Header().Set("Content-Type", "text/plain; charset=utf-8")
-	w.WriteHeader(http.StatusOK)
-}
-
-// TODO: add subcommands to spawn different models
-func main() {
-	model := &LLama7B{}
-	
-	http.HandleFunc("/generate", model.Handler)
-
-	fmt.Println("Starting server on :8080")
-	if err := http.ListenAndServe(":8080", nil); err != nil {
-		fmt.Printf("Error starting server: %s\n", err)
-		return
-	}
-}

+ 2 - 0
server/requirements.txt

@@ -0,0 +1,2 @@
+Flask==2.3.2
+flask_cors==3.0.10

+ 47 - 0
server/server.py

@@ -0,0 +1,47 @@
+import json
+import os
+from llama_cpp import Llama
+from flask import Flask, Response, stream_with_context, request
+from flask_cors import CORS, cross_origin
+
+app = Flask(__name__)
+CORS(app)  # enable CORS for all routes
+
+# llms tracks which models are loaded
+llms = {}
+
+
+@app.route("/generate", methods=["POST"])
+def generate():
+    data = request.get_json()
+    model = data.get("model")
+    prompt = data.get("prompt")
+
+    if not model:
+        return Response("Model is required", status=400)
+    if not prompt:
+        return Response("Prompt is required", status=400)
+    if not os.path.exists(f"../models/{model}.bin"):
+        return {"error": "The model file does not exist."}, 400
+
+    if model not in llms:
+        llms[model] = Llama(model_path=f"../models/{model}.bin")
+
+    def stream_response():
+        stream = llms[model](
+            str(prompt),  # TODO: optimize prompt based on model
+            max_tokens=4096,
+            stop=["Q:", "\n"],
+            echo=True,
+            stream=True,
+        )
+        for output in stream:
+            yield json.dumps(output)
+
+    return Response(
+        stream_with_context(stream_response()), mimetype="text/event-stream"
+    )
+
+
+if __name__ == "__main__":
+    app.run(debug=True, threaded=True, port=5000)