|
@@ -23,6 +23,9 @@ type Sequence struct {
|
|
|
// number of tokens evaluated
|
|
|
nPast int
|
|
|
|
|
|
+ // batch index
|
|
|
+ iBatch int
|
|
|
+
|
|
|
// number of tokens predicted so far
|
|
|
numPredicted int
|
|
|
|
|
@@ -122,6 +125,7 @@ func (s *Server) allNil() bool {
|
|
|
}
|
|
|
|
|
|
func (s *Server) run(ctx context.Context) {
|
|
|
+ // TODO - should this be n_ctx / parallel like the old server.cpp setup?
|
|
|
batch := llama.NewBatch(s.batchSize, 0, s.parallel)
|
|
|
defer batch.Free()
|
|
|
|
|
@@ -141,8 +145,6 @@ func (s *Server) run(ctx context.Context) {
|
|
|
}
|
|
|
s.mu.Unlock()
|
|
|
|
|
|
- // prepare the batch
|
|
|
- ibatch := make([]int, s.parallel)
|
|
|
for i, seq := range s.seqs {
|
|
|
if seq == nil {
|
|
|
continue
|
|
@@ -164,14 +166,10 @@ func (s *Server) run(ctx context.Context) {
|
|
|
if j > s.batchSize {
|
|
|
break
|
|
|
}
|
|
|
-
|
|
|
batch.Add(t, seq.nPast, []int{i}, !seq.prompt())
|
|
|
seq.nPast++
|
|
|
-
|
|
|
- if seq.prompt() {
|
|
|
- ibatch[i] = batch.NumTokens() + 1
|
|
|
- }
|
|
|
}
|
|
|
+ seq.iBatch = batch.NumTokens() - 1
|
|
|
}
|
|
|
|
|
|
err := s.lc.Decode(batch)
|
|
@@ -186,12 +184,6 @@ func (s *Server) run(ctx context.Context) {
|
|
|
|
|
|
// don't sample prompt processing
|
|
|
if seq.prompt() {
|
|
|
- if len(seq.tokens) < s.batchSize {
|
|
|
- seq.tokens = []int{}
|
|
|
- } else {
|
|
|
- seq.tokens = seq.tokens[s.batchSize:]
|
|
|
- }
|
|
|
-
|
|
|
continue
|
|
|
}
|
|
|
|
|
@@ -199,7 +191,7 @@ func (s *Server) run(ctx context.Context) {
|
|
|
if seq.embeddingOnly {
|
|
|
embd := s.lc.GetEmbeddingsSeq(i)
|
|
|
if embd == nil {
|
|
|
- embd = s.lc.GetEmbeddingsIth(ibatch[i])
|
|
|
+ embd = s.lc.GetEmbeddingsIth(seq.iBatch)
|
|
|
}
|
|
|
|
|
|
seq.embedding <- embd
|
|
@@ -212,7 +204,7 @@ func (s *Server) run(ctx context.Context) {
|
|
|
// sample a token
|
|
|
// logits := s.lc.GetLogitsIth(ibatch[i])
|
|
|
// token := s.lc.SampleTokenGreedy(logits)
|
|
|
- token := seq.samplingCtx.Sample(s.lc, nil, ibatch[i])
|
|
|
+ token := seq.samplingCtx.Sample(s.lc, nil, seq.iBatch)
|
|
|
|
|
|
seq.samplingCtx.Accept(s.lc, token, true)
|
|
|
piece := s.model.TokenToPiece(token)
|