|
@@ -198,9 +198,6 @@ func incompleteUnicode(token string) bool {
|
|
|
}
|
|
|
|
|
|
func (s *Server) run(ctx context.Context) {
|
|
|
- batch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs))
|
|
|
- defer batch.Free()
|
|
|
-
|
|
|
// build up stop sequences as we recognize them
|
|
|
// TODO (jmorganca): simplify this
|
|
|
pieces := make([][]string, s.parallel)
|
|
@@ -210,158 +207,166 @@ func (s *Server) run(ctx context.Context) {
|
|
|
case <-ctx.Done():
|
|
|
return
|
|
|
default:
|
|
|
- slog.Debug("Processing batch", "seqs", len(s.seqs))
|
|
|
- s.mu.Lock()
|
|
|
- for s.allNil() {
|
|
|
- s.cond.Wait() // Wait until an item is added
|
|
|
- }
|
|
|
- s.mu.Unlock()
|
|
|
-
|
|
|
- for i, seq := range s.seqs {
|
|
|
- if seq == nil {
|
|
|
- continue
|
|
|
- }
|
|
|
-
|
|
|
- // if past the num predict limit
|
|
|
- if seq.numPredict > 0 && seq.numPredicted > seq.numPredict {
|
|
|
- seq.doneReason = "limit"
|
|
|
- close(seq.responses)
|
|
|
- s.lc.KvCacheSeqRm(i, 0, -1)
|
|
|
- s.seqs[i] = nil
|
|
|
- continue
|
|
|
- }
|
|
|
-
|
|
|
- if seq.nPast+len(seq.tokens) > s.numCtx {
|
|
|
- s.shiftContext(i)
|
|
|
- }
|
|
|
-
|
|
|
- if seq.t_start_process_prompt.IsZero() {
|
|
|
- seq.t_start_process_prompt = time.Now()
|
|
|
- }
|
|
|
-
|
|
|
- var numTokensProcessed int
|
|
|
- for j, t := range seq.tokens {
|
|
|
- // todo: make this n_batch
|
|
|
- if j >= s.batchSize {
|
|
|
- break
|
|
|
- }
|
|
|
- batch.Add(t, seq.nPast, []int{i}, numTokensProcessed+1 == len(seq.tokens))
|
|
|
- seq.nPast++
|
|
|
- numTokensProcessed++
|
|
|
- }
|
|
|
- seq.tokens = seq.tokens[numTokensProcessed:]
|
|
|
- seq.iBatch = batch.NumTokens() - 1
|
|
|
- }
|
|
|
+ pieces = s.processBatch(pieces)
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (s *Server) processBatch(pieces [][]string) [][]string {
|
|
|
+ batch := llama.NewBatch(s.batchSize*len(s.seqs), 0, len(s.seqs))
|
|
|
+ defer batch.Free()
|
|
|
+
|
|
|
+ s.mu.Lock()
|
|
|
+ for s.allNil() {
|
|
|
+ s.cond.Wait() // Wait until an item is added
|
|
|
+ }
|
|
|
+ defer s.mu.Unlock()
|
|
|
+
|
|
|
+ slog.Debug("Processing batch", "seqs", len(s.seqs))
|
|
|
+
|
|
|
+ for i, seq := range s.seqs {
|
|
|
+ if seq == nil {
|
|
|
+ continue
|
|
|
+ }
|
|
|
|
|
|
- if batch.NumTokens() == 0 {
|
|
|
- continue
|
|
|
+ // if past the num predict limit
|
|
|
+ if seq.numPredict > 0 && seq.numPredicted > seq.numPredict {
|
|
|
+ seq.doneReason = "limit"
|
|
|
+ close(seq.responses)
|
|
|
+ s.lc.KvCacheSeqRm(i, 0, -1)
|
|
|
+ s.seqs[i] = nil
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ if seq.nPast+len(seq.tokens) > s.numCtx {
|
|
|
+ s.shiftContext(i)
|
|
|
+ }
|
|
|
+
|
|
|
+ if seq.t_start_process_prompt.IsZero() {
|
|
|
+ seq.t_start_process_prompt = time.Now()
|
|
|
+ }
|
|
|
+
|
|
|
+ var numTokensProcessed int
|
|
|
+ for j, t := range seq.tokens {
|
|
|
+ // todo: make this n_batch
|
|
|
+ if j >= s.batchSize {
|
|
|
+ break
|
|
|
}
|
|
|
+ batch.Add(t, seq.nPast, []int{i}, numTokensProcessed+1 == len(seq.tokens))
|
|
|
+ seq.nPast++
|
|
|
+ numTokensProcessed++
|
|
|
+ }
|
|
|
+ seq.tokens = seq.tokens[numTokensProcessed:]
|
|
|
+ seq.iBatch = batch.NumTokens() - 1
|
|
|
+ }
|
|
|
+
|
|
|
+ if batch.NumTokens() == 0 {
|
|
|
+ return pieces
|
|
|
+ }
|
|
|
|
|
|
- err := s.lc.Decode(batch)
|
|
|
- if err != nil {
|
|
|
- slog.Error("failed to decode batch", "error", err)
|
|
|
- panic("Failed to decode")
|
|
|
+ err := s.lc.Decode(batch)
|
|
|
+ if err != nil {
|
|
|
+ slog.Error("failed to decode batch", "error", err)
|
|
|
+ panic("Failed to decode")
|
|
|
+ }
|
|
|
+
|
|
|
+ for i, seq := range s.seqs {
|
|
|
+ if seq == nil {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ // don't sample prompt processing
|
|
|
+ if len(seq.tokens) != 0 {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ // if done processing the prompt, generating an embedding and return
|
|
|
+ if seq.embeddingOnly {
|
|
|
+ embd := s.lc.GetEmbeddingsSeq(i)
|
|
|
+ if embd == nil {
|
|
|
+ embd = s.lc.GetEmbeddingsIth(seq.iBatch)
|
|
|
}
|
|
|
|
|
|
- for i, seq := range s.seqs {
|
|
|
- if seq == nil {
|
|
|
- continue
|
|
|
- }
|
|
|
-
|
|
|
- // don't sample prompt processing
|
|
|
- if len(seq.tokens) != 0 {
|
|
|
- continue
|
|
|
- }
|
|
|
-
|
|
|
- // if done processing the prompt, generating an embedding and return
|
|
|
- if seq.embeddingOnly {
|
|
|
- embd := s.lc.GetEmbeddingsSeq(i)
|
|
|
- if embd == nil {
|
|
|
- embd = s.lc.GetEmbeddingsIth(seq.iBatch)
|
|
|
- }
|
|
|
-
|
|
|
- seq.embedding <- embd
|
|
|
- close(seq.embedding)
|
|
|
- s.lc.KvCacheSeqRm(i, 0, -1)
|
|
|
- s.seqs[i] = nil
|
|
|
- continue
|
|
|
- }
|
|
|
-
|
|
|
- // sample a token
|
|
|
- // logits := s.lc.GetLogitsIth(ibatch[i])
|
|
|
- // token := s.lc.SampleTokenGreedy(logits)
|
|
|
- token := seq.samplingCtx.Sample(s.lc, nil, seq.iBatch)
|
|
|
-
|
|
|
- seq.samplingCtx.Accept(s.lc, token, true)
|
|
|
- seq.n_decoded += 1
|
|
|
- if seq.n_decoded == 1 {
|
|
|
- seq.t_start_genereration = time.Now()
|
|
|
- }
|
|
|
- piece := s.model.TokenToPiece(token)
|
|
|
-
|
|
|
- seq.numPredicted++
|
|
|
-
|
|
|
- slog.Debug("sampled", "piece", piece)
|
|
|
-
|
|
|
- // if it's an end of sequence token, break
|
|
|
- // TODO: just end this sequence
|
|
|
- if s.model.TokenIsEog(token) {
|
|
|
- // TODO: end the sequence instead of quitting the pool
|
|
|
- s.lc.KvCacheSeqRm(i, 0, -1)
|
|
|
-
|
|
|
- // TODO (jmorganca): we should send this back
|
|
|
- // as it's important for the /api/generate context
|
|
|
- // seq.responses <- piece
|
|
|
-
|
|
|
- seq.doneReason = "stop"
|
|
|
- close(seq.responses)
|
|
|
- seq.samplingCtx.Free()
|
|
|
- pieces[i] = []string{}
|
|
|
- s.seqs[i] = nil
|
|
|
- continue
|
|
|
- }
|
|
|
-
|
|
|
- seq.tokens = []int{token}
|
|
|
-
|
|
|
- pieces[i] = append(pieces[i], piece)
|
|
|
- sequence := strings.Join(pieces[i], "")
|
|
|
-
|
|
|
- if incompleteUnicode(sequence) {
|
|
|
- continue
|
|
|
- }
|
|
|
-
|
|
|
- if ok, stop := findStop(sequence, seq.stop); ok {
|
|
|
- slog.Info("hit stop token", "stop", seq.stop)
|
|
|
-
|
|
|
- truncated := truncateStop(pieces[i], stop)
|
|
|
-
|
|
|
- for _, p := range truncated {
|
|
|
- seq.responses <- p
|
|
|
- }
|
|
|
-
|
|
|
- s.lc.KvCacheSeqRm(i, 0, -1)
|
|
|
- seq.doneReason = "stop"
|
|
|
- close(seq.responses)
|
|
|
- seq.samplingCtx.Free()
|
|
|
- pieces[i] = []string{}
|
|
|
- s.seqs[i] = nil
|
|
|
- continue
|
|
|
- }
|
|
|
-
|
|
|
- if containsStopSuffix(sequence, seq.stop) {
|
|
|
- continue
|
|
|
- }
|
|
|
-
|
|
|
- for _, p := range pieces[i] {
|
|
|
- seq.responses <- p
|
|
|
- }
|
|
|
-
|
|
|
- pieces[i] = []string{}
|
|
|
+ seq.embedding <- embd
|
|
|
+ close(seq.embedding)
|
|
|
+ s.lc.KvCacheSeqRm(i, 0, -1)
|
|
|
+ s.seqs[i] = nil
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ // sample a token
|
|
|
+ // logits := s.lc.GetLogitsIth(ibatch[i])
|
|
|
+ // token := s.lc.SampleTokenGreedy(logits)
|
|
|
+ token := seq.samplingCtx.Sample(s.lc, nil, seq.iBatch)
|
|
|
+
|
|
|
+ seq.samplingCtx.Accept(s.lc, token, true)
|
|
|
+ seq.n_decoded += 1
|
|
|
+ if seq.n_decoded == 1 {
|
|
|
+ seq.t_start_genereration = time.Now()
|
|
|
+ }
|
|
|
+ piece := s.model.TokenToPiece(token)
|
|
|
+
|
|
|
+ seq.numPredicted++
|
|
|
+
|
|
|
+ slog.Debug("sampled", "piece", piece)
|
|
|
+
|
|
|
+ // if it's an end of sequence token, break
|
|
|
+ // TODO: just end this sequence
|
|
|
+ if s.model.TokenIsEog(token) {
|
|
|
+ // TODO: end the sequence instead of quitting the pool
|
|
|
+ s.lc.KvCacheSeqRm(i, 0, -1)
|
|
|
+
|
|
|
+ // TODO (jmorganca): we should send this back
|
|
|
+ // as it's important for the /api/generate context
|
|
|
+ // seq.responses <- piece
|
|
|
+
|
|
|
+ seq.doneReason = "stop"
|
|
|
+ close(seq.responses)
|
|
|
+ seq.samplingCtx.Free()
|
|
|
+ pieces[i] = []string{}
|
|
|
+ s.seqs[i] = nil
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ seq.tokens = []int{token}
|
|
|
+
|
|
|
+ pieces[i] = append(pieces[i], piece)
|
|
|
+ sequence := strings.Join(pieces[i], "")
|
|
|
+
|
|
|
+ if incompleteUnicode(sequence) {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ if ok, stop := findStop(sequence, seq.stop); ok {
|
|
|
+ slog.Info("hit stop token", "stop", seq.stop)
|
|
|
+
|
|
|
+ truncated := truncateStop(pieces[i], stop)
|
|
|
+
|
|
|
+ for _, p := range truncated {
|
|
|
+ seq.responses <- p
|
|
|
}
|
|
|
|
|
|
- batch.Clear()
|
|
|
+ s.lc.KvCacheSeqRm(i, 0, -1)
|
|
|
+ seq.doneReason = "stop"
|
|
|
+ close(seq.responses)
|
|
|
+ seq.samplingCtx.Free()
|
|
|
+ pieces[i] = []string{}
|
|
|
+ s.seqs[i] = nil
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ if containsStopSuffix(sequence, seq.stop) {
|
|
|
+ continue
|
|
|
}
|
|
|
+
|
|
|
+ for _, p := range pieces[i] {
|
|
|
+ seq.responses <- p
|
|
|
+ }
|
|
|
+
|
|
|
+ pieces[i] = []string{}
|
|
|
}
|
|
|
+
|
|
|
+ return pieces
|
|
|
}
|
|
|
|
|
|
type Options struct {
|