Ver código fonte

ml: use input context for extracting outputs (#9875)

Jeffrey Morgan 1 mês atrás
pai
commit
da0e345200

+ 1 - 1
model/models/gemma2/model.go

@@ -179,7 +179,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 		return nil, err
 	}
 
-	outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
+	outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
 	if err != nil {
 		return nil, err
 	}

+ 1 - 1
model/models/gemma3/model.go

@@ -150,7 +150,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 		return nil, err
 	}
 
-	outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
+	outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
 	if err != nil {
 		return nil, err
 	}

+ 1 - 1
model/models/llama/model.go

@@ -150,7 +150,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 		return nil, err
 	}
 
-	outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
+	outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
 	if err != nil {
 		return nil, err
 	}

+ 1 - 1
model/models/mllama/model.go

@@ -154,7 +154,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 		return nil, err
 	}
 
-	outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
+	outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
 	if err != nil {
 		return nil, err
 	}