Browse Source

update predict code

Michael Yang 1 year ago
parent
commit
3003fc03fc
3 changed files with 175 additions and 180 deletions
  1. 4 1
      api/types.go
  2. 162 81
      llama/llama.go
  3. 9 98
      llama/utils.go

+ 4 - 1
api/types.go

@@ -134,6 +134,7 @@ type Options struct {
 
 	// Model options
 	NumCtx        int  `json:"num_ctx,omitempty"`
+	NumKeep       int  `json:"num_keep,omitempty"`
 	NumBatch      int  `json:"num_batch,omitempty"`
 	NumGPU        int  `json:"num_gpu,omitempty"`
 	MainGPU       int  `json:"main_gpu,omitempty"`
@@ -158,6 +159,7 @@ type Options struct {
 	Mirostat         int     `json:"mirostat,omitempty"`
 	MirostatTau      float32 `json:"mirostat_tau,omitempty"`
 	MirostatEta      float32 `json:"mirostat_eta,omitempty"`
+	PenalizeNewline  bool    `json:"penalize_newline,omitempty"`
 
 	NumThread int `json:"num_thread,omitempty"`
 }
@@ -176,7 +178,7 @@ func DefaultOptions() Options {
 		UseMMap:  true,
 		UseMLock: false,
 
-		RepeatLastN:      512,
+		RepeatLastN:      64,
 		RepeatPenalty:    1.1,
 		FrequencyPenalty: 0.0,
 		PresencePenalty:  0.0,
@@ -188,6 +190,7 @@ func DefaultOptions() Options {
 		Mirostat:         0,
 		MirostatTau:      5.0,
 		MirostatEta:      0.1,
+		PenalizeNewline:  true,
 
 		NumThread: runtime.NumCPU(),
 	}

+ 162 - 81
llama/llama.go

@@ -1,8 +1,8 @@
 package llama
 
 /*
-#cgo CPPFLAGS: -O3 -DNDEBUG=1 -DGGML_USE_K_QUANTS
-#cgo CXXFLAGS: -std=c++11
+#cgo CPPFLAGS: -O3 -Wall -Wextra -Werror -Wno-unused-function -Wno-unused-variable -DNDEBUG -DGGML_USE_K_QUANTS
+#cgo CXXFLAGS: -std=gnu++11
 #cgo darwin CPPFLAGS: -DGGML_USE_ACCELERATE -DGGML_USE_METAL -DGGML_METAL_NDEBUG
 #cgo darwin LDFLAGS: -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
 #include <stdlib.h>
@@ -21,6 +21,7 @@ struct llama_sample_options
 	int mirostat;
 	float mirostat_tau;
 	float mirostat_eta;
+	bool penalize_newline;
 };
 
 llama_token llama_sample(
@@ -37,6 +38,8 @@ llama_token llama_sample(
 		false,
 	};
 
+	struct llama_token_data newline = candidates_p.data[llama_token_nl()];
+
 	llama_sample_repetition_penalty(
 		ctx, &candidates_p,
 		last_tokens, n_last_tokens,
@@ -47,6 +50,10 @@ llama_token llama_sample(
 		last_tokens, n_last_tokens,
 		opts->frequency_penalty, opts->presence_penalty);
 
+	if (!opts->penalize_newline) {
+		candidates_p.data[llama_token_nl()] = newline;
+	}
+
 	if (opts->temperature <= 0) {
 		return llama_sample_token_greedy(ctx, &candidates_p);
 	}
@@ -82,9 +89,9 @@ import (
 	"errors"
 	"fmt"
 	"io"
+	"log"
 	"os"
 	"strings"
-	"time"
 	"unicode/utf8"
 	"unsafe"
 
@@ -96,6 +103,10 @@ type LLM struct {
 	model  *C.struct_llama_model
 	ctx    *C.struct_llama_context
 
+	last   []C.llama_token
+	embd   []C.llama_token
+	cursor int
+
 	api.Options
 }
 
@@ -152,16 +163,98 @@ func (llm *LLM) Close() {
 }
 
 func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
-	if input := llm.tokenize(prompt); input != nil {
-		embd := make([]C.llama_token, len(ctx))
-		for i := range ctx {
-			embd[i] = C.llama_token(ctx[i])
+	C.llama_reset_timings(llm.ctx)
+
+	tokens := make([]C.llama_token, len(ctx))
+	for i := range tokens {
+		tokens[i] = C.llama_token(ctx[i])
+	}
+
+	if len(tokens) == 0 {
+		tokens = llm.tokenize(" ")
+	}
+
+	llm.marshalPrompt(tokens, prompt)
+
+	C.llama_set_rng_seed(llm.ctx, C.uint(llm.Seed))
+
+	var b bytes.Buffer
+	for {
+		token, err := llm.next()
+		if errors.Is(err, io.EOF) {
+			break
+		} else if err != nil {
+			return err
+		}
+
+		b.WriteString(llm.detokenize(token))
+		if utf8.Valid(b.Bytes()) || b.Len() >= utf8.UTFMax {
+			fn(api.GenerateResponse{Response: b.String()})
+			b.Reset()
 		}
+	}
 
-		return llm.generate(append(embd, input...), fn)
+	last := make([]int, 0, len(llm.last))
+	for _, i := range llm.last {
+		if i != 0 {
+			last = append(last, int(i))
+		}
 	}
 
-	return errors.New("llama: tokenize")
+	timings := C.llama_get_timings(llm.ctx)
+	fn(api.GenerateResponse{
+		Done:               true,
+		Context:            last,
+		PromptEvalCount:    int(timings.n_p_eval),
+		PromptEvalDuration: parseDurationMs(float64(timings.t_p_eval_ms)),
+		EvalCount:          int(timings.n_eval),
+		EvalDuration:       parseDurationMs(float64(timings.t_eval_ms)),
+	})
+
+	return nil
+}
+
+func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_token {
+	tokens := append(ctx, llm.tokenize(prompt)...)
+	if llm.NumKeep < 0 {
+		llm.NumKeep = len(tokens)
+	}
+
+	// min(llm.NumCtx - 4, llm.NumKeep)
+	if llm.NumCtx-4 < llm.NumKeep {
+		llm.NumKeep = llm.NumCtx - 4
+	}
+
+	if len(tokens) >= llm.NumCtx {
+		// truncate input
+		numLeft := (llm.NumCtx - llm.NumKeep) / 2
+		truncated := tokens[:llm.NumKeep]
+		erasedBlocks := (len(tokens) - llm.NumKeep - numLeft - 1) / numLeft
+		truncated = append(truncated, tokens[llm.NumKeep+erasedBlocks*numLeft:]...)
+		copy(llm.last, tokens[len(tokens)-llm.NumCtx:])
+
+		tokens = truncated
+		log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated))
+	} else {
+		llm.last = make([]C.llama_token, llm.NumCtx-len(tokens))
+		llm.last = append(llm.last, tokens...)
+	}
+
+	var i int
+	for i = 0; i < len(llm.embd) && i < len(tokens) && llm.embd[i] == tokens[i]; i++ {
+		// noop
+	}
+
+	llm.embd = tokens
+	if i == len(tokens) {
+		// evaluate at least one token to generate logits
+		i--
+	}
+
+	llm.cursor = i
+
+	log.Printf("prompt: num_past=%d cached=%v eval=%v", i, len(llm.embd[:i]), len(llm.embd[i:]))
+	return tokens
 }
 
 func (llm *LLM) tokenize(prompt string) []C.llama_token {
@@ -185,98 +278,86 @@ func (llm *LLM) detokenize(tokens ...C.llama_token) string {
 	return sb.String()
 }
 
-func (llm *LLM) generate(input []C.llama_token, fn func(api.GenerateResponse)) error {
-	var opts C.struct_llama_sample_options
-	opts.repeat_penalty = C.float(llm.RepeatPenalty)
-	opts.frequency_penalty = C.float(llm.FrequencyPenalty)
-	opts.presence_penalty = C.float(llm.PresencePenalty)
-	opts.temperature = C.float(llm.Temperature)
-	opts.top_k = C.int(llm.TopK)
-	opts.top_p = C.float(llm.TopP)
-	opts.tfs_z = C.float(llm.TFSZ)
-	opts.typical_p = C.float(llm.TypicalP)
-	opts.mirostat = C.int(llm.Mirostat)
-	opts.mirostat_tau = C.float(llm.MirostatTau)
-	opts.mirostat_eta = C.float(llm.MirostatEta)
-
-	output := deque[C.llama_token]{capacity: llm.NumCtx}
-
-	context := deque[int]{capacity: llm.NumCtx / 2}
-	for _, in := range input {
-		context.PushLeft(int(in))
-	}
+func (llm *LLM) next() (C.llama_token, error) {
+	if len(llm.embd) >= llm.NumCtx {
+		numLeft := (llm.NumCtx - llm.NumKeep) / 2
+		truncated := llm.embd[:llm.NumKeep]
+		truncated = append(truncated, llm.embd[len(llm.embd)-numLeft:]...)
 
-	var b bytes.Buffer
-	for C.llama_get_kv_cache_token_count(llm.ctx) < C.int(llm.NumCtx) {
-		if retval := C.llama_eval(llm.ctx, unsafe.SliceData(input), C.int(len(input)), C.llama_get_kv_cache_token_count(llm.ctx), C.int(llm.NumThread)); retval != 0 {
-			return errors.New("llama: eval")
-		}
+		llm.embd = truncated
+		llm.cursor = llm.NumKeep
+		log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d cursor=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated), llm.cursor)
+	}
 
-		token, err := llm.sample(output, &opts)
-		if errors.Is(err, io.EOF) {
+	for {
+		if llm.cursor >= len(llm.embd) {
 			break
-		} else if err != nil {
-			return err
 		}
 
-		b.WriteString(llm.detokenize(token))
-		if utf8.Valid(b.Bytes()) || b.Len() >= utf8.UTFMax {
-			// call the callback
-			fn(api.GenerateResponse{
-				Response: b.String(),
-			})
-
-			output.PushLeft(token)
-			context.PushLeft(int(token))
-			b.Reset()
+		numEval := len(llm.embd) - llm.cursor
+		if numEval > llm.NumBatch {
+			numEval = llm.NumBatch
 		}
 
-		input = []C.llama_token{token}
-	}
-
-	dur := func(ms float64) time.Duration {
-		d, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
-		if err != nil {
-			panic(err)
+		if retval := C.llama_eval(llm.ctx, unsafe.SliceData(llm.embd[llm.cursor:]), C.int(numEval), C.int(llm.cursor), C.int(llm.NumThread)); retval != 0 {
+			return 0, fmt.Errorf("llama_eval: %d", retval)
 		}
 
-		return d
+		llm.cursor += numEval
 	}
 
-	timings := C.llama_get_timings(llm.ctx)
-	fn(api.GenerateResponse{
-		Done:               true,
-		Context:            context.Data(),
-		PromptEvalCount:    int(timings.n_p_eval),
-		PromptEvalDuration: dur(float64(timings.t_p_eval_ms)),
-		EvalCount:          int(timings.n_eval),
-		EvalDuration:       dur(float64(timings.t_eval_ms)),
-	})
-
-	return nil
-}
-
-func (llm *LLM) sample(output deque[C.llama_token], opts *C.struct_llama_sample_options) (C.llama_token, error) {
-	numVocab := int(C.llama_n_vocab(llm.ctx))
+	var sampleOpts C.struct_llama_sample_options
+	sampleOpts.repeat_penalty = C.float(llm.RepeatPenalty)
+	sampleOpts.frequency_penalty = C.float(llm.FrequencyPenalty)
+	sampleOpts.presence_penalty = C.float(llm.PresencePenalty)
+	sampleOpts.temperature = C.float(llm.Temperature)
+	sampleOpts.top_k = C.int(llm.TopK)
+	sampleOpts.top_p = C.float(llm.TopP)
+	sampleOpts.tfs_z = C.float(llm.TFSZ)
+	sampleOpts.typical_p = C.float(llm.TypicalP)
+	sampleOpts.mirostat = C.int(llm.Mirostat)
+	sampleOpts.mirostat_tau = C.float(llm.MirostatTau)
+	sampleOpts.mirostat_eta = C.float(llm.MirostatEta)
+	sampleOpts.penalize_newline = C.bool(llm.PenalizeNewline)
+
+	numVocab := C.llama_n_vocab(llm.ctx)
 	logits := unsafe.Slice(C.llama_get_logits(llm.ctx), numVocab)
 
-	candidates := deque[C.struct_llama_token_data]{capacity: numVocab}
-	for i := 0; i < candidates.Cap(); i++ {
-		candidates.PushLeft(C.struct_llama_token_data{
+	// TODO: logit bias
+
+	candidates := make([]C.llama_token_data, numVocab)
+	for i := range logits {
+		candidates[i] = C.llama_token_data{
 			id:    C.int(i),
 			logit: logits[i],
 			p:     0,
-		})
+		}
 	}
 
+	repeatLastN := llm.RepeatLastN
+	if len(llm.last) < repeatLastN {
+		repeatLastN = len(llm.last)
+	}
+
+	if llm.NumCtx < repeatLastN {
+		repeatLastN = llm.NumCtx
+	}
+
+	lastN := llm.last[len(llm.last)-repeatLastN:]
+
 	token := C.llama_sample(
 		llm.ctx,
-		unsafe.SliceData(candidates.Data()), C.size_t(candidates.Len()),
-		unsafe.SliceData(output.Data()), C.size_t(output.Len()),
-		opts)
-	if token != C.llama_token_eos() {
-		return token, nil
+		unsafe.SliceData(candidates), C.size_t(len(candidates)),
+		unsafe.SliceData(lastN), C.size_t(len(lastN)),
+		&sampleOpts,
+	)
+
+	llm.last = append(llm.last, token)
+	llm.embd = append(llm.embd, token)
+
+	if token == C.llama_token_eos() {
+		return 0, io.EOF
 	}
 
-	return 0, io.EOF
+	return token, nil
 }

+ 9 - 98
llama/utils.go

@@ -1,104 +1,15 @@
 package llama
 
-type node[T any] struct {
-	t    T
-	next *node[T]
-	prev *node[T]
-}
-
-type deque[T any] struct {
-	head     *node[T]
-	tail     *node[T]
-	size     int
-	capacity int
-}
-
-func (d *deque[T]) Empty() bool {
-	return d.size == 0
-}
-
-func (d *deque[T]) Len() int {
-	return d.size
-}
-
-func (d *deque[T]) Cap() int {
-	return d.capacity
-}
-
-func (d *deque[T]) Push(t T) {
-	if d.capacity > 0 && d.size >= d.capacity {
-		d.PopLeft()
-	}
-
-	n := node[T]{t: t}
-	if d.head != nil {
-		n.next = d.head
-		d.head.prev = &n
-		d.head = &n
-	} else {
-		d.head = &n
-		d.tail = &n
-	}
-
-	d.size++
-}
-
-func (d *deque[T]) PushLeft(t T) {
-	if d.capacity > 0 && d.size >= d.capacity {
-		d.Pop()
-	}
-
-	n := node[T]{t: t}
-	if d.tail != nil {
-		n.prev = d.tail
-		d.tail.next = &n
-		d.tail = &n
-	} else {
-		d.head = &n
-		d.tail = &n
-	}
-
-	d.size++
-}
-
-func (d *deque[T]) Pop() *T {
-	if d.Empty() {
-		return nil
-	}
-
-	head := d.head
-	d.head = head.next
-	if d.head != nil {
-		d.head.prev = nil
-	} else {
-		d.tail = nil
-	}
-
-	d.size--
-	return &head.t
-}
-
-func (d *deque[T]) PopLeft() *T {
-	if d.Empty() {
-		return nil
-	}
-
-	tail := d.tail
-	d.tail = tail.prev
-	if d.tail != nil {
-		d.tail.next = nil
-	} else {
-		d.head = nil
-	}
-
-	d.size--
-	return &tail.t
-}
+import (
+	"fmt"
+	"time"
+)
 
-func (d *deque[T]) Data() (data []T) {
-	for n := d.head; n != nil; n = n.next {
-		data = append(data, n.t)
+func parseDurationMs(ms float64) time.Duration {
+	dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
+	if err != nil {
+		panic(err)
 	}
 
-	return data
+	return dur
 }