jmorganca 11 mesi fa
parent
commit
c0b94376b2
4 ha cambiato i file con 15 aggiunte e 15 eliminazioni
  1. 6 0
      llama/llama.go
  2. 7 15
      llama/runner/runner.go
  3. 1 0
      llama/sampling_ext.cpp
  4. 1 0
      llama/sampling_ext.h

+ 6 - 0
llama/llama.go

@@ -293,6 +293,7 @@ type SamplingParams struct {
 	MirostatEta    float32
 	PenalizeNl     bool
 	Seed           uint32
+	Grammar        string
 }
 
 func NewSamplingContext(params SamplingParams) *SamplingContext {
@@ -310,6 +311,11 @@ func NewSamplingContext(params SamplingParams) *SamplingContext {
 	cparams.mirostat_eta = C.float(params.MirostatEta)
 	cparams.penalize_nl = C.bool(params.PenalizeNl)
 	cparams.seed = C.uint32_t(params.Seed)
+
+	grammar := C.CString(params.Grammar)
+	defer C.free(unsafe.Pointer(grammar))
+
+	cparams.grammar = grammar
 	return &SamplingContext{c: C.llama_sampling_cinit(&cparams)}
 }
 

+ 7 - 15
llama/runner/runner.go

@@ -33,10 +33,6 @@ func (s *Sequence) prompt() bool {
 	return s.nPast < len(s.tokens)-1
 }
 
-func DefaultParams() llama.SamplingParams {
-	return llama.SamplingParams{}
-}
-
 func (s *Server) NewSequence(r Request, w http.ResponseWriter) *Sequence {
 	var samplingParams llama.SamplingParams
 	samplingParams.TopK = r.TopK
@@ -52,6 +48,7 @@ func (s *Server) NewSequence(r Request, w http.ResponseWriter) *Sequence {
 	samplingParams.MirostatEta = r.MirostatEta
 	samplingParams.PenalizeNl = r.PenalizeNewline
 	samplingParams.Seed = uint32(r.Seed)
+	samplingParams.Grammar = r.Grammar
 
 	tokens, err := s.lc.Model().Tokenize(r.Prompt, 2048, false, true)
 	if err != nil {
@@ -112,8 +109,6 @@ func (s *Server) run(ctx context.Context) {
 			}
 			s.mu.Unlock()
 
-			fmt.Println("seqs", s.seqs, len(s.seqs))
-
 			// prepare the batch
 			ibatch := make([]int, s.parallel)
 			for i, seq := range s.seqs {
@@ -158,15 +153,10 @@ func (s *Server) run(ctx context.Context) {
 				}
 
 				// sample a token
-				// TODO: sample based on the sequence
-				fmt.Println("Sampling token", i, ibatch[i])
-				fmt.Println("calling sample", s.lc, nil, ibatch[i])
-				token := seq.samplingCtx.Sample(s.lc, nil, ibatch[i])
-				seq.samplingCtx.Accept(s.lc, token, true)
-
 				// logits := s.lc.GetLogitsIth(ibatch[i])
 				// token := s.lc.SampleTokenGreedy(logits)
-				fmt.Println("sampled", token, s.model.TokenToPiece(token))
+				token := seq.samplingCtx.Sample(s.lc, nil, ibatch[i])
+				seq.samplingCtx.Accept(s.lc, token, true)
 
 				seq.responses <- s.model.TokenToPiece(token)
 				seq.tokens = []int{token}
@@ -177,6 +167,7 @@ func (s *Server) run(ctx context.Context) {
 					// TODO: end the sequence instead of quitting the pool
 					s.lc.KvCacheSeqRm(i, 0, -1)
 					close(seq.responses)
+					seq.samplingCtx.Free()
 					s.seqs[i] = nil
 					continue
 				}
@@ -188,8 +179,9 @@ func (s *Server) run(ctx context.Context) {
 }
 
 type Request struct {
-	Prompt string   `json:"prompt"`
-	Images []string `json:"images"`
+	Prompt  string   `json:"prompt"`
+	Images  []string `json:"images"`
+	Grammar string   `json:"grammar"`
 
 	api.Options
 }

+ 1 - 0
llama/sampling_ext.cpp

@@ -17,6 +17,7 @@ struct llama_sampling_context* llama_sampling_cinit(struct llama_sampling_cparam
     sparams.mirostat_eta = params->mirostat_eta;
     sparams.penalize_nl = params->penalize_nl;
     sparams.seed = params->seed;
+    sparams.grammar = std::string(params->grammar);
     return llama_sampling_init(sparams);
 }
 

+ 1 - 0
llama/sampling_ext.h

@@ -22,6 +22,7 @@ struct llama_sampling_cparams {
     float       mirostat_eta;
     bool        penalize_nl;
     uint32_t    seed;
+    char*       grammar;
 };
 
 struct llama_sampling_context* llama_sampling_cinit(struct llama_sampling_cparams *params);