Browse Source

runner.go: Shift context window when KV cache space is exceeded

Currently, once the KV cache is full, text generation stops. Instead,
we should shift out the oldest context so that new generation can
continue based on more recent context.

This uses the algorithm from llama.cpp that is currently used by Ollama
with the server.cpp code. There are others but they are never turned
on through Ollama, so this restores parity.

The algorithm is:
 - Retain a configurable number of tokens at the beginning (for things
like beginning of sequence tokens
 - Drop the oldest half of the remaining tokens
 - Shift the remaining new tokens to the back of the cache
Jesse Gross 8 months ago
parent
commit
69cc5795a7
2 changed files with 79 additions and 15 deletions
  1. 14 0
      llama/llama.go
  2. 65 15
      llama/runner/runner.go

+ 14 - 0
llama/llama.go

@@ -157,6 +157,10 @@ func (c *Context) SampleTokenGreedy(logits []float32) int {
 	}))
 	}))
 }
 }
 
 
+func (c *Context) KvCacheSeqAdd(seqId int, p0 int, p1 int, delta int) {
+	C.llama_kv_cache_seq_add(c.c, C.int(seqId), C.int(p0), C.int(p1), C.int(delta))
+}
+
 func (c *Context) KvCacheSeqRm(seqId int, p0 int, p1 int) bool {
 func (c *Context) KvCacheSeqRm(seqId int, p0 int, p1 int) bool {
 	return bool(C.llama_kv_cache_seq_rm(c.c, C.int(seqId), C.int(p0), C.int(p1)))
 	return bool(C.llama_kv_cache_seq_rm(c.c, C.int(seqId), C.int(p0), C.int(p1)))
 }
 }
@@ -191,6 +195,16 @@ func (m *Model) TokenIsEog(token int) bool {
 	return bool(C.llama_token_is_eog(m.c, C.llama_token(token)))
 	return bool(C.llama_token_is_eog(m.c, C.llama_token(token)))
 }
 }
 
 
+func (m *Model) ShouldAddBOSToken() bool {
+	addBos := int(C.llama_add_bos_token(m.c))
+
+	if addBos != -1 {
+		return addBos != 0
+	} else {
+		return C.llama_vocab_type(m.c) == C.LLAMA_VOCAB_TYPE_SPM
+	}
+}
+
 func (m *Model) ApplyLoraFromFile(loraPath string, scale float32, baseModelPath string, threads int) error {
 func (m *Model) ApplyLoraFromFile(loraPath string, scale float32, baseModelPath string, threads int) error {
 	cLoraPath := C.CString(loraPath)
 	cLoraPath := C.CString(loraPath)
 	defer C.free(unsafe.Pointer(cLoraPath))
 	defer C.free(unsafe.Pointer(cLoraPath))

+ 65 - 15
llama/runner/runner.go

@@ -49,6 +49,9 @@ type Sequence struct {
 	// stop sequences
 	// stop sequences
 	stop []string
 	stop []string
 
 
+	// number of tokens to keep at the beginning when shifting context window
+	numKeep int
+
 	// true if an embedding are to be returned instead of text generation
 	// true if an embedding are to be returned instead of text generation
 	embeddingOnly bool
 	embeddingOnly bool
 
 
@@ -61,22 +64,38 @@ type Sequence struct {
 	n_prompt_tokens        int
 	n_prompt_tokens        int
 }
 }
 
 
-func (s *Server) NewSequence(prompt string, numPredict int, stop []string, params *llama.SamplingParams, embedding bool) *Sequence {
+type NewSequenceParams struct {
+	numPredict     int
+	stop           []string
+	numKeep        int
+	samplingParams *llama.SamplingParams
+	embedding      bool
+}
+
+func (s *Server) NewSequence(prompt string, params NewSequenceParams) *Sequence {
 	tokens, err := s.lc.Model().Tokenize(prompt, true, true)
 	tokens, err := s.lc.Model().Tokenize(prompt, true, true)
 	if err != nil {
 	if err != nil {
 		panic(err)
 		panic(err)
 	}
 	}
 
 
-	// truncate to last n tokens
-	// TODO: this shouldn't happen and will severely impact generation
-	// quality. instead we should ensure to cut prompt in the API.
+	if params.numKeep < 0 {
+		params.numKeep = len(tokens)
+	}
+	// Subtracting 4 ensures that at least 1 token can be discarded during shift
+	params.numKeep = min(params.numKeep, s.numCtx-4)
+	params.numKeep += s.bosToken
+
+	// truncate to fit in context window
 	if len(tokens) > s.numCtx {
 	if len(tokens) > s.numCtx {
-		tokens = tokens[:s.numCtx]
+		slog.Warn("truncating input prompt", "limit", s.numCtx, "prompt", len(tokens), "numKeep", params.numKeep)
+		newTokens := tokens[:params.numKeep]
+		newTokens = append(newTokens, tokens[len(tokens)-s.numCtx+params.numKeep:]...)
+		tokens = newTokens
 	}
 	}
 
 
 	var sc *llama.SamplingContext
 	var sc *llama.SamplingContext
-	if params != nil {
-		sc = llama.NewSamplingContext(*params)
+	if params.samplingParams != nil {
+		sc = llama.NewSamplingContext(*params.samplingParams)
 		for _, t := range tokens {
 		for _, t := range tokens {
 			sc.Accept(s.lc, t, false)
 			sc.Accept(s.lc, t, false)
 		}
 		}
@@ -85,12 +104,13 @@ func (s *Server) NewSequence(prompt string, numPredict int, stop []string, param
 	return &Sequence{
 	return &Sequence{
 		tokens:          tokens,
 		tokens:          tokens,
 		n_prompt_tokens: len(tokens),
 		n_prompt_tokens: len(tokens),
-		numPredict:      numPredict,
+		numPredict:      params.numPredict,
 		responses:       make(chan string, 1),
 		responses:       make(chan string, 1),
 		embedding:       make(chan []float32, 1),
 		embedding:       make(chan []float32, 1),
 		samplingCtx:     sc,
 		samplingCtx:     sc,
-		embeddingOnly:   embedding,
-		stop:            stop,
+		embeddingOnly:   params.embedding,
+		stop:            params.stop,
+		numKeep:         params.numKeep,
 	}
 	}
 }
 }
 
 
@@ -111,6 +131,9 @@ type Server struct {
 	// context window size
 	// context window size
 	numCtx int
 	numCtx int
 
 
+	// does this model require a beginning of sequence token?
+	bosToken int
+
 	mu sync.Mutex
 	mu sync.Mutex
 
 
 	cond *sync.Cond
 	cond *sync.Cond
@@ -129,6 +152,21 @@ func (s *Server) allNil() bool {
 	return true
 	return true
 }
 }
 
 
+func (s *Server) shiftContext(seqIndex int) {
+	seq := s.seqs[seqIndex]
+
+	numLeft := seq.nPast - seq.numKeep
+	numDiscard := numLeft / 2
+
+	slog.Debug("context limit hit - shifting", "limit", s.numCtx, "nPast", seq.nPast,
+		"numKeep", seq.numKeep, "numLeft", numLeft, "numDiscard", numDiscard)
+
+	s.lc.KvCacheSeqRm(seqIndex, seq.numKeep, seq.numKeep+numDiscard)
+	s.lc.KvCacheSeqAdd(seqIndex, seq.numKeep+numDiscard, seq.nPast, -numDiscard)
+
+	seq.nPast -= numDiscard
+}
+
 func (s *Server) run(ctx context.Context) {
 func (s *Server) run(ctx context.Context) {
 	// TODO - should this be n_ctx / parallel like the old server.cpp setup?
 	// TODO - should this be n_ctx / parallel like the old server.cpp setup?
 	batch := llama.NewBatch(s.batchSize, 0, s.parallel)
 	batch := llama.NewBatch(s.batchSize, 0, s.parallel)
@@ -155,10 +193,8 @@ func (s *Server) run(ctx context.Context) {
 					continue
 					continue
 				}
 				}
 
 
-				hitLimit := seq.numPredict > 0 && seq.numPredicted > seq.numPredict
-
 				// if past the num predict limit
 				// if past the num predict limit
-				if hitLimit || seq.nPast > s.numCtx {
+				if seq.numPredict > 0 && seq.numPredicted > seq.numPredict {
 					seq.doneReason = "limit"
 					seq.doneReason = "limit"
 					close(seq.responses)
 					close(seq.responses)
 					s.lc.KvCacheSeqRm(i, 0, -1)
 					s.lc.KvCacheSeqRm(i, 0, -1)
@@ -166,6 +202,10 @@ func (s *Server) run(ctx context.Context) {
 					continue
 					continue
 				}
 				}
 
 
+				if seq.nPast+len(seq.tokens) > s.numCtx {
+					s.shiftContext(i)
+				}
+
 				if seq.t_start_process_prompt.IsZero() {
 				if seq.t_start_process_prompt.IsZero() {
 					seq.t_start_process_prompt = time.Now()
 					seq.t_start_process_prompt = time.Now()
 				}
 				}
@@ -350,7 +390,13 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 	samplingParams.Seed = uint32(req.Seed)
 	samplingParams.Seed = uint32(req.Seed)
 	samplingParams.Grammar = req.Grammar
 	samplingParams.Grammar = req.Grammar
 
 
-	seq := s.NewSequence(req.Prompt, req.NumPredict, req.Stop, &samplingParams, false)
+	seq := s.NewSequence(req.Prompt, NewSequenceParams{
+		numPredict:     req.NumPredict,
+		stop:           req.Stop,
+		numKeep:        req.NumKeep,
+		samplingParams: &samplingParams,
+		embedding:      false,
+	})
 
 
 	// TODO (jmorganca): add to sequence queue instead of
 	// TODO (jmorganca): add to sequence queue instead of
 	// failing if a slot isn't available
 	// failing if a slot isn't available
@@ -428,7 +474,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
 	embeddings := make([][]float32, len(req.Content))
 	embeddings := make([][]float32, len(req.Content))
 	var processed int
 	var processed int
 	for i, content := range req.Content {
 	for i, content := range req.Content {
-		seqs[i] = s.NewSequence(content, 0, nil, nil, true)
+		seqs[i] = s.NewSequence(content, NewSequenceParams{embedding: true})
 	}
 	}
 
 
 	// TODO - refactor to go routines to add seq's and drain the responses
 	// TODO - refactor to go routines to add seq's and drain the responses
@@ -563,6 +609,10 @@ func main() {
 	ctxParams := llama.NewContextParams(*numCtx, *threads, *flashAttention)
 	ctxParams := llama.NewContextParams(*numCtx, *threads, *flashAttention)
 	server.lc = llama.NewContextWithModel(server.model, ctxParams)
 	server.lc = llama.NewContextWithModel(server.model, ctxParams)
 
 
+	if server.model.ShouldAddBOSToken() {
+		server.bosToken = 1
+	}
+
 	if *ppath != "" {
 	if *ppath != "" {
 		server.cc = llama.NewClipContext(*ppath)
 		server.cc = llama.NewClipContext(*ppath)
 	}
 	}