|
@@ -49,6 +49,9 @@ type Sequence struct {
|
|
// stop sequences
|
|
// stop sequences
|
|
stop []string
|
|
stop []string
|
|
|
|
|
|
|
|
+ // number of tokens to keep at the beginning when shifting context window
|
|
|
|
+ numKeep int
|
|
|
|
+
|
|
// true if an embedding are to be returned instead of text generation
|
|
// true if an embedding are to be returned instead of text generation
|
|
embeddingOnly bool
|
|
embeddingOnly bool
|
|
|
|
|
|
@@ -61,22 +64,38 @@ type Sequence struct {
|
|
n_prompt_tokens int
|
|
n_prompt_tokens int
|
|
}
|
|
}
|
|
|
|
|
|
-func (s *Server) NewSequence(prompt string, numPredict int, stop []string, params *llama.SamplingParams, embedding bool) *Sequence {
|
|
|
|
|
|
+type NewSequenceParams struct {
|
|
|
|
+ numPredict int
|
|
|
|
+ stop []string
|
|
|
|
+ numKeep int
|
|
|
|
+ samplingParams *llama.SamplingParams
|
|
|
|
+ embedding bool
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+func (s *Server) NewSequence(prompt string, params NewSequenceParams) *Sequence {
|
|
tokens, err := s.lc.Model().Tokenize(prompt, true, true)
|
|
tokens, err := s.lc.Model().Tokenize(prompt, true, true)
|
|
if err != nil {
|
|
if err != nil {
|
|
panic(err)
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
|
|
- // truncate to last n tokens
|
|
|
|
- // TODO: this shouldn't happen and will severely impact generation
|
|
|
|
- // quality. instead we should ensure to cut prompt in the API.
|
|
|
|
|
|
+ if params.numKeep < 0 {
|
|
|
|
+ params.numKeep = len(tokens)
|
|
|
|
+ }
|
|
|
|
+ // Subtracting 4 ensures that at least 1 token can be discarded during shift
|
|
|
|
+ params.numKeep = min(params.numKeep, s.numCtx-4)
|
|
|
|
+ params.numKeep += s.bosToken
|
|
|
|
+
|
|
|
|
+ // truncate to fit in context window
|
|
if len(tokens) > s.numCtx {
|
|
if len(tokens) > s.numCtx {
|
|
- tokens = tokens[:s.numCtx]
|
|
|
|
|
|
+ slog.Warn("truncating input prompt", "limit", s.numCtx, "prompt", len(tokens), "numKeep", params.numKeep)
|
|
|
|
+ newTokens := tokens[:params.numKeep]
|
|
|
|
+ newTokens = append(newTokens, tokens[len(tokens)-s.numCtx+params.numKeep:]...)
|
|
|
|
+ tokens = newTokens
|
|
}
|
|
}
|
|
|
|
|
|
var sc *llama.SamplingContext
|
|
var sc *llama.SamplingContext
|
|
- if params != nil {
|
|
|
|
- sc = llama.NewSamplingContext(*params)
|
|
|
|
|
|
+ if params.samplingParams != nil {
|
|
|
|
+ sc = llama.NewSamplingContext(*params.samplingParams)
|
|
for _, t := range tokens {
|
|
for _, t := range tokens {
|
|
sc.Accept(s.lc, t, false)
|
|
sc.Accept(s.lc, t, false)
|
|
}
|
|
}
|
|
@@ -85,12 +104,13 @@ func (s *Server) NewSequence(prompt string, numPredict int, stop []string, param
|
|
return &Sequence{
|
|
return &Sequence{
|
|
tokens: tokens,
|
|
tokens: tokens,
|
|
n_prompt_tokens: len(tokens),
|
|
n_prompt_tokens: len(tokens),
|
|
- numPredict: numPredict,
|
|
|
|
|
|
+ numPredict: params.numPredict,
|
|
responses: make(chan string, 1),
|
|
responses: make(chan string, 1),
|
|
embedding: make(chan []float32, 1),
|
|
embedding: make(chan []float32, 1),
|
|
samplingCtx: sc,
|
|
samplingCtx: sc,
|
|
- embeddingOnly: embedding,
|
|
|
|
- stop: stop,
|
|
|
|
|
|
+ embeddingOnly: params.embedding,
|
|
|
|
+ stop: params.stop,
|
|
|
|
+ numKeep: params.numKeep,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
@@ -111,6 +131,9 @@ type Server struct {
|
|
// context window size
|
|
// context window size
|
|
numCtx int
|
|
numCtx int
|
|
|
|
|
|
|
|
+ // does this model require a beginning of sequence token?
|
|
|
|
+ bosToken int
|
|
|
|
+
|
|
mu sync.Mutex
|
|
mu sync.Mutex
|
|
|
|
|
|
cond *sync.Cond
|
|
cond *sync.Cond
|
|
@@ -129,6 +152,21 @@ func (s *Server) allNil() bool {
|
|
return true
|
|
return true
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+func (s *Server) shiftContext(seqIndex int) {
|
|
|
|
+ seq := s.seqs[seqIndex]
|
|
|
|
+
|
|
|
|
+ numLeft := seq.nPast - seq.numKeep
|
|
|
|
+ numDiscard := numLeft / 2
|
|
|
|
+
|
|
|
|
+ slog.Debug("context limit hit - shifting", "limit", s.numCtx, "nPast", seq.nPast,
|
|
|
|
+ "numKeep", seq.numKeep, "numLeft", numLeft, "numDiscard", numDiscard)
|
|
|
|
+
|
|
|
|
+ s.lc.KvCacheSeqRm(seqIndex, seq.numKeep, seq.numKeep+numDiscard)
|
|
|
|
+ s.lc.KvCacheSeqAdd(seqIndex, seq.numKeep+numDiscard, seq.nPast, -numDiscard)
|
|
|
|
+
|
|
|
|
+ seq.nPast -= numDiscard
|
|
|
|
+}
|
|
|
|
+
|
|
func (s *Server) run(ctx context.Context) {
|
|
func (s *Server) run(ctx context.Context) {
|
|
// TODO - should this be n_ctx / parallel like the old server.cpp setup?
|
|
// TODO - should this be n_ctx / parallel like the old server.cpp setup?
|
|
batch := llama.NewBatch(s.batchSize, 0, s.parallel)
|
|
batch := llama.NewBatch(s.batchSize, 0, s.parallel)
|
|
@@ -155,10 +193,8 @@ func (s *Server) run(ctx context.Context) {
|
|
continue
|
|
continue
|
|
}
|
|
}
|
|
|
|
|
|
- hitLimit := seq.numPredict > 0 && seq.numPredicted > seq.numPredict
|
|
|
|
-
|
|
|
|
// if past the num predict limit
|
|
// if past the num predict limit
|
|
- if hitLimit || seq.nPast > s.numCtx {
|
|
|
|
|
|
+ if seq.numPredict > 0 && seq.numPredicted > seq.numPredict {
|
|
seq.doneReason = "limit"
|
|
seq.doneReason = "limit"
|
|
close(seq.responses)
|
|
close(seq.responses)
|
|
s.lc.KvCacheSeqRm(i, 0, -1)
|
|
s.lc.KvCacheSeqRm(i, 0, -1)
|
|
@@ -166,6 +202,10 @@ func (s *Server) run(ctx context.Context) {
|
|
continue
|
|
continue
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ if seq.nPast+len(seq.tokens) > s.numCtx {
|
|
|
|
+ s.shiftContext(i)
|
|
|
|
+ }
|
|
|
|
+
|
|
if seq.t_start_process_prompt.IsZero() {
|
|
if seq.t_start_process_prompt.IsZero() {
|
|
seq.t_start_process_prompt = time.Now()
|
|
seq.t_start_process_prompt = time.Now()
|
|
}
|
|
}
|
|
@@ -350,7 +390,13 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|
samplingParams.Seed = uint32(req.Seed)
|
|
samplingParams.Seed = uint32(req.Seed)
|
|
samplingParams.Grammar = req.Grammar
|
|
samplingParams.Grammar = req.Grammar
|
|
|
|
|
|
- seq := s.NewSequence(req.Prompt, req.NumPredict, req.Stop, &samplingParams, false)
|
|
|
|
|
|
+ seq := s.NewSequence(req.Prompt, NewSequenceParams{
|
|
|
|
+ numPredict: req.NumPredict,
|
|
|
|
+ stop: req.Stop,
|
|
|
|
+ numKeep: req.NumKeep,
|
|
|
|
+ samplingParams: &samplingParams,
|
|
|
|
+ embedding: false,
|
|
|
|
+ })
|
|
|
|
|
|
// TODO (jmorganca): add to sequence queue instead of
|
|
// TODO (jmorganca): add to sequence queue instead of
|
|
// failing if a slot isn't available
|
|
// failing if a slot isn't available
|
|
@@ -428,7 +474,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|
embeddings := make([][]float32, len(req.Content))
|
|
embeddings := make([][]float32, len(req.Content))
|
|
var processed int
|
|
var processed int
|
|
for i, content := range req.Content {
|
|
for i, content := range req.Content {
|
|
- seqs[i] = s.NewSequence(content, 0, nil, nil, true)
|
|
|
|
|
|
+ seqs[i] = s.NewSequence(content, NewSequenceParams{embedding: true})
|
|
}
|
|
}
|
|
|
|
|
|
// TODO - refactor to go routines to add seq's and drain the responses
|
|
// TODO - refactor to go routines to add seq's and drain the responses
|
|
@@ -563,6 +609,10 @@ func main() {
|
|
ctxParams := llama.NewContextParams(*numCtx, *threads, *flashAttention)
|
|
ctxParams := llama.NewContextParams(*numCtx, *threads, *flashAttention)
|
|
server.lc = llama.NewContextWithModel(server.model, ctxParams)
|
|
server.lc = llama.NewContextWithModel(server.model, ctxParams)
|
|
|
|
|
|
|
|
+ if server.model.ShouldAddBOSToken() {
|
|
|
|
+ server.bosToken = 1
|
|
|
|
+ }
|
|
|
|
+
|
|
if *ppath != "" {
|
|
if *ppath != "" {
|
|
server.cc = llama.NewClipContext(*ppath)
|
|
server.cc = llama.NewClipContext(*ppath)
|
|
}
|
|
}
|