소스 검색

ml: use input context for extracting outputs (#9875)

Jeffrey Morgan 1 개월 전
부모
커밋
da0e345200
4개의 변경된 파일4개의 추가작업 그리고 4개의 파일을 삭제
  1. 1 1
      model/models/gemma2/model.go
  2. 1 1
      model/models/gemma3/model.go
  3. 1 1
      model/models/llama/model.go
  4. 1 1
      model/models/mllama/model.go

+ 1 - 1
model/models/gemma2/model.go

@@ -179,7 +179,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 		return nil, err
 	}
 
-	outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
+	outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
 	if err != nil {
 		return nil, err
 	}

+ 1 - 1
model/models/gemma3/model.go

@@ -150,7 +150,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 		return nil, err
 	}
 
-	outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
+	outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
 	if err != nil {
 		return nil, err
 	}

+ 1 - 1
model/models/llama/model.go

@@ -150,7 +150,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 		return nil, err
 	}
 
-	outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
+	outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
 	if err != nil {
 		return nil, err
 	}

+ 1 - 1
model/models/mllama/model.go

@@ -154,7 +154,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
 		return nil, err
 	}
 
-	outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
+	outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
 	if err != nil {
 		return nil, err
 	}