Procházet zdrojové kódy

input: Rename Options to Batch

Options is no longer very descriptive of this struct.
Jesse Gross před 1 měsícem
rodič
revize
0c220935bd

+ 1 - 1
kvcache/cache.go

@@ -52,7 +52,7 @@ type Cache interface {
 	// StartForward is called before the start of the model's forward pass.
 	// For each token in the coming batch, there must be a corresponding
 	// entry in positions and seqs.
-	StartForward(ctx ml.Context, opts input.Options) error
+	StartForward(ctx ml.Context, batch input.Batch) error
 
 	// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
 	CopyPrefix(srcSeq, dstSeq int, len int32)

+ 6 - 6
kvcache/causal.go

@@ -140,10 +140,10 @@ func (c *Causal) Close() {
 	}
 }
 
-func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
-	c.curBatchSize = len(opts.Positions)
-	c.curSequences = opts.Sequences
-	c.curPositions = opts.Positions
+func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error {
+	c.curBatchSize = len(batch.Positions)
+	c.curSequences = batch.Sequences
+	c.curPositions = batch.Positions
 	c.opts.Except = nil
 
 	var err error
@@ -157,8 +157,8 @@ func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
 	}
 
 	c.curCellRange = newRange()
-	for i, pos := range opts.Positions {
-		seq := opts.Sequences[i]
+	for i, pos := range batch.Positions {
+		seq := batch.Sequences[i]
 
 		c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
 

+ 1 - 1
kvcache/causal_test.go

@@ -270,7 +270,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
 			context := backend.NewContext()
 			defer context.Close()
 
-			err := cache.StartForward(context, input.Options{Positions: test.pos, Sequences: test.seqs})
+			err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs})
 			if err != nil {
 				panic(err)
 			}

+ 3 - 3
kvcache/encoder.go

@@ -79,10 +79,10 @@ func (c *EncoderCache) Close() {
 	}
 }
 
-func (c *EncoderCache) StartForward(ctx ml.Context, opts input.Options) error {
+func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch) error {
 	// We work with the most recent image
-	if len(opts.Multimodal) > 0 {
-		c.curPos = opts.Positions[opts.Multimodal[len(opts.Multimodal)-1].Index]
+	if len(batch.Multimodal) > 0 {
+		c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
 	}
 
 	return nil

+ 4 - 4
kvcache/wrapper.go

@@ -41,14 +41,14 @@ func (c *WrapperCache) Close() {
 	}
 }
 
-func (c *WrapperCache) StartForward(ctx ml.Context, opts input.Options) error {
+func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch) error {
 	for i, cache := range c.caches {
-		err := cache.StartForward(ctx, opts)
+		err := cache.StartForward(ctx, batch)
 		if err != nil {
 			// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
 			for j := i - 1; j >= 0; j-- {
-				for k := range opts.Positions {
-					_ = c.caches[j].Remove(opts.Sequences[k], opts.Positions[k], math.MaxInt32)
+				for k := range batch.Positions {
+					_ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
 				}
 			}
 			return err

+ 19 - 6
model/input/input.go

@@ -33,11 +33,24 @@ type MultimodalIndex struct {
 	Multimodal any
 }
 
-// Options contains the inputs for a model forward pass
-type Options struct {
-	Inputs     []int32
+// Batch contains the inputs for a model forward pass
+type Batch struct {
+	// Inputs is the input tokens, including placeholders for multimodal inputs.
+	Inputs []int32
+
+	// Multimodal is a set of multimodal embeddings previously created by
+	// EncodeMultimodal, along with an index into Inputs. Unused for text-only
+	// models or for batches without multimodal elements.
 	Multimodal []MultimodalIndex
-	Positions  []int32
-	Sequences  []int
-	Outputs    []int32
+
+	// Positions is the position for each Input, relative to its sequence. Equal
+	// in length to Inputs.
+	Positions []int32
+
+	// Sequences is the sequence for each Input. Equal in length to Inputs.
+	Sequences []int
+
+	// Outputs are the set of indicies into Inputs for which output data should
+	// be returned.
+	Outputs []int32
 }

+ 7 - 7
model/model.go

@@ -26,7 +26,7 @@ var ErrNoVisionModel = errors.New("this model is missing data required for image
 
 // Model implements a specific model architecture, defining the forward pass and any model-specific configuration
 type Model interface {
-	Forward(ml.Context, input.Options) (ml.Tensor, error)
+	Forward(ml.Context, input.Batch) (ml.Tensor, error)
 
 	Backend() ml.Backend
 	Config() config
@@ -280,24 +280,24 @@ func canNil(t reflect.Type) bool {
 		t.Kind() == reflect.Slice
 }
 
-func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) {
-	if len(opts.Positions) != len(opts.Sequences) {
-		return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
+func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) {
+	if len(batch.Positions) != len(batch.Sequences) {
+		return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
 	}
 
-	if len(opts.Positions) < 1 {
+	if len(batch.Positions) < 1 {
 		return nil, errors.New("batch size cannot be less than 1")
 	}
 
 	cache := m.Config().Cache
 	if cache != nil {
-		err := cache.StartForward(ctx, opts)
+		err := cache.StartForward(ctx, batch)
 		if err != nil {
 			return nil, err
 		}
 	}
 
-	t, err := m.Forward(ctx, opts)
+	t, err := m.Forward(ctx, batch)
 	if err != nil {
 		return nil, err
 	}

+ 1 - 1
model/model_test.go

@@ -163,7 +163,7 @@ func TestGetTextProcessor(t *testing.T) {
 
 type notTextProcessorModel struct{}
 
-func (notTextProcessorModel) Forward(ml.Context, input.Options) (ml.Tensor, error) {
+func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) {
 	panic("unimplemented")
 }
 

+ 4 - 4
model/models/gemma2/model.go

@@ -168,18 +168,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
 	return hiddenState.Add(ctx, residual)
 }
 
-func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
-	inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
+func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
+	inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs))
 	if err != nil {
 		return nil, err
 	}
 
-	positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
+	positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
 	if err != nil {
 		return nil, err
 	}
 
-	outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
+	outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
 	if err != nil {
 		return nil, err
 	}

+ 5 - 5
model/models/gemma3/model.go

@@ -139,23 +139,23 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
 	return result, nil
 }
 
-func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
-	inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
+func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
+	inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs))
 	if err != nil {
 		return nil, err
 	}
 
-	positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
+	positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
 	if err != nil {
 		return nil, err
 	}
 
-	outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
+	outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
 	if err != nil {
 		return nil, err
 	}
 
-	return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
+	return m.TextModel.Forward(ctx, inputs, positions, outputs, batch, m.Cache), nil
 }
 
 func init() {

+ 2 - 2
model/models/gemma3/model_text.go

@@ -171,13 +171,13 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
 	return hiddenState.Add(ctx, residual)
 }
 
-func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
+func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
 	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
 	hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
 
 	// set image embeddings
 	var except []int
-	for _, image := range opts.Multimodal {
+	for _, image := range batch.Multimodal {
 		visionOutputs := image.Multimodal.(ml.Tensor)
 		ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
 

+ 4 - 4
model/models/llama/model.go

@@ -139,18 +139,18 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
 	return hiddenState.Add(ctx, residual)
 }
 
-func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
-	inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
+func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
+	inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs))
 	if err != nil {
 		return nil, err
 	}
 
-	positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
+	positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
 	if err != nil {
 		return nil, err
 	}
 
-	outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
+	outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
 	if err != nil {
 		return nil, err
 	}

+ 6 - 6
model/models/mllama/model.go

@@ -135,26 +135,26 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
 	return inputs, nil
 }
 
-func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
+func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
 	var crossAttentionStates ml.Tensor
-	if len(opts.Multimodal) > 0 {
-		images := opts.Multimodal[len(opts.Multimodal)-1].Multimodal.([]ml.Tensor)
+	if len(batch.Multimodal) > 0 {
+		images := batch.Multimodal[len(batch.Multimodal)-1].Multimodal.([]ml.Tensor)
 		if len(images) > 0 {
 			crossAttentionStates = images[len(images)-1]
 		}
 	}
 
-	inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
+	inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs))
 	if err != nil {
 		return nil, err
 	}
 
-	positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
+	positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
 	if err != nil {
 		return nil, err
 	}
 
-	outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
+	outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
 	if err != nil {
 		return nil, err
 	}

+ 10 - 10
runner/ollamarunner/runner.go

@@ -348,7 +348,7 @@ func (s *Server) processBatch() error {
 	}
 	defer s.mu.Unlock()
 
-	var options input.Options
+	var batch input.Batch
 
 	for i, seq := range s.seqs {
 		if seq == nil {
@@ -395,17 +395,17 @@ func (s *Server) processBatch() error {
 				}
 			}
 
-			options.Inputs = append(options.Inputs, inp.Token)
+			batch.Inputs = append(batch.Inputs, inp.Token)
 			if inp.Multimodal != nil {
-				options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal})
+				batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batch.Inputs) - 1, Multimodal: inp.Multimodal})
 			}
 
-			options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
-			options.Sequences = append(options.Sequences, seq.cache.Id)
+			batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
+			batch.Sequences = append(batch.Sequences, seq.cache.Id)
 
-			seq.iBatch = len(options.Outputs)
+			seq.iBatch = len(batch.Outputs)
 			if j+1 == len(seq.inputs) {
-				options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1))
+				batch.Outputs = append(batch.Outputs, int32(len(batch.Inputs)-1))
 			}
 			seq.pendingInputs = append(seq.pendingInputs, inp)
 		}
@@ -413,14 +413,14 @@ func (s *Server) processBatch() error {
 		seq.inputs = seq.inputs[len(seq.pendingInputs):]
 	}
 
-	if len(options.Inputs) == 0 {
+	if len(batch.Inputs) == 0 {
 		return nil
 	}
 
 	ctx := s.model.Backend().NewContext()
 	defer ctx.Close()
 
-	modelOutput, err := model.Forward(ctx, s.model, options)
+	modelOutput, err := model.Forward(ctx, s.model, batch)
 	if err != nil {
 		return fmt.Errorf("failed to decode batch: %w", err)
 	}
@@ -460,7 +460,7 @@ func (s *Server) processBatch() error {
 		}
 
 		// sample a token
-		vocabSize := len(logits) / len(options.Outputs)
+		vocabSize := len(logits) / len(batch.Outputs)
 
 		token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
 		if err != nil {