Selaa lähdekoodia

runner.go: Add unit tests for context shifting

This also makes it easier to truncate long inputs the same as
shifting but does not actually implement it. This type of
truncation has a trade off between quality and time to first
token.
Jesse Gross 5 kuukautta sitten
vanhempi
commit
2cd11ae365
3 muutettua tiedostoa jossa 82 lisäystä ja 7 poistoa
  1. 15 5
      llama/runner/cache.go
  2. 63 0
      llama/runner/cache_test.go
  3. 4 2
      llama/runner/runner.go

+ 15 - 5
llama/runner/cache.go

@@ -199,6 +199,20 @@ func countCommonPrefix(a []input, b []input) int {
 	return count
 	return count
 }
 }
 
 
+func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
+	targetFree := (c.numCtx - numKeep) / 2
+	targetFree = max(targetFree, 1)
+
+	currentFree := c.numCtx - inputLen
+	discard := targetFree - currentFree
+
+	if discard < 0 {
+		discard = 0
+	}
+
+	return discard
+}
+
 // Frees up space in the KV cache by deleting the oldest half of history and shifting
 // Frees up space in the KV cache by deleting the oldest half of history and shifting
 // the newest half into that space (saving numKeep inputs at the beginning).
 // the newest half into that space (saving numKeep inputs at the beginning).
 //
 //
@@ -208,11 +222,7 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
 		return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
 		return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
 	}
 	}
 
 
-	targetFree := (c.numCtx - numKeep) / 2
-	targetFree = max(targetFree, 1)
-
-	currentFree := c.numCtx - len(slot.Inputs)
-	discard := targetFree - currentFree
+	discard := c.ShiftDiscard(len(slot.Inputs), numKeep)
 
 
 	if discard <= 0 {
 	if discard <= 0 {
 		return nil
 		return nil

+ 63 - 0
llama/runner/cache_test.go

@@ -227,3 +227,66 @@ func TestFindCacheSlot(t *testing.T) {
 		})
 		})
 	}
 	}
 }
 }
+
+func TestShiftDiscard(t *testing.T) {
+	tests := []struct {
+		name     string
+		numCtx   int
+		numKeep  int
+		inputLen int
+		expected int
+	}{
+		{
+			name:     "Shift",
+			numCtx:   2048,
+			numKeep:  5,
+			inputLen: 2048,
+			expected: 1021,
+		},
+		{
+			name:     "Max Keep",
+			numCtx:   2048,
+			numKeep:  2047,
+			inputLen: 2048,
+			expected: 1,
+		},
+		{
+			name:     "No Keep",
+			numCtx:   2048,
+			numKeep:  0,
+			inputLen: 2048,
+			expected: 1024,
+		},
+		{
+			name:     "Truncate",
+			numCtx:   2048,
+			numKeep:  5,
+			inputLen: 5000,
+			expected: 3973,
+		},
+		{
+			name:     "Truncate Keep",
+			numCtx:   2048,
+			numKeep:  2047,
+			inputLen: 5000,
+			expected: 2953,
+		},
+		{
+			name:     "No Op",
+			numCtx:   2048,
+			numKeep:  5,
+			inputLen: 512,
+			expected: 0,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			c := InputCache{numCtx: tt.numCtx}
+			result := c.ShiftDiscard(tt.inputLen, tt.numKeep)
+			if result != tt.expected {
+				t.Errorf("shiftDiscard(ctx: %v, keep: %v input: %v): have %v; want %v", tt.numCtx, tt.numKeep, tt.inputLen, result, tt.expected)
+			}
+		})
+	}
+}

+ 4 - 2
llama/runner/runner.go

@@ -122,9 +122,11 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
 	params.numKeep = min(params.numKeep, s.cache.numCtx-1)
 	params.numKeep = min(params.numKeep, s.cache.numCtx-1)
 
 
 	if len(inputs) > s.cache.numCtx {
 	if len(inputs) > s.cache.numCtx {
-		slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "numKeep", params.numKeep)
+		discard := len(inputs) - s.cache.numCtx
 		newInputs := inputs[:params.numKeep]
 		newInputs := inputs[:params.numKeep]
-		newInputs = append(newInputs, inputs[len(inputs)-s.cache.numCtx+params.numKeep:]...)
+		newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
+
+		slog.Warn("truncating input prompt", "limit", s.cache.numCtx, "prompt", len(inputs), "keep", params.numKeep, "new", len(newInputs))
 		inputs = newInputs
 		inputs = newInputs
 	}
 	}