Преглед изворни кода

ollamarunner: Use a separate context per multimodal input

Currently there is a single context per sequence, shared all by
all multimodal inputs. Since we build a vision encoder graph per
image, with a large number of inputs we can eventually hit the
maximum number of graph nodes per context.

This changes to use a separate context for each image, ensuring
that available resource limits are consistent.
Jesse Gross пре 1 месец
родитељ
комит
282bfaaa95
4 измењених фајлова са 33 додато и 19 уклоњено
  1. 1 1
      model/model.go
  2. 1 1
      model/models/gemma3/model.go
  3. 7 4
      model/models/mllama/model.go
  4. 24 13
      runner/ollamarunner/runner.go

+ 1 - 1
model/model.go

@@ -60,7 +60,7 @@ type MultimodalProcessor interface {
 	// This function is also responsible for updating MultimodalHash for any Multimodal
 	// that is modified to ensure that there is a unique hash value that accurately
 	// represents the contents.
-	PostTokenize(ml.Context, []input.Input) ([]input.Input, error)
+	PostTokenize([]input.Input) ([]input.Input, error)
 }
 
 // Base implements the common fields and methods for all models

+ 1 - 1
model/models/gemma3/model.go

@@ -111,7 +111,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
 	return visionOutputs, nil
 }
 
-func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
+func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
 	var result []input.Input
 
 	for _, inp := range inputs {

+ 7 - 4
model/models/mllama/model.go

@@ -106,17 +106,17 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
 	return m.Projector.Forward(ctx, crossAttentionStates), nil
 }
 
-func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) {
+func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
 	var images []input.Input
 	fnvHash := fnv.New64a()
 
 	for i := range inputs {
 		if inputs[i].Multimodal == nil {
 			if len(images) > 0 {
-				inputs[i].Multimodal = images[0].Multimodal
+				inputs[i].Multimodal = []ml.Tensor{images[0].Multimodal.(ml.Tensor)}
 				inputs[i].MultimodalHash = images[0].MultimodalHash
 				for j := 1; j < len(images); j++ {
-					inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3)
+					inputs[i].Multimodal = append(inputs[i].Multimodal.([]ml.Tensor), images[0].Multimodal.(ml.Tensor))
 					fnvHash.Reset()
 					binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash)
 					binary.Write(fnvHash, binary.NativeEndian, inputs[j].MultimodalHash)
@@ -138,7 +138,10 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Inpu
 func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 	var crossAttentionStates ml.Tensor
 	if len(opts.Multimodal) > 0 {
-		crossAttentionStates = opts.Multimodal[len(opts.Multimodal)-1].Multimodal.(ml.Tensor)
+		images := opts.Multimodal[len(opts.Multimodal)-1].Multimodal.([]ml.Tensor)
+		if len(images) > 0 {
+			crossAttentionStates = images[len(images)-1]
+		}
 	}
 
 	inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))

+ 24 - 13
runner/ollamarunner/runner.go

@@ -34,10 +34,14 @@ import (
 	_ "github.com/ollama/ollama/model/models"
 )
 
+type contextList struct {
+	list []ml.Context
+}
+
 type Sequence struct {
-	// ctx for allocating tensors that last the lifetime of the sequence, such as
+	// ctxs are used for allocating tensors that last the lifetime of the sequence, such as
 	// multimodal embeddings
-	ctx ml.Context
+	ctxs *contextList
 
 	// batch index
 	iBatch int
@@ -99,9 +103,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
 	s.ready.Wait()
 
 	startTime := time.Now()
-	ctx := s.model.Backend().NewContext()
 
-	inputs, err := s.inputs(ctx, prompt, images)
+	inputs, ctxs, err := s.inputs(prompt, images)
 	if err != nil {
 		return nil, fmt.Errorf("failed to process inputs: %w", err)
 	} else if len(inputs) == 0 {
@@ -127,7 +130,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
 	// TODO(jessegross): Ingest cached history for grammar
 
 	return &Sequence{
-		ctx:                 ctx,
+		ctxs:                ctxs,
 		inputs:              inputs,
 		numPromptInputs:     len(inputs),
 		startProcessingTime: startTime,
@@ -146,7 +149,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
 // inputs processes the prompt and images into a list of inputs
 // by splitting the prompt on [img-<n>] tags, tokenizing text and
 // decoding images
-func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) ([]input.Input, error) {
+func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *contextList, error) {
 	var inputs []input.Input
 	var parts []string
 	var matches [][]string
@@ -161,12 +164,19 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) (
 		parts = []string{prompt}
 	}
 
+	var contexts contextList
+	runtime.AddCleanup(&contexts, func(ctxs []ml.Context) {
+		for _, ctx := range ctxs {
+			ctx.Close()
+		}
+	}, contexts.list)
+
 	postTokenize := false
 	for i, part := range parts {
 		// text - tokenize
 		tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
 		if err != nil {
-			return nil, err
+			return nil, nil, err
 		}
 
 		for _, t := range tokens {
@@ -186,12 +196,14 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) (
 			}
 
 			if imageIndex < 0 {
-				return nil, fmt.Errorf("invalid image index: %d", n)
+				return nil, nil, fmt.Errorf("invalid image index: %d", n)
 			}
 
+			ctx := s.model.Backend().NewContext()
+			contexts.list = append(contexts.list, ctx)
 			imageEmbeddings, err := multimodalProcessor.EncodeMultimodal(ctx, images[imageIndex].Data)
 			if err != nil {
-				return nil, err
+				return nil, nil, err
 			}
 
 			s.multimodalHash.Reset()
@@ -205,13 +217,13 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []llm.ImageData) (
 
 	if visionModel && postTokenize {
 		var err error
-		inputs, err = multimodalProcessor.PostTokenize(ctx, inputs)
+		inputs, err = multimodalProcessor.PostTokenize(inputs)
 		if err != nil {
-			return nil, err
+			return nil, nil, err
 		}
 	}
 
-	return inputs, nil
+	return inputs, &contexts, nil
 }
 
 type Server struct {
@@ -306,7 +318,6 @@ func (s *Server) removeSequence(seqIndex int, reason string) {
 	close(seq.responses)
 	close(seq.embedding)
 	seq.cache.InUse = false
-	seq.ctx.Close()
 	s.seqs[seqIndex] = nil
 	s.seqsSem.Release(1)
 }