فهرست منبع

Allow models to force a new batch

This is useful for a few things:
 - Work around bugs, such as having 2 images in one batch
 - Keep the image in a single batch for fully connected attention
 - Improve performance by not evaluating embeddings multiple times
Jesse Gross 1 ماه پیش
والد
کامیت
06007c0a18
4فایلهای تغییر یافته به همراه10 افزوده شده و 14 حذف شده
  1. 6 0
      model/input/input.go
  2. 2 2
      model/models/gemma3/model.go
  3. 1 1
      runner/ollamarunner/runner.go
  4. 1 11
      server/prompt.go

+ 6 - 0
model/input/input.go

@@ -15,6 +15,12 @@ type Input struct {
 	// stored in Multimodal, used for caching and comparing
 	// equality.
 	MultimodalHash uint64
+
+	// BatchBreak forces a new batch to be started with this
+	// input. For example, this can be used to align images
+	// with batches. Note that batches may be divided in additional
+	// locations as well.
+	BatchBreak bool
 }
 
 // MultimodalIndex is a multimodal element (such as an image)

+ 2 - 2
model/models/gemma3/model.go

@@ -112,8 +112,8 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Inpu
 			result = append(result, inp)
 		} else {
 			imageInputs := []input.Input{
-				{Token: 108},    // "\n\n"
-				{Token: 255999}, // "<start_of_image>""
+				{Token: 108},                      // "\n\n"
+				{Token: 255999, BatchBreak: true}, // "<start_of_image>""
 			}
 			result = append(result, imageInputs...)
 

+ 1 - 1
runner/ollamarunner/runner.go

@@ -363,7 +363,7 @@ func (s *Server) processBatch() error {
 				}
 			}
 
-			if j >= s.batchSize {
+			if j >= s.batchSize || (inp.BatchBreak && len(seq.pendingInputs) != 0) {
 				break
 			}
 

+ 1 - 11
server/prompt.go

@@ -26,7 +26,6 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
 	var system []api.Message
 
 	isMllama := checkMllamaModelFamily(m)
-	isGemma3 := checkGemma3ModelFamily(m)
 
 	var imageNumTokens int
 	// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
@@ -41,7 +40,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
 	n := len(msgs) - 1
 	// in reverse, find all messages that fit into context window
 	for i := n; i >= 0; i-- {
-		if (isMllama || isGemma3) && len(msgs[i].Images) > 1 {
+		if isMllama && len(msgs[i].Images) > 1 {
 			return "", nil, errTooManyImages
 		}
 
@@ -158,12 +157,3 @@ func checkMllamaModelFamily(m *Model) bool {
 	}
 	return false
 }
-
-func checkGemma3ModelFamily(m *Model) bool {
-	for _, arch := range m.Config.ModelFamilies {
-		if arch == "gemma3" {
-			return true
-		}
-	}
-	return false
-}