Browse Source

basic progress

jmorganca 11 months ago
parent
commit
43efc893d7
2 changed files with 93 additions and 26 deletions
  1. 24 4
      llama/llama.go
  2. 69 22
      llama/runner/runner.go

+ 24 - 4
llama/llama.go

@@ -4,10 +4,10 @@ package llama
 // #cgo CXXFLAGS: -std=c++11 -DNDEBUG -DLOG_DISABLE_LOGS
 // #cgo darwin,arm64 CFLAGS: -DGGML_USE_METAL -DGGML_METAL_EMBED_LIBRARY -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
 // #cgo darwin,arm64 CXXFLAGS: -DGGML_USE_METAL -DGGML_METAL_EMBED_LIBRARY -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
-// #cgo darwin,arm64 LDFLAGS: -ld_classic ${SRCDIR}/ggml-metal.o -framework Foundation -framework Metal -framework MetalKit -framework Accelerate
+// #cgo darwin,arm64 LDFLAGS: ${SRCDIR}/ggml-metal.o -framework Foundation -framework Metal -framework MetalKit -framework Accelerate
 // #cgo darwin,amd64 CFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
 // #cgo darwin,amd64 CXXFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
-// #cgo darwin,amd64 LDFLAGS: -ld_classic -framework Foundation -framework Accelerate
+// #cgo darwin,amd64 LDFLAGS: -framework Foundation -framework Accelerate
 // #cgo linux CFLAGS: -D_GNU_SOURCE
 // #cgo linux CXXFLAGS: -D_GNU_SOURCE
 // #cgo windows LDFLAGS: -lmsvcrt
@@ -29,11 +29,14 @@ package llama
 // #include "clip.h"
 // #include "llava.h"
 // #include "sampling_ext.h"
+//
+// bool llamaProgressCallback(float progress, void *user_data);
 import "C"
 import (
 	"errors"
 	"fmt"
 	"runtime"
+	"runtime/cgo"
 	"strings"
 	"unsafe"
 )
@@ -65,10 +68,26 @@ type ModelParams struct {
 	c C.struct_llama_model_params
 }
 
-func NewModelParams(numGpuLayers int, mainGpu int) ModelParams {
+//export llamaProgressCallback
+func llamaProgressCallback(progress C.float, userData unsafe.Pointer) C.bool {
+	handle := cgo.Handle(userData)
+	callback := handle.Value().(func(float32))
+	callback(float32(progress))
+	return true
+}
+
+func NewModelParams(numGpuLayers int, mainGpu int, callback func(float32)) ModelParams {
 	params := C.llama_model_default_params()
 	params.n_gpu_layers = C.int(numGpuLayers)
 	params.main_gpu = C.int32_t(mainGpu)
+
+	handle := cgo.NewHandle(callback)
+	params.progress_callback = C.llama_progress_callback(C.llamaProgressCallback)
+	params.progress_callback_user_data = unsafe.Pointer(handle)
+	runtime.SetFinalizer(&params, func(p *C.struct_llama_model_params) {
+		handle.Delete()
+	})
+
 	return ModelParams{c: params}
 }
 
@@ -233,7 +252,8 @@ func (m *Model) TokenToPiece(token int) string {
 	return strings.TrimRight(string(buf), "\x00")
 }
 
-func (m *Model) Tokenize(text string, maxTokens int, addSpecial bool, parseSpecial bool) ([]int, error) {
+func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int, error) {
+	maxTokens := len(text) + 2
 	cTokens := make([]C.llama_token, maxTokens)
 	cText := C.CString(text)
 	defer C.free(unsafe.Pointer(cText))

+ 69 - 22
llama/runner/runner.go

@@ -7,6 +7,7 @@ import (
 	"fmt"
 	"log"
 	"log/slog"
+	"math"
 	"net"
 	"net/http"
 	"runtime"
@@ -28,6 +29,9 @@ type Sequence struct {
 	// channel to send responses over
 	responses chan string
 
+	// number of tokens to predict
+	numPredict int
+
 	samplingCtx *llama.SamplingContext
 
 	// channel to send back the embedding if embedding only
@@ -38,6 +42,8 @@ type Sequence struct {
 
 	// true if an embedding are to be returned instead of text generation
 	embeddingOnly bool
+
+	doneReason string
 }
 
 // prompt returns true if the prompt is still being processed
@@ -46,11 +52,18 @@ func (s *Sequence) prompt() bool {
 }
 
 func (s *Server) NewSequence(prompt string, stop []string, params *llama.SamplingParams, embedding bool) *Sequence {
-	tokens, err := s.lc.Model().Tokenize(prompt, 2048, false, true)
+	tokens, err := s.lc.Model().Tokenize(prompt, false, true)
 	if err != nil {
 		panic(err)
 	}
 
+	// truncate to last n tokens
+	// TODO: this shouldn't happen and will severely impact generation
+	// quality. instead we should ensure to cut prompt in the API.
+	if len(tokens) > s.numCtx {
+		tokens = tokens[:s.numCtx]
+	}
+
 	var sc *llama.SamplingContext
 	if params != nil {
 		sc = llama.NewSamplingContext(*params)
@@ -83,9 +96,16 @@ type Server struct {
 	// TODO (jmorganca): this can probably be moved into run()
 	seqs []*Sequence
 
+	// context window size
+	numCtx int
+
 	mu sync.Mutex
 
 	cond *sync.Cond
+
+	progress float32
+
+	status string
 }
 
 func (s *Server) allNil() bool {
@@ -183,6 +203,15 @@ func (s *Server) run(ctx context.Context) {
 					continue
 				}
 
+				// we've reached the context limit
+				if seq.nPast > s.numCtx {
+					seq.doneReason = "limit"
+					close(seq.responses)
+					s.lc.KvCacheSeqRm(i, 0, -1)
+					s.seqs[i] = nil
+					continue
+				}
+
 				for j, t := range seq.tokens {
 					// todo: make this n_batch
 					if j > s.batchSize {
@@ -252,6 +281,7 @@ func (s *Server) run(ctx context.Context) {
 					// as it's important for the /api/generate context
 					// seq.responses <- piece
 
+					seq.doneReason = "stop"
 					close(seq.responses)
 					seq.samplingCtx.Free()
 					pieces[i] = []string{}
@@ -273,6 +303,7 @@ func (s *Server) run(ctx context.Context) {
 					}
 
 					s.lc.KvCacheSeqRm(i, 0, -1)
+					seq.doneReason = "stop"
 					close(seq.responses)
 					seq.samplingCtx.Free()
 					pieces[i] = []string{}
@@ -411,6 +442,24 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
 	}
 }
 
+type HealthResponse struct {
+	Status   string  `json:"status"`
+	Progress float32 `json:"progress"`
+}
+
+// TODO (jmorganca): is it safe to do this concurrently with decoding?
+func (s *Server) health(w http.ResponseWriter, r *http.Request) {
+	w.Header().Set("Content-Type", "application/json")
+
+	if err := json.NewEncoder(w).Encode(&HealthResponse{
+		Status:   s.status,
+		Progress: s.progress,
+	}); err != nil {
+		log.Println("Failed to encode result:", err)
+		return
+	}
+}
+
 func main() {
 	mpath := flag.String("model", "", "Path to model binary file")
 	ppath := flag.String("projector", "", "Path to projector binary file")
@@ -425,36 +474,31 @@ func main() {
 	threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
 	flag.Parse()
 
+	server := &Server{
+		numCtx:    *numCtx,
+		batchSize: *batchSize,
+		parallel:  *parallel,
+		seqs:      make([]*Sequence, *parallel),
+		status:    "loading",
+	}
+
 	// load the model
 	llama.BackendInit()
-	params := llama.NewModelParams(*nGpuLayers, *mainGpu)
-	model := llama.LoadModelFromFile(*mpath, params)
+	params := llama.NewModelParams(*nGpuLayers, *mainGpu, func(progress float32) {
+		slog.Info("Loading model", "progress %", math.Round(float64(progress*100)))
+		server.progress = progress
+	})
+	server.model = llama.LoadModelFromFile(*mpath, params)
 
 	if *lpath != "" {
-		model.ApplyLoraFromFile(*lpath, 1.0, "", *threads)
+		server.model.ApplyLoraFromFile(*lpath, 1.0, "", *threads)
 	}
 
 	ctxParams := llama.NewContextParams(*numCtx, *threads, *flashAttention)
-	lc := llama.NewContextWithModel(model, ctxParams)
-	if lc == nil {
-		panic("Failed to create context")
-	}
+	server.lc = llama.NewContextWithModel(server.model, ctxParams)
 
-	var cc *llama.ClipContext
 	if *ppath != "" {
-		cc = llama.NewClipContext(*ppath)
-		if cc == nil {
-			panic("Failed to create clip context")
-		}
-	}
-
-	server := &Server{
-		model:     model,
-		lc:        lc,
-		cc:        cc,
-		batchSize: *batchSize,
-		parallel:  *parallel,
-		seqs:      make([]*Sequence, *parallel),
+		server.cc = llama.NewClipContext(*ppath)
 	}
 
 	server.cond = sync.NewCond(&server.mu)
@@ -473,11 +517,14 @@ func main() {
 	mux := http.NewServeMux()
 	mux.HandleFunc("/embeddings", server.embeddings)
 	mux.HandleFunc("/completion", server.completion)
+	mux.HandleFunc("/health", server.health)
 
 	httpServer := http.Server{
 		Handler: mux,
 	}
 
+	server.status = "ready"
+
 	log.Println("Server listening on", addr)
 	if err := httpServer.Serve(listener); err != nil {
 		log.Fatal("server error:", err)