소스 검색

runner: remove cache prompt flag from ollama runner (#9826)

We do not need to bypass the prompt caching in the ollama runner yet, as
only embedding models needed to bypass the prompt caching. When embedding
models are implemented they can skip initializing this cache completely.
Bruce MacDonald 1 개월 전
부모
커밋
95e271d98f
3개의 변경된 파일130개의 추가작업 그리고 7개의 파일을 삭제
  1. 1 6
      runner/ollamarunner/cache.go
  2. 128 0
      runner/ollamarunner/cache_test.go
  3. 1 1
      runner/ollamarunner/runner.go

+ 1 - 6
runner/ollamarunner/cache.go

@@ -89,7 +89,7 @@ type InputCacheSlot struct {
 	lastUsed time.Time
 }
 
-func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) {
+func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) {
 	var slot *InputCacheSlot
 	var numPast int32
 	var err error
@@ -107,11 +107,6 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*Inp
 		return nil, nil, err
 	}
 
-	// TODO (brucemacd): cachePrompt is always true for completion, but false for embedding, can this be improved?
-	if !cachePrompt {
-		numPast = 0
-	}
-
 	slot.InUse = true
 	slot.lastUsed = time.Now()
 

+ 128 - 0
runner/ollamarunner/cache_test.go

@@ -297,3 +297,131 @@ func TestShiftDiscard(t *testing.T) {
 		})
 	}
 }
+
+func TestLoadCacheSlot(t *testing.T) {
+	tests := []struct {
+		name           string
+		cache          InputCache
+		prompt         []input.Input
+		wantErr        bool
+		expectedSlotId int
+		expectedPrompt int // expected length of remaining prompt
+	}{
+		{
+			name: "Basic cache hit - single user",
+			cache: InputCache{
+				multiUserCache: false,
+				slots: []InputCacheSlot{
+					{
+						Id:       0,
+						Inputs:   []input.Input{{Token: 1}, {Token: 2}},
+						InUse:    false,
+						lastUsed: time.Now().Add(-time.Second),
+					},
+					{
+						Id:       1,
+						Inputs:   []input.Input{},
+						InUse:    false,
+						lastUsed: time.Now().Add(-2 * time.Second),
+					},
+				},
+			},
+			prompt:         []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
+			wantErr:        false,
+			expectedSlotId: 0,
+			expectedPrompt: 1, // Only token 3 remains
+		},
+		{
+			name: "Basic cache hit - multi user",
+			cache: InputCache{
+				multiUserCache: true,
+				slots: []InputCacheSlot{
+					{
+						Id:       0,
+						Inputs:   []input.Input{{Token: 1}, {Token: 2}},
+						InUse:    false,
+						lastUsed: time.Now().Add(-time.Second),
+					},
+					{
+						Id:       1,
+						Inputs:   []input.Input{},
+						InUse:    false,
+						lastUsed: time.Now().Add(-2 * time.Second),
+					},
+				},
+			},
+			prompt:         []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
+			wantErr:        false,
+			expectedSlotId: 0,
+			expectedPrompt: 1, // Only token 3 remains
+		},
+		{
+			name: "Exact match - leave one input",
+			cache: InputCache{
+				multiUserCache: false,
+				slots: []InputCacheSlot{
+					{
+						Id:       0,
+						Inputs:   []input.Input{{Token: 1}, {Token: 2}},
+						InUse:    false,
+						lastUsed: time.Now().Add(-time.Second),
+					},
+				},
+			},
+			prompt:         []input.Input{{Token: 1}, {Token: 2}},
+			wantErr:        false,
+			expectedSlotId: 0,
+			expectedPrompt: 1, // Should leave 1 token for sampling
+		},
+		{
+			name: "No available slots",
+			cache: InputCache{
+				multiUserCache: false,
+				slots: []InputCacheSlot{
+					{
+						Id:       0,
+						Inputs:   []input.Input{{Token: 1}, {Token: 2}},
+						InUse:    true,
+						lastUsed: time.Now().Add(-time.Second),
+					},
+				},
+			},
+			prompt:         []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
+			wantErr:        true,
+			expectedSlotId: -1,
+			expectedPrompt: -1,
+		},
+	}
+
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt)
+
+			// Check error state
+			if (err != nil) != tt.wantErr {
+				t.Errorf("LoadCacheSlot() error = %v, wantErr %v", err, tt.wantErr)
+				return
+			}
+
+			if tt.wantErr {
+				return // Skip further checks if we expected an error
+			}
+
+			// Verify slot ID
+			if slot.Id != tt.expectedSlotId {
+				t.Errorf("LoadCacheSlot() slot ID = %v, expected %v", slot.Id, tt.expectedSlotId)
+			}
+
+			// Verify slot is now marked in use
+			if !slot.InUse {
+				t.Errorf("LoadCacheSlot() slot not marked InUse")
+			}
+
+			// Verify remaining prompt length
+			if len(remainingPrompt) != tt.expectedPrompt {
+				t.Errorf("LoadCacheSlot() remaining prompt length = %v, expected %v",
+					len(remainingPrompt), tt.expectedPrompt)
+			}
+		})
+	}
+}

+ 1 - 1
runner/ollamarunner/runner.go

@@ -590,7 +590,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 	found := false
 	for i, sq := range s.seqs {
 		if sq == nil {
-			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
+			seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
 			if err != nil {
 				s.mu.Unlock()
 				http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)