瀏覽代碼

wip stop tokens

jmorganca 11 月之前
父節點
當前提交
240d4cf0aa
共有 1 個文件被更改,包括 85 次插入6 次删除
  1. 85 6
      llama/runner/runner.go

+ 85 - 6
llama/runner/runner.go

@@ -10,6 +10,7 @@ import (
 	"net"
 	"net/http"
 	"strconv"
+	"strings"
 	"sync"
 
 	"github.com/ollama/ollama/api"
@@ -31,6 +32,9 @@ type Sequence struct {
 	// channel to send back the embedding if embedding only
 	embedding chan []float32
 
+	// stop sequences
+	stop []string
+
 	// true if an embedding are to be returned instead of text generation
 	embeddingOnly bool
 }
@@ -40,7 +44,7 @@ func (s *Sequence) prompt() bool {
 	return s.nPast < len(s.tokens)-1
 }
 
-func (s *Server) NewSequence(prompt string, params *llama.SamplingParams, embedding bool) *Sequence {
+func (s *Server) NewSequence(prompt string, stop []string, params *llama.SamplingParams, embedding bool) *Sequence {
 	tokens, err := s.lc.Model().Tokenize(prompt, 2048, false, true)
 	if err != nil {
 		panic(err)
@@ -60,6 +64,7 @@ func (s *Server) NewSequence(prompt string, params *llama.SamplingParams, embedd
 		embedding:     make(chan []float32, 1),
 		samplingCtx:   sc,
 		embeddingOnly: embedding,
+		stop:          stop,
 	}
 }
 
@@ -72,6 +77,7 @@ type Server struct {
 	parallel int
 
 	// seqs is the list of parallel sequences being evaluated
+	// TODO (jmorganca): this can probably be moved into run()
 	seqs []*Sequence
 
 	mu sync.Mutex
@@ -88,10 +94,36 @@ func (s *Server) allNil() bool {
 	return true
 }
 
+func contains(sequence string, stops []string) (bool, string) {
+	for _, stop := range stops {
+		if strings.Contains(sequence, stop) {
+			return true, stop
+		}
+	}
+
+	return false, ""
+}
+
+func overlap(sequence string, stops []string) bool {
+	for _, stop := range stops {
+		for i := 1; i < len(stop); i++ {
+			if strings.HasSuffix(sequence, stop[:i]) {
+				return true
+			}
+		}
+	}
+
+	return false
+}
+
 func (s *Server) run(ctx context.Context) {
 	batch := llama.NewBatch(512, 0, s.parallel)
 	defer batch.Free()
 
+	// build up stop sequences as we recognize them
+	// TODO (jmorganca): simplify this
+	sofar := make([][]string, s.parallel)
+
 	for {
 		select {
 		case <-ctx.Done():
@@ -165,21 +197,67 @@ func (s *Server) run(ctx context.Context) {
 				// logits := s.lc.GetLogitsIth(ibatch[i])
 				// token := s.lc.SampleTokenGreedy(logits)
 				token := seq.samplingCtx.Sample(s.lc, nil, ibatch[i])
-				seq.samplingCtx.Accept(s.lc, token, true)
 
-				seq.responses <- s.model.TokenToPiece(token)
-				seq.tokens = []int{token}
+				seq.samplingCtx.Accept(s.lc, token, true)
+				piece := s.model.TokenToPiece(token)
+				slog.Info("sampled", "piece", piece)
 
 				// if it's an end of sequence token, break
 				// TODO: just end this sequence
 				if s.model.TokenIsEog(token) {
 					// TODO: end the sequence instead of quitting the pool
 					s.lc.KvCacheSeqRm(i, 0, -1)
+
+					// TODO (jmorganca): we should send this back
+					// as it's important for the /api/generate context
+					// seq.responses <- piece
+
 					close(seq.responses)
 					seq.samplingCtx.Free()
+					sofar[i] = []string{}
 					s.seqs[i] = nil
 					continue
 				}
+
+				seq.tokens = []int{token}
+
+				// recognize stop sequences
+				// TODO (jmorganca): add tests around this
+				// TODO (jmorganca): send back parital piece
+
+				sequence := strings.Join(append(sofar[i], piece), "")
+				if ok, stop := contains(sequence, seq.stop); ok {
+					slog.Info("hit stop token", "stop", seq.stop)
+					for _, p := range sofar[i] {
+						seq.responses <- p
+					}
+
+					piece, _, _ := strings.Cut(piece, stop)
+					seq.responses <- piece
+
+					s.lc.KvCacheSeqRm(i, 0, -1)
+					close(seq.responses)
+					seq.samplingCtx.Free()
+					sofar[i] = []string{}
+					s.seqs[i] = nil
+					continue
+				}
+
+				if overlap(sequence, seq.stop) {
+					slog.Info("overlap", "sequence", sequence)
+					// partial stop, don't send
+					continue
+				}
+
+				slog.Info("sending", "sofar", sofar[i])
+
+				sofar[i] = append(sofar[i], piece)
+
+				for _, p := range sofar[i] {
+					seq.responses <- p
+				}
+
+				sofar[i] = []string{}
 			}
 
 			batch.Clear()
@@ -191,6 +269,7 @@ type CompletionRequest struct {
 	Prompt  string   `json:"prompt"`
 	Images  []string `json:"images"`
 	Grammar string   `json:"grammar"`
+	Stop    []string `json:"stop"`
 
 	api.Options
 }
@@ -228,7 +307,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 	samplingParams.Seed = uint32(req.Seed)
 	samplingParams.Grammar = req.Grammar
 
-	seq := s.NewSequence(req.Prompt, &samplingParams, false)
+	seq := s.NewSequence(req.Prompt, req.Stop, &samplingParams, false)
 
 	// TODO (jmorganca): add to sequence queue instead of
 	// failing if a slot isn't available
@@ -279,7 +358,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
 
 	w.Header().Set("Content-Type", "application/json")
 
-	seq := s.NewSequence(req.Prompt, nil, true)
+	seq := s.NewSequence(req.Prompt, nil, nil, true)
 
 	s.mu.Lock()
 	for i, sq := range s.seqs {