jmorganca 11 月之前
父节点
当前提交
b39fca7088
共有 4 个文件被更改,包括 200 次插入11 次删除
  1. 64 0
      llama/llama.go
  2. 44 11
      llama/runner/runner.go
  3. 45 0
      llama/sampling_ext.cpp
  4. 47 0
      llama/sampling_ext.h

+ 64 - 0
llama/llama.go

@@ -28,6 +28,7 @@ package llama
 // #include "llama.h"
 // #include "clip.h"
 // #include "llava.h"
+// #include "sampling_ext.h"
 import "C"
 import (
 	"fmt"
@@ -244,6 +245,7 @@ func Quantize(infile, outfile string, ftype llm.FileType) error {
 	return nil
 }
 
+// llava
 type ClipContext struct {
 	c *C.struct_clip_ctx
 }
@@ -270,3 +272,65 @@ func NewLlavaImageEmbed(clipContext *ClipContext, data []byte) *LlavaImageEmbed
 func LlavaEvalImageEmbed(llamaContext *Context, embed *LlavaImageEmbed, nBatch int, nPast *int) {
 	C.llava_eval_image_embed(llamaContext.c, embed.c, C.int(nBatch), (*C.int)(unsafe.Pointer(nPast)))
 }
+
+// sampling
+// TODO: this is a temporary wrapper to allow calling C++ code from CGo
+type SamplingContext struct {
+	c *C.struct_llama_sampling_context
+}
+
+type SamplingParams struct {
+	TopK           int
+	TopP           float32
+	TfsZ           float32
+	TypicalP       float32
+	Temp           float32
+	PenaltyRepeat  float32
+	PenaltyFreq    float32
+	PenaltyPresent float32
+	Mirostat       int
+	MirostatTau    float32
+	MirostatEta    float32
+	PenalizeNl     bool
+	Seed           uint32
+}
+
+func NewSamplingContext(params SamplingParams) *SamplingContext {
+	var cparams C.struct_llama_sampling_cparams
+	cparams.top_k = C.int32_t(params.TopK)
+	cparams.top_p = C.float(params.TopP)
+	cparams.tfs_z = C.float(params.TfsZ)
+	cparams.typical_p = C.float(params.TypicalP)
+	cparams.temp = C.float(params.Temp)
+	cparams.penalty_repeat = C.float(params.PenaltyRepeat)
+	cparams.penalty_freq = C.float(params.PenaltyFreq)
+	cparams.penalty_present = C.float(params.PenaltyFreq)
+	cparams.mirostat = C.int32_t(params.Mirostat)
+	cparams.mirostat_tau = C.float(params.MirostatTau)
+	cparams.mirostat_eta = C.float(params.MirostatEta)
+	cparams.penalize_nl = C.bool(params.PenalizeNl)
+	cparams.seed = C.uint32_t(params.Seed)
+	return &SamplingContext{c: C.llama_sampling_cinit(&cparams)}
+}
+
+func (s *SamplingContext) Free() {
+	C.llama_sampling_cfree(s.c)
+}
+
+func (s *SamplingContext) Reset() {
+	C.llama_sampling_creset(s.c)
+}
+
+func (s *SamplingContext) Sample(ctxMain *Context, ctxConfig *Context, idx int) int {
+	// TODO (jmorganca): handle nil for all args
+	if ctxConfig == nil {
+		return int(C.llama_sampling_csample(s.c, ctxMain.c, nil, C.int(idx)))
+	}
+
+	return int(C.llama_sampling_csample(s.c, ctxMain.c, ctxConfig.c, C.int(idx)))
+
+}
+
+func (s *SamplingContext) Accept(ctxMain *Context, id int, applyGrammar bool) {
+	C.llama_sampling_caccept(s.c, ctxMain.c, C.llama_token(id), C.bool(applyGrammar))
+}

+ 44 - 11
llama/runner/runner.go

@@ -24,6 +24,8 @@ type Sequence struct {
 	tokens []int
 
 	responses chan string
+
+	samplingCtx *llama.SamplingContext
 }
 
 // prompt returns true if the prompt is still being processed
@@ -31,15 +33,41 @@ func (s *Sequence) prompt() bool {
 	return s.nPast < len(s.tokens)-1
 }
 
-func (s *Server) NewSequence(text string, w http.ResponseWriter) *Sequence {
-	tokens, err := s.lc.Model().Tokenize(text, 2048, true, true)
+func DefaultParams() llama.SamplingParams {
+	return llama.SamplingParams{}
+}
+
+func (s *Server) NewSequence(r Request, w http.ResponseWriter) *Sequence {
+	var samplingParams llama.SamplingParams
+	samplingParams.TopK = r.TopK
+	samplingParams.TopP = r.TopP
+	samplingParams.TfsZ = r.TFSZ
+	samplingParams.TypicalP = r.TypicalP
+	samplingParams.Temp = r.Temperature
+	samplingParams.PenaltyRepeat = r.RepeatPenalty
+	samplingParams.PenaltyFreq = r.FrequencyPenalty
+	samplingParams.PenaltyPresent = r.PresencePenalty
+	samplingParams.Mirostat = r.Mirostat
+	samplingParams.MirostatTau = r.MirostatTau
+	samplingParams.MirostatEta = r.MirostatEta
+	samplingParams.PenalizeNl = r.PenalizeNewline
+	samplingParams.Seed = uint32(r.Seed)
+
+	tokens, err := s.lc.Model().Tokenize(r.Prompt, 2048, false, true)
 	if err != nil {
 		panic(err)
 	}
 
+	sc := llama.NewSamplingContext(samplingParams)
+
+	for _, t := range tokens {
+		sc.Accept(s.lc, t, false)
+	}
+
 	return &Sequence{
-		tokens:    tokens,
-		responses: make(chan string, 1),
+		tokens:      tokens,
+		responses:   make(chan string, 1),
+		samplingCtx: sc,
 	}
 }
 
@@ -80,7 +108,6 @@ func (s *Server) run(ctx context.Context) {
 			slog.Info("Processing batch", "seqs", len(s.seqs))
 			s.mu.Lock()
 			for s.allNil() {
-				fmt.Println("wait")
 				s.cond.Wait() // Wait until an item is added
 			}
 			s.mu.Unlock()
@@ -133,8 +160,16 @@ func (s *Server) run(ctx context.Context) {
 				// sample a token
 				// TODO: sample based on the sequence
 				fmt.Println("Sampling token", i, ibatch[i])
-				logits := s.lc.GetLogitsIth(ibatch[i])
-				token := s.lc.SampleTokenGreedy(logits)
+				fmt.Println("calling sample", s.lc, nil, ibatch[i])
+				token := seq.samplingCtx.Sample(s.lc, nil, ibatch[i])
+				seq.samplingCtx.Accept(s.lc, token, true)
+
+				// logits := s.lc.GetLogitsIth(ibatch[i])
+				// token := s.lc.SampleTokenGreedy(logits)
+				fmt.Println("sampled", token, s.model.TokenToPiece(token))
+
+				seq.responses <- s.model.TokenToPiece(token)
+				seq.tokens = []int{token}
 
 				// if it's an end of sequence token, break
 				// TODO: just end this sequence
@@ -145,9 +180,6 @@ func (s *Server) run(ctx context.Context) {
 					s.seqs[i] = nil
 					continue
 				}
-
-				seq.responses <- s.model.TokenToPiece(token)
-				seq.tokens = []int{token}
 			}
 
 			batch.Clear()
@@ -168,6 +200,7 @@ type Response struct {
 
 func (s *Server) handler(w http.ResponseWriter, r *http.Request) {
 	var request Request
+	request.Options = api.DefaultOptions()
 	if err := json.NewDecoder(r.Body).Decode(&request); err != nil {
 		http.Error(w, "Bad request", http.StatusBadRequest)
 		return
@@ -178,7 +211,7 @@ func (s *Server) handler(w http.ResponseWriter, r *http.Request) {
 	w.Header().Set("Transfer-Encoding", "chunked")
 	w.WriteHeader(http.StatusOK)
 
-	seq := s.NewSequence(request.Prompt, w)
+	seq := s.NewSequence(request, w)
 
 	s.mu.Lock()
 	for i, sq := range s.seqs {

+ 45 - 0
llama/sampling_ext.cpp

@@ -0,0 +1,45 @@
+// TODO: this is a temporary wrapper to allow calling C++ code from CGo
+#include "sampling.h"
+#include "sampling_ext.h"
+
+struct llama_sampling_context* llama_sampling_cinit(struct llama_sampling_cparams *params) {
+    llama_sampling_params sparams;
+    sparams.top_k = params->top_k;
+    sparams.top_p = params->top_p;
+    sparams.tfs_z = params->tfs_z;
+    sparams.typical_p = params->typical_p;
+    sparams.temp = params->temp;
+    sparams.penalty_repeat = params->penalty_repeat;
+    sparams.penalty_freq = params->penalty_freq;
+    sparams.penalty_present = params->penalty_present;
+    sparams.mirostat = params->mirostat;
+    sparams.mirostat_tau = params->mirostat_tau;
+    sparams.mirostat_eta = params->mirostat_eta;
+    sparams.penalize_nl = params->penalize_nl;
+    sparams.seed = params->seed;
+    return llama_sampling_init(sparams);
+}
+
+void llama_sampling_cfree(struct llama_sampling_context * ctx){
+    llama_sampling_free(ctx);
+}
+
+void llama_sampling_creset(struct llama_sampling_context * ctx){
+    llama_sampling_reset(ctx);
+}
+
+llama_token llama_sampling_csample(
+        struct llama_sampling_context * ctx_sampling,
+        struct llama_context * ctx_main,
+        struct llama_context * ctx_cfg,
+        int idx) {
+    return llama_sampling_sample(ctx_sampling, ctx_main, ctx_cfg, idx);
+}
+
+void llama_sampling_caccept(
+        struct llama_sampling_context * ctx_sampling,
+        struct llama_context * ctx_main,
+        llama_token id,
+        bool apply_grammar) {
+    llama_sampling_accept(ctx_sampling, ctx_main, id, apply_grammar);
+}

+ 47 - 0
llama/sampling_ext.h

@@ -0,0 +1,47 @@
+// TODO: this is a temporary wrapper to allow calling C++ code from CGo
+#ifndef LLAMA_SAMPLING_EXT_H
+#define LLAMA_SAMPLING_EXT_H
+
+#include "llama.h"
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+struct llama_sampling_cparams {
+    int32_t     top_k;
+    float       top_p;
+    float       tfs_z;
+    float       typical_p;
+    float       temp;
+    float       penalty_repeat;
+    float       penalty_freq;
+    float       penalty_present;
+    int32_t     mirostat;
+    float       mirostat_tau;
+    float       mirostat_eta;
+    bool        penalize_nl;
+    uint32_t    seed;
+};
+
+struct llama_sampling_context* llama_sampling_cinit(struct llama_sampling_cparams *params);
+void llama_sampling_cfree(struct llama_sampling_context * ctx);
+void llama_sampling_creset(struct llama_sampling_context * ctx);
+
+llama_token llama_sampling_csample(
+        struct llama_sampling_context * ctx_sampling,
+        struct llama_context * ctx_main,
+        struct llama_context * ctx_cfg,
+        int idx);
+
+void llama_sampling_caccept(
+        struct llama_sampling_context * ctx_sampling,
+        struct llama_context * ctx_main,
+        llama_token id,
+        bool apply_grammar);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif // LLAMA_SAMPLING_EXT_H