소스 검색

runner.go: Implement RepeatLastN to penalize repeated tokens

RepeatLastN is a user-facing parameter that is exposed that is exposed
through the APIs but is not currently plumbed through.
Jesse Gross 8 달 전
부모
커밋
477f529d26
4개의 변경된 파일5개의 추가작업 그리고 0개의 파일을 삭제
  1. 2 0
      llama/llama.go
  2. 1 0
      llama/runner/runner.go
  3. 1 0
      llama/sampling_ext.cpp
  4. 1 0
      llama/sampling_ext.h

+ 2 - 0
llama/llama.go

@@ -390,6 +390,7 @@ type SamplingParams struct {
 	TfsZ           float32
 	TfsZ           float32
 	TypicalP       float32
 	TypicalP       float32
 	Temp           float32
 	Temp           float32
+	RepeatLastN    int
 	PenaltyRepeat  float32
 	PenaltyRepeat  float32
 	PenaltyFreq    float32
 	PenaltyFreq    float32
 	PenaltyPresent float32
 	PenaltyPresent float32
@@ -408,6 +409,7 @@ func NewSamplingContext(params SamplingParams) *SamplingContext {
 	cparams.tfs_z = C.float(params.TfsZ)
 	cparams.tfs_z = C.float(params.TfsZ)
 	cparams.typical_p = C.float(params.TypicalP)
 	cparams.typical_p = C.float(params.TypicalP)
 	cparams.temp = C.float(params.Temp)
 	cparams.temp = C.float(params.Temp)
+	cparams.penalty_last_n = C.int32_t(params.RepeatLastN)
 	cparams.penalty_repeat = C.float(params.PenaltyRepeat)
 	cparams.penalty_repeat = C.float(params.PenaltyRepeat)
 	cparams.penalty_freq = C.float(params.PenaltyFreq)
 	cparams.penalty_freq = C.float(params.PenaltyFreq)
 	cparams.penalty_present = C.float(params.PenaltyFreq)
 	cparams.penalty_present = C.float(params.PenaltyFreq)

+ 1 - 0
llama/runner/runner.go

@@ -402,6 +402,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 	samplingParams.TfsZ = req.TFSZ
 	samplingParams.TfsZ = req.TFSZ
 	samplingParams.TypicalP = req.TypicalP
 	samplingParams.TypicalP = req.TypicalP
 	samplingParams.Temp = req.Temperature
 	samplingParams.Temp = req.Temperature
+	samplingParams.RepeatLastN = req.RepeatLastN
 	samplingParams.PenaltyRepeat = req.RepeatPenalty
 	samplingParams.PenaltyRepeat = req.RepeatPenalty
 	samplingParams.PenaltyFreq = req.FrequencyPenalty
 	samplingParams.PenaltyFreq = req.FrequencyPenalty
 	samplingParams.PenaltyPresent = req.PresencePenalty
 	samplingParams.PenaltyPresent = req.PresencePenalty

+ 1 - 0
llama/sampling_ext.cpp

@@ -10,6 +10,7 @@ struct llama_sampling_context *llama_sampling_cinit(struct llama_sampling_cparam
     sparams.tfs_z = params->tfs_z;
     sparams.tfs_z = params->tfs_z;
     sparams.typical_p = params->typical_p;
     sparams.typical_p = params->typical_p;
     sparams.temp = params->temp;
     sparams.temp = params->temp;
+    sparams.penalty_last_n = params->penalty_last_n;
     sparams.penalty_repeat = params->penalty_repeat;
     sparams.penalty_repeat = params->penalty_repeat;
     sparams.penalty_freq = params->penalty_freq;
     sparams.penalty_freq = params->penalty_freq;
     sparams.penalty_present = params->penalty_present;
     sparams.penalty_present = params->penalty_present;

+ 1 - 0
llama/sampling_ext.h

@@ -16,6 +16,7 @@ extern "C"
         float tfs_z;
         float tfs_z;
         float typical_p;
         float typical_p;
         float temp;
         float temp;
+        int32_t penalty_last_n;
         float penalty_repeat;
         float penalty_repeat;
         float penalty_freq;
         float penalty_freq;
         float penalty_present;
         float penalty_present;