|
@@ -33,10 +33,6 @@ func (s *Sequence) prompt() bool {
|
|
|
return s.nPast < len(s.tokens)-1
|
|
|
}
|
|
|
|
|
|
-func DefaultParams() llama.SamplingParams {
|
|
|
- return llama.SamplingParams{}
|
|
|
-}
|
|
|
-
|
|
|
func (s *Server) NewSequence(r Request, w http.ResponseWriter) *Sequence {
|
|
|
var samplingParams llama.SamplingParams
|
|
|
samplingParams.TopK = r.TopK
|
|
@@ -52,6 +48,7 @@ func (s *Server) NewSequence(r Request, w http.ResponseWriter) *Sequence {
|
|
|
samplingParams.MirostatEta = r.MirostatEta
|
|
|
samplingParams.PenalizeNl = r.PenalizeNewline
|
|
|
samplingParams.Seed = uint32(r.Seed)
|
|
|
+ samplingParams.Grammar = r.Grammar
|
|
|
|
|
|
tokens, err := s.lc.Model().Tokenize(r.Prompt, 2048, false, true)
|
|
|
if err != nil {
|
|
@@ -112,8 +109,6 @@ func (s *Server) run(ctx context.Context) {
|
|
|
}
|
|
|
s.mu.Unlock()
|
|
|
|
|
|
- fmt.Println("seqs", s.seqs, len(s.seqs))
|
|
|
-
|
|
|
// prepare the batch
|
|
|
ibatch := make([]int, s.parallel)
|
|
|
for i, seq := range s.seqs {
|
|
@@ -158,15 +153,10 @@ func (s *Server) run(ctx context.Context) {
|
|
|
}
|
|
|
|
|
|
// sample a token
|
|
|
- // TODO: sample based on the sequence
|
|
|
- fmt.Println("Sampling token", i, ibatch[i])
|
|
|
- fmt.Println("calling sample", s.lc, nil, ibatch[i])
|
|
|
- token := seq.samplingCtx.Sample(s.lc, nil, ibatch[i])
|
|
|
- seq.samplingCtx.Accept(s.lc, token, true)
|
|
|
-
|
|
|
// logits := s.lc.GetLogitsIth(ibatch[i])
|
|
|
// token := s.lc.SampleTokenGreedy(logits)
|
|
|
- fmt.Println("sampled", token, s.model.TokenToPiece(token))
|
|
|
+ token := seq.samplingCtx.Sample(s.lc, nil, ibatch[i])
|
|
|
+ seq.samplingCtx.Accept(s.lc, token, true)
|
|
|
|
|
|
seq.responses <- s.model.TokenToPiece(token)
|
|
|
seq.tokens = []int{token}
|
|
@@ -177,6 +167,7 @@ func (s *Server) run(ctx context.Context) {
|
|
|
// TODO: end the sequence instead of quitting the pool
|
|
|
s.lc.KvCacheSeqRm(i, 0, -1)
|
|
|
close(seq.responses)
|
|
|
+ seq.samplingCtx.Free()
|
|
|
s.seqs[i] = nil
|
|
|
continue
|
|
|
}
|
|
@@ -188,8 +179,9 @@ func (s *Server) run(ctx context.Context) {
|
|
|
}
|
|
|
|
|
|
type Request struct {
|
|
|
- Prompt string `json:"prompt"`
|
|
|
- Images []string `json:"images"`
|
|
|
+ Prompt string `json:"prompt"`
|
|
|
+ Images []string `json:"images"`
|
|
|
+ Grammar string `json:"grammar"`
|
|
|
|
|
|
api.Options
|
|
|
}
|