浏览代码

runner.go: Shift context window when KV cache space is exceeded

Currently, once the KV cache is full, text generation stops. Instead,
we should shift out the oldest context so that new generation can
continue based on more recent context.

This uses the algorithm from llama.cpp that is currently used by Ollama
with the server.cpp code. There are others but they are never turned
on through Ollama, so this restores parity.

The algorithm is:
 - Retain a configurable number of tokens at the beginning (for things
like beginning of sequence tokens
 - Drop the oldest half of the remaining tokens
 - Shift the remaining new tokens to the back of the cache
Jesse Gross 8 月之前
父节点
当前提交
69cc5795a7
共有 2 个文件被更改,包括 79 次插入15 次删除
  1. 14 0
      llama/llama.go
  2. 65 15
      llama/runner/runner.go

+ 14 - 0
llama/llama.go

@@ -157,6 +157,10 @@ func (c *Context) SampleTokenGreedy(logits []float32) int {
 	}))
 }
 
+func (c *Context) KvCacheSeqAdd(seqId int, p0 int, p1 int, delta int) {
+	C.llama_kv_cache_seq_add(c.c, C.int(seqId), C.int(p0), C.int(p1), C.int(delta))
+}
+
 func (c *Context) KvCacheSeqRm(seqId int, p0 int, p1 int) bool {
 	return bool(C.llama_kv_cache_seq_rm(c.c, C.int(seqId), C.int(p0), C.int(p1)))
 }
@@ -191,6 +195,16 @@ func (m *Model) TokenIsEog(token int) bool {
 	return bool(C.llama_token_is_eog(m.c, C.llama_token(token)))
 }
 
+func (m *Model) ShouldAddBOSToken() bool {
+	addBos := int(C.llama_add_bos_token(m.c))
+
+	if addBos != -1 {
+		return addBos != 0
+	} else {
+		return C.llama_vocab_type(m.c) == C.LLAMA_VOCAB_TYPE_SPM
+	}
+}
+
 func (m *Model) ApplyLoraFromFile(loraPath string, scale float32, baseModelPath string, threads int) error {
 	cLoraPath := C.CString(loraPath)
 	defer C.free(unsafe.Pointer(cLoraPath))

+ 65 - 15
llama/runner/runner.go

@@ -49,6 +49,9 @@ type Sequence struct {
 	// stop sequences
 	stop []string
 
+	// number of tokens to keep at the beginning when shifting context window
+	numKeep int
+
 	// true if an embedding are to be returned instead of text generation
 	embeddingOnly bool
 
@@ -61,22 +64,38 @@ type Sequence struct {
 	n_prompt_tokens        int
 }
 
-func (s *Server) NewSequence(prompt string, numPredict int, stop []string, params *llama.SamplingParams, embedding bool) *Sequence {
+type NewSequenceParams struct {
+	numPredict     int
+	stop           []string
+	numKeep        int
+	samplingParams *llama.SamplingParams
+	embedding      bool
+}
+
+func (s *Server) NewSequence(prompt string, params NewSequenceParams) *Sequence {
 	tokens, err := s.lc.Model().Tokenize(prompt, true, true)
 	if err != nil {
 		panic(err)
 	}
 
-	// truncate to last n tokens
-	// TODO: this shouldn't happen and will severely impact generation
-	// quality. instead we should ensure to cut prompt in the API.
+	if params.numKeep < 0 {
+		params.numKeep = len(tokens)
+	}
+	// Subtracting 4 ensures that at least 1 token can be discarded during shift
+	params.numKeep = min(params.numKeep, s.numCtx-4)
+	params.numKeep += s.bosToken
+
+	// truncate to fit in context window
 	if len(tokens) > s.numCtx {
-		tokens = tokens[:s.numCtx]
+		slog.Warn("truncating input prompt", "limit", s.numCtx, "prompt", len(tokens), "numKeep", params.numKeep)
+		newTokens := tokens[:params.numKeep]
+		newTokens = append(newTokens, tokens[len(tokens)-s.numCtx+params.numKeep:]...)
+		tokens = newTokens
 	}
 
 	var sc *llama.SamplingContext
-	if params != nil {
-		sc = llama.NewSamplingContext(*params)
+	if params.samplingParams != nil {
+		sc = llama.NewSamplingContext(*params.samplingParams)
 		for _, t := range tokens {
 			sc.Accept(s.lc, t, false)
 		}
@@ -85,12 +104,13 @@ func (s *Server) NewSequence(prompt string, numPredict int, stop []string, param
 	return &Sequence{
 		tokens:          tokens,
 		n_prompt_tokens: len(tokens),
-		numPredict:      numPredict,
+		numPredict:      params.numPredict,
 		responses:       make(chan string, 1),
 		embedding:       make(chan []float32, 1),
 		samplingCtx:     sc,
-		embeddingOnly:   embedding,
-		stop:            stop,
+		embeddingOnly:   params.embedding,
+		stop:            params.stop,
+		numKeep:         params.numKeep,
 	}
 }
 
@@ -111,6 +131,9 @@ type Server struct {
 	// context window size
 	numCtx int
 
+	// does this model require a beginning of sequence token?
+	bosToken int
+
 	mu sync.Mutex
 
 	cond *sync.Cond
@@ -129,6 +152,21 @@ func (s *Server) allNil() bool {
 	return true
 }
 
+func (s *Server) shiftContext(seqIndex int) {
+	seq := s.seqs[seqIndex]
+
+	numLeft := seq.nPast - seq.numKeep
+	numDiscard := numLeft / 2
+
+	slog.Debug("context limit hit - shifting", "limit", s.numCtx, "nPast", seq.nPast,
+		"numKeep", seq.numKeep, "numLeft", numLeft, "numDiscard", numDiscard)
+
+	s.lc.KvCacheSeqRm(seqIndex, seq.numKeep, seq.numKeep+numDiscard)
+	s.lc.KvCacheSeqAdd(seqIndex, seq.numKeep+numDiscard, seq.nPast, -numDiscard)
+
+	seq.nPast -= numDiscard
+}
+
 func (s *Server) run(ctx context.Context) {
 	// TODO - should this be n_ctx / parallel like the old server.cpp setup?
 	batch := llama.NewBatch(s.batchSize, 0, s.parallel)
@@ -155,10 +193,8 @@ func (s *Server) run(ctx context.Context) {
 					continue
 				}
 
-				hitLimit := seq.numPredict > 0 && seq.numPredicted > seq.numPredict
-
 				// if past the num predict limit
-				if hitLimit || seq.nPast > s.numCtx {
+				if seq.numPredict > 0 && seq.numPredicted > seq.numPredict {
 					seq.doneReason = "limit"
 					close(seq.responses)
 					s.lc.KvCacheSeqRm(i, 0, -1)
@@ -166,6 +202,10 @@ func (s *Server) run(ctx context.Context) {
 					continue
 				}
 
+				if seq.nPast+len(seq.tokens) > s.numCtx {
+					s.shiftContext(i)
+				}
+
 				if seq.t_start_process_prompt.IsZero() {
 					seq.t_start_process_prompt = time.Now()
 				}
@@ -350,7 +390,13 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 	samplingParams.Seed = uint32(req.Seed)
 	samplingParams.Grammar = req.Grammar
 
-	seq := s.NewSequence(req.Prompt, req.NumPredict, req.Stop, &samplingParams, false)
+	seq := s.NewSequence(req.Prompt, NewSequenceParams{
+		numPredict:     req.NumPredict,
+		stop:           req.Stop,
+		numKeep:        req.NumKeep,
+		samplingParams: &samplingParams,
+		embedding:      false,
+	})
 
 	// TODO (jmorganca): add to sequence queue instead of
 	// failing if a slot isn't available
@@ -428,7 +474,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
 	embeddings := make([][]float32, len(req.Content))
 	var processed int
 	for i, content := range req.Content {
-		seqs[i] = s.NewSequence(content, 0, nil, nil, true)
+		seqs[i] = s.NewSequence(content, NewSequenceParams{embedding: true})
 	}
 
 	// TODO - refactor to go routines to add seq's and drain the responses
@@ -563,6 +609,10 @@ func main() {
 	ctxParams := llama.NewContextParams(*numCtx, *threads, *flashAttention)
 	server.lc = llama.NewContextWithModel(server.model, ctxParams)
 
+	if server.model.ShouldAddBOSToken() {
+		server.bosToken = 1
+	}
+
 	if *ppath != "" {
 		server.cc = llama.NewClipContext(*ppath)
 	}