Browse Source

runner: remove cache prompt flag from ollama runner (#9826)

We do not need to bypass the prompt caching in the ollama runner yet, as
only embedding models needed to bypass the prompt caching. When embedding
models are implemented they can skip initializing this cache completely.
Bruce MacDonald 1 month ago
parent
commit
95e271d98f
3 changed files with 130 additions and 7 deletions
  1. 1 6
      runner/ollamarunner/cache.go
  2. 128 0
      runner/ollamarunner/cache_test.go
  3. 1 1
      runner/ollamarunner/runner.go

+ 1 - 6
runner/ollamarunner/cache.go

@@ -89,7 +89,7 @@ type InputCacheSlot struct {
 	lastUsed time.Time
 }
 
-func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) {
+func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) {
 	var slot *InputCacheSlot
 	var numPast int32
 	var err error
@@ -107,11 +107,6 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*Inp
 		return nil, nil, err
 	}
 
-	// TODO (brucemacd): cachePrompt is always true for completion, but false for embedding, can this be improved?
-	if !cachePrompt {
-		numPast = 0
-	}
-
 	slot.InUse = true
 	slot.lastUsed = time.Now()
 

+ 128 - 0
runner/ollamarunner/cache_test.go

@@ -297,3 +297,131 @@ func TestShiftDiscard(t *testing.T) {
 		})
 	}
 }
+
+func TestLoadCacheSlot(t *testing.T) {
+	tests := []struct {
+		name           string
+		cache          InputCache
+		prompt         []input.Input
+		wantErr        bool
+		expectedSlotId int
+		expectedPrompt int // expected length of remaining prompt
+	}{
+		{
+			name: "Basic cache hit - single user",
+			cache: InputCache{
+				multiUserCache: false,
+				slots: []InputCacheSlot{
+					{
+						Id:       0,
+						Inputs:   []input.Input{{Token: 1}, {Token: 2}},
+						InUse:    false,
+						lastUsed: time.Now().Add(-time.Second),
+					},
+					{
+						Id:       1,
+						Inputs:   []input.Input{},
+						InUse:    false,
+						lastUsed: time.Now().Add(-2 * time.Second),
+					},
+				},
+			},
+			prompt:         []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
+			wantErr:        false,
+			expectedSlotId: 0,
+			expectedPrompt: 1, // Only token 3 remains
+		},
+		{
+			name: "Basic cache hit - multi user",
+			cache: InputCache{
+				multiUserCache: true,
+				slots: []InputCacheSlot{
+					{
+						Id:       0,
+						Inputs:   []input.Input{{Token: 1}, {Token: 2}},
+						InUse:    false,
+						lastUsed: time.Now().Add(-time.Second),
+					},
+					{
+						Id:       1,
+						Inputs:   []input.Input{},
+						InUse:    false,
+						lastUsed: time.Now().Add(-2 * time.Second),
+					},
+				},
+			},
+			prompt:         []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
+			wantErr:        false,
+			expectedSlotId: 0,
+			expectedPrompt: 1, // Only token 3 remains
+		},
+		{
+			name: "Exact match - leave one input",
+			cache: InputCache{
+				multiUserCache: false,
+				slots: []InputCacheSlot{
+					{
+						Id:       0,
+						Inputs:   []input.Input{{Token: 1}, {Token: 2}},
+						InUse:    false,
+						lastUsed: time.Now().Add(-time.Second),
+					},
+				},
+			},
+			prompt:         []input.Input{{Token: 1}, {Token: 2}},
+			wantErr:        false,
+			expectedSlotId: 0,
+			expectedPrompt: 1, // Should leave 1 token for sampling
+		},
+		{
+			name: "No available slots",
+			cache: InputCache{
+				multiUserCache: false,
+				slots: []InputCacheSlot{
+					{
+						Id:       0,
+						Inputs:   []input.Input{{Token: 1}, {Token: 2}},
+						InUse:    true,
+						lastUsed: time.Now().Add(-time.Second),
+					},
+				},
+			},
+			prompt:         []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
+			wantErr:        true,
+			expectedSlotId: -1,
+			expectedPrompt: -1,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt)
+
+			// Check error state
+			if (err != nil) != tt.wantErr {
+				t.Errorf("LoadCacheSlot() error = %v, wantErr %v", err, tt.wantErr)
+				return
+			}
+
+			if tt.wantErr {
+				return // Skip further checks if we expected an error
+			}
+
+			// Verify slot ID
+			if slot.Id != tt.expectedSlotId {
+				t.Errorf("LoadCacheSlot() slot ID = %v, expected %v", slot.Id, tt.expectedSlotId)
+			}
+
+			// Verify slot is now marked in use
+			if !slot.InUse {
+				t.Errorf("LoadCacheSlot() slot not marked InUse")
+			}
+
+			// Verify remaining prompt length
+			if len(remainingPrompt) != tt.expectedPrompt {
+				t.Errorf("LoadCacheSlot() remaining prompt length = %v, expected %v",
+					len(remainingPrompt), tt.expectedPrompt)
+			}
+		})
+	}
+}

+ 1 - 1
runner/ollamarunner/runner.go

@@ -590,7 +590,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 	found := false
 	for i, sq := range s.seqs {
 		if sq == nil {
-			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
+			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
 			if err != nil {
 				s.mu.Unlock()
 				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)