瀏覽代碼

model: Pass input tensor instead of raw data to models

Rather than directly giving the input data to models, we can
pass a tensor instead. In the short term, this saves some duplicated
code.

Longer term, we will want to overlap setting up the next batch with
processing of the current one. In this case, we will only have the
shape of tensor but it will not be loaded with data at the time of
graph generation. By passing only a tensor to models now, we set up
this possibility and prevent them from relying on data that they won't
have in the future.

Although the same could be done for Positions and Outputs, in some
cases we either need the raw input data or don't use them at all.
Therefore, for now we leave them as they are and allow models to
convert them to tensors as needed.
Jesse Gross 1 月之前
父節點
當前提交
0fbfcf3c9c

+ 3 - 1
model/input/input.go

@@ -1,5 +1,7 @@
 package input
 
+import "github.com/ollama/ollama/ml"
+
 // Input represents one token in the input stream
 type Input struct {
 	// Token is a single element of text.
@@ -36,7 +38,7 @@ type MultimodalIndex struct {
 // Batch contains the inputs for a model forward pass
 type Batch struct {
 	// Inputs is the input tokens, including placeholders for multimodal inputs.
-	Inputs []int32
+	Inputs ml.Tensor
 
 	// Multimodal is a set of multimodal embeddings previously created by
 	// EncodeMultimodal, along with an index into Inputs. Unused for text-only

+ 7 - 1
model/model.go

@@ -280,7 +280,7 @@ func canNil(t reflect.Type) bool {
 		t.Kind() == reflect.Slice
 }
 
-func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) {
+func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) {
 	if len(batch.Positions) != len(batch.Sequences) {
 		return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
 	}
@@ -289,6 +289,12 @@ func Forward(ctx ml.Context, m Model, batch input.Batch) (ml.Tensor, error) {
 		return nil, errors.New("batch size cannot be less than 1")
 	}
 
+	var err error
+	batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
+	if err != nil {
+		return nil, err
+	}
+
 	cache := m.Config().Cache
 	if cache != nil {
 		err := cache.StartForward(ctx, batch)

+ 1 - 6
model/models/gemma2/model.go

@@ -169,11 +169,6 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
 }
 
 func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
-	inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs))
-	if err != nil {
-		return nil, err
-	}
-
 	positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
 	if err != nil {
 		return nil, err
@@ -184,7 +179,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
 		return nil, err
 	}
 
-	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
+	hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
 	hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
 
 	if len(m.Layers) == gemma27BLayerCount {

+ 1 - 6
model/models/gemma3/model.go

@@ -140,11 +140,6 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
 }
 
 func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
-	inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs))
-	if err != nil {
-		return nil, err
-	}
-
 	positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
 	if err != nil {
 		return nil, err
@@ -155,7 +150,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
 		return nil, err
 	}
 
-	return m.TextModel.Forward(ctx, inputs, positions, outputs, batch, m.Cache), nil
+	return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
 }
 
 func init() {

+ 1 - 6
model/models/llama/model.go

@@ -140,11 +140,6 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
 }
 
 func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
-	inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs))
-	if err != nil {
-		return nil, err
-	}
-
 	positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
 	if err != nil {
 		return nil, err
@@ -155,7 +150,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
 		return nil, err
 	}
 
-	hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
+	hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
 
 	for i, layer := range m.Layers {
 		m.Cache.SetLayer(i)

+ 1 - 6
model/models/mllama/model.go

@@ -144,11 +144,6 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
 		}
 	}
 
-	inputs, err := ctx.Input().FromIntSlice(batch.Inputs, len(batch.Inputs))
-	if err != nil {
-		return nil, err
-	}
-
 	positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
 	if err != nil {
 		return nil, err
@@ -160,7 +155,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
 	}
 
 	// TODO: attention mask, cross attention mask
-	return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
+	return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
 }
 
 func init() {

+ 6 - 5
runner/ollamarunner/runner.go

@@ -348,6 +348,7 @@ func (s *Server) processBatch() error {
 	}
 	defer s.mu.Unlock()
 
+	var batchInputs []int32
 	var batch input.Batch
 
 	for i, seq := range s.seqs {
@@ -395,9 +396,9 @@ func (s *Server) processBatch() error {
 				}
 			}
 
-			batch.Inputs = append(batch.Inputs, inp.Token)
+			batchInputs = append(batchInputs, inp.Token)
 			if inp.Multimodal != nil {
-				batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batch.Inputs) - 1, Multimodal: inp.Multimodal})
+				batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: inp.Multimodal})
 			}
 
 			batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
@@ -405,7 +406,7 @@ func (s *Server) processBatch() error {
 
 			seq.iBatch = len(batch.Outputs)
 			if j+1 == len(seq.inputs) {
-				batch.Outputs = append(batch.Outputs, int32(len(batch.Inputs)-1))
+				batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
 			}
 			seq.pendingInputs = append(seq.pendingInputs, inp)
 		}
@@ -413,14 +414,14 @@ func (s *Server) processBatch() error {
 		seq.inputs = seq.inputs[len(seq.pendingInputs):]
 	}
 
-	if len(batch.Inputs) == 0 {
+	if len(batchInputs) == 0 {
 		return nil
 	}
 
 	ctx := s.model.Backend().NewContext()
 	defer ctx.Close()
 
-	modelOutput, err := model.Forward(ctx, s.model, batch)
+	modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
 	if err != nil {
 		return fmt.Errorf("failed to decode batch: %w", err)
 	}