|
@@ -34,9 +34,6 @@ type input struct {
|
|
|
}
|
|
|
|
|
|
type Sequence struct {
|
|
|
- // number of inputs evaluated
|
|
|
- numPast int
|
|
|
-
|
|
|
// batch index
|
|
|
iBatch int
|
|
|
|
|
@@ -112,21 +109,15 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
|
|
params.numKeep = len(inputs)
|
|
|
}
|
|
|
|
|
|
- if !params.embedding {
|
|
|
- // Subtracting 4 ensures that at least 1 input can be discarded during shift
|
|
|
- params.numKeep = min(params.numKeep, s.cache.numCtx-4)
|
|
|
- params.numKeep += s.bosToken
|
|
|
- } else {
|
|
|
- // Embeddings are 1 shot - just truncate to the context window, without ever shifting
|
|
|
- params.numKeep = min(params.numKeep, s.cache.numCtx)
|
|
|
+ if s.model.AddBOSToken() {
|
|
|
+ params.numKeep += 1
|
|
|
}
|
|
|
|
|
|
- // truncate to fit in context window
|
|
|
+ // Ensure that at least 1 input can be discarded during shift
|
|
|
+ params.numKeep = min(params.numKeep, s.cache.numCtx-1)
|
|
|
+
|
|
|
if len(inputs) > s.cache.numCtx {
|
|
|
- slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "numKeep", params.numKeep)
|
|
|
- newInputs := inputs[:params.numKeep]
|
|
|
- newInputs = append(newInputs, inputs[len(inputs)-s.cache.numCtx+params.numKeep:]...)
|
|
|
- inputs = newInputs
|
|
|
+ slog.Warn("input exceeds context length", "prompt", len(inputs), "limit", s.cache.numCtx)
|
|
|
}
|
|
|
|
|
|
var sc *llama.SamplingContext
|
|
@@ -231,9 +222,6 @@ type Server struct {
|
|
|
// KV cache
|
|
|
cache *InputCache
|
|
|
|
|
|
- // does this model require a beginning of sequence token?
|
|
|
- bosToken int
|
|
|
-
|
|
|
// next sequence for prompt processing to avoid starvation
|
|
|
nextSeq int
|
|
|
|
|
@@ -258,18 +246,6 @@ func (s *Server) allNil() bool {
|
|
|
return true
|
|
|
}
|
|
|
|
|
|
-func (s *Server) shiftContext(seq *Sequence) {
|
|
|
- numLeft := seq.numPast - seq.numKeep
|
|
|
- numDiscard := numLeft / 2
|
|
|
-
|
|
|
- slog.Debug("context limit hit - shifting", "limit", s.cache.numCtx, "numPast", seq.numPast,
|
|
|
- "numKeep", seq.numKeep, "numLeft", numLeft, "numDiscard", numDiscard)
|
|
|
-
|
|
|
- s.cache.ShiftCacheSlot(seq.cache, seq.numKeep, numDiscard, seq.numPast)
|
|
|
-
|
|
|
- seq.numPast -= numDiscard
|
|
|
-}
|
|
|
-
|
|
|
func flushPending(seq *Sequence) bool {
|
|
|
joined := strings.Join(seq.pendingResponses, "")
|
|
|
seq.pendingResponses = []string{}
|
|
@@ -374,12 +350,19 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
- if seq.numPast+len(seq.inputs) > s.cache.numCtx {
|
|
|
- s.shiftContext(seq)
|
|
|
- }
|
|
|
-
|
|
|
var numInputsProcessed int
|
|
|
+ shifted := false
|
|
|
+
|
|
|
for i, input := range seq.inputs {
|
|
|
+ if len(seq.cache.Inputs)+1 > s.cache.numCtx {
|
|
|
+ if !shifted {
|
|
|
+ s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
|
|
+ shifted = true
|
|
|
+ } else {
|
|
|
+ break
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
embedding := input.embed != nil
|
|
|
|
|
|
// If we don't currently have a batch, use one of the correct type and
|
|
@@ -403,13 +386,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
|
|
}
|
|
|
|
|
|
crossAttention = seq.crossAttention
|
|
|
- batch.Add(input.token, input.embed, seq.numPast, numInputsProcessed+1 == len(seq.inputs), seq.cache.Id)
|
|
|
- seq.numPast++
|
|
|
+ batch.Add(input.token, input.embed, len(seq.cache.Inputs), i+1 == len(seq.inputs), seq.cache.Id)
|
|
|
+ seq.cache.Inputs = append(seq.cache.Inputs, input)
|
|
|
numInputsProcessed++
|
|
|
}
|
|
|
|
|
|
if numInputsProcessed > 0 {
|
|
|
- seq.cache.Inputs = append(seq.cache.Inputs, seq.inputs[:numInputsProcessed]...)
|
|
|
seq.inputs = seq.inputs[numInputsProcessed:]
|
|
|
seq.iBatch = batch.NumTokens() - 1
|
|
|
}
|
|
@@ -632,7 +614,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|
|
s.mu.Lock()
|
|
|
for i, sq := range s.seqs {
|
|
|
if sq == nil {
|
|
|
- seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
|
|
+ seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
|
|
if err != nil {
|
|
|
s.mu.Unlock()
|
|
|
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
|
@@ -715,7 +697,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|
|
s.mu.Lock()
|
|
|
for i, sq := range s.seqs {
|
|
|
if sq == nil {
|
|
|
- seq.cache, seq.inputs, seq.numPast, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
|
|
+ seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, req.CachePrompt)
|
|
|
if err != nil {
|
|
|
s.mu.Unlock()
|
|
|
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
|
|
@@ -802,10 +784,6 @@ func (s *Server) loadModel(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- if s.model.AddBOSToken() {
|
|
|
- s.bosToken = 1
|
|
|
- }
|
|
|
-
|
|
|
if ppath != "" {
|
|
|
var err error
|
|
|
s.image, err = NewImageContext(s.lc, ppath)
|
|
@@ -814,7 +792,10 @@ func (s *Server) loadModel(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- s.cache = NewInputCache(s.lc, kvSize, s.parallel, multiUserCache)
|
|
|
+ s.cache, err = NewInputCache(s.lc, kvSize, s.parallel, multiUserCache)
|
|
|
+ if err != nil {
|
|
|
+ panic(err)
|
|
|
+ }
|
|
|
|
|
|
s.status = ServerStatusReady
|
|
|
s.ready.Done()
|