瀏覽代碼

ml: use input context for extracting outputs (#9875)

Jeffrey Morgan 1 月之前
父節點
當前提交
da0e345200
共有 4 個文件被更改,包括 4 次插入4 次删除
  1. 1 1
      model/models/gemma2/model.go
  2. 1 1
      model/models/gemma3/model.go
  3. 1 1
      model/models/llama/model.go
  4. 1 1
      model/models/mllama/model.go

+ 1 - 1
model/models/gemma2/model.go

@@ -179,7 +179,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 		return nil, err
 	}
 
-	outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
+	outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
 	if err != nil {
 		return nil, err
 	}

+ 1 - 1
model/models/gemma3/model.go

@@ -150,7 +150,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 		return nil, err
 	}
 
-	outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
+	outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
 	if err != nil {
 		return nil, err
 	}

+ 1 - 1
model/models/llama/model.go

@@ -150,7 +150,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 		return nil, err
 	}
 
-	outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
+	outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
 	if err != nil {
 		return nil, err
 	}

+ 1 - 1
model/models/mllama/model.go

@@ -154,7 +154,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 		return nil, err
 	}
 
-	outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
+	outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
 	if err != nil {
 		return nil, err
 	}