|
@@ -348,6 +348,7 @@ func (s *Server) processBatch() error {
|
|
|
}
|
|
|
defer s.mu.Unlock()
|
|
|
|
|
|
+ var batchInputs []int32
|
|
|
var batch input.Batch
|
|
|
|
|
|
for i, seq := range s.seqs {
|
|
@@ -395,9 +396,9 @@ func (s *Server) processBatch() error {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- batch.Inputs = append(batch.Inputs, inp.Token)
|
|
|
+ batchInputs = append(batchInputs, inp.Token)
|
|
|
if inp.Multimodal != nil {
|
|
|
- batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batch.Inputs) - 1, Multimodal: inp.Multimodal})
|
|
|
+ batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: inp.Multimodal})
|
|
|
}
|
|
|
|
|
|
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
|
@@ -405,7 +406,7 @@ func (s *Server) processBatch() error {
|
|
|
|
|
|
seq.iBatch = len(batch.Outputs)
|
|
|
if j+1 == len(seq.inputs) {
|
|
|
- batch.Outputs = append(batch.Outputs, int32(len(batch.Inputs)-1))
|
|
|
+ batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
|
|
|
}
|
|
|
seq.pendingInputs = append(seq.pendingInputs, inp)
|
|
|
}
|
|
@@ -413,14 +414,14 @@ func (s *Server) processBatch() error {
|
|
|
seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
|
|
}
|
|
|
|
|
|
- if len(batch.Inputs) == 0 {
|
|
|
+ if len(batchInputs) == 0 {
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
ctx := s.model.Backend().NewContext()
|
|
|
defer ctx.Close()
|
|
|
|
|
|
- modelOutput, err := model.Forward(ctx, s.model, batch)
|
|
|
+ modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
|
|
|
if err != nil {
|
|
|
return fmt.Errorf("failed to decode batch: %w", err)
|
|
|
}
|