|
@@ -348,7 +348,7 @@ func (s *Server) processBatch() error {
|
|
}
|
|
}
|
|
defer s.mu.Unlock()
|
|
defer s.mu.Unlock()
|
|
|
|
|
|
- var options input.Options
|
|
|
|
|
|
+ var batch input.Batch
|
|
|
|
|
|
for i, seq := range s.seqs {
|
|
for i, seq := range s.seqs {
|
|
if seq == nil {
|
|
if seq == nil {
|
|
@@ -395,17 +395,17 @@ func (s *Server) processBatch() error {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
- options.Inputs = append(options.Inputs, inp.Token)
|
|
|
|
|
|
+ batch.Inputs = append(batch.Inputs, inp.Token)
|
|
if inp.Multimodal != nil {
|
|
if inp.Multimodal != nil {
|
|
- options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal})
|
|
|
|
|
|
+ batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batch.Inputs) - 1, Multimodal: inp.Multimodal})
|
|
}
|
|
}
|
|
|
|
|
|
- options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
|
|
|
- options.Sequences = append(options.Sequences, seq.cache.Id)
|
|
|
|
|
|
+ batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
|
|
|
|
+ batch.Sequences = append(batch.Sequences, seq.cache.Id)
|
|
|
|
|
|
- seq.iBatch = len(options.Outputs)
|
|
|
|
|
|
+ seq.iBatch = len(batch.Outputs)
|
|
if j+1 == len(seq.inputs) {
|
|
if j+1 == len(seq.inputs) {
|
|
- options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1))
|
|
|
|
|
|
+ batch.Outputs = append(batch.Outputs, int32(len(batch.Inputs)-1))
|
|
}
|
|
}
|
|
seq.pendingInputs = append(seq.pendingInputs, inp)
|
|
seq.pendingInputs = append(seq.pendingInputs, inp)
|
|
}
|
|
}
|
|
@@ -413,14 +413,14 @@ func (s *Server) processBatch() error {
|
|
seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
|
seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
|
}
|
|
}
|
|
|
|
|
|
- if len(options.Inputs) == 0 {
|
|
|
|
|
|
+ if len(batch.Inputs) == 0 {
|
|
return nil
|
|
return nil
|
|
}
|
|
}
|
|
|
|
|
|
ctx := s.model.Backend().NewContext()
|
|
ctx := s.model.Backend().NewContext()
|
|
defer ctx.Close()
|
|
defer ctx.Close()
|
|
|
|
|
|
- modelOutput, err := model.Forward(ctx, s.model, options)
|
|
|
|
|
|
+ modelOutput, err := model.Forward(ctx, s.model, batch)
|
|
if err != nil {
|
|
if err != nil {
|
|
return fmt.Errorf("failed to decode batch: %w", err)
|
|
return fmt.Errorf("failed to decode batch: %w", err)
|
|
}
|
|
}
|
|
@@ -460,7 +460,7 @@ func (s *Server) processBatch() error {
|
|
}
|
|
}
|
|
|
|
|
|
// sample a token
|
|
// sample a token
|
|
- vocabSize := len(logits) / len(options.Outputs)
|
|
|
|
|
|
+ vocabSize := len(logits) / len(batch.Outputs)
|
|
|
|
|
|
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
|
|
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
|
|
if err != nil {
|
|
if err != nil {
|