|
@@ -65,8 +65,8 @@ type Sequence struct {
|
|
// number of tokens to predict
|
|
// number of tokens to predict
|
|
numPredict int
|
|
numPredict int
|
|
|
|
|
|
- // set of samplers to run on generated logits
|
|
|
|
- samplers []sample.Sampler
|
|
|
|
|
|
+ // sampler with transforms to run on generated logits
|
|
|
|
+ sampler sample.Sampler
|
|
|
|
|
|
// channel to send back the embedding if embedding only
|
|
// channel to send back the embedding if embedding only
|
|
embedding chan []float32
|
|
embedding chan []float32
|
|
@@ -93,7 +93,7 @@ type NewSequenceParams struct {
|
|
numPredict int
|
|
numPredict int
|
|
stop []string
|
|
stop []string
|
|
numKeep int32
|
|
numKeep int32
|
|
- samplers []sample.Sampler
|
|
|
|
|
|
+ sampler sample.Sampler
|
|
embedding bool
|
|
embedding bool
|
|
}
|
|
}
|
|
|
|
|
|
@@ -136,7 +136,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
|
|
responses: make(chan string, 100),
|
|
responses: make(chan string, 100),
|
|
quit: make(chan bool, 1),
|
|
quit: make(chan bool, 1),
|
|
embedding: make(chan []float32, 1),
|
|
embedding: make(chan []float32, 1),
|
|
- samplers: params.samplers,
|
|
|
|
|
|
+ sampler: params.sampler,
|
|
embeddingOnly: params.embedding,
|
|
embeddingOnly: params.embedding,
|
|
stop: params.stop,
|
|
stop: params.stop,
|
|
numKeep: params.numKeep,
|
|
numKeep: params.numKeep,
|
|
@@ -393,13 +393,7 @@ func (s *Server) processBatch() error {
|
|
return fmt.Errorf("failed to decode batch: %w", err)
|
|
return fmt.Errorf("failed to decode batch: %w", err)
|
|
}
|
|
}
|
|
|
|
|
|
- f32s := modelOutput.Floats()
|
|
|
|
-
|
|
|
|
- // TODO(jessegross): This will no longer be necessary once the sampling interface takes f32s
|
|
|
|
- logits := make([]float64, len(f32s))
|
|
|
|
- for i, f32 := range f32s {
|
|
|
|
- logits[i] = float64(f32)
|
|
|
|
- }
|
|
|
|
|
|
+ logits := modelOutput.Floats()
|
|
|
|
|
|
for i, seq := range s.seqs {
|
|
for i, seq := range s.seqs {
|
|
if seq == nil {
|
|
if seq == nil {
|
|
@@ -433,15 +427,13 @@ func (s *Server) processBatch() error {
|
|
}
|
|
}
|
|
|
|
|
|
// sample a token
|
|
// sample a token
|
|
- vocabSize := len(f32s) / len(options.Outputs)
|
|
|
|
- tokens, err := sample.Sample(logits[seq.iBatch*vocabSize:(seq.iBatch+1)*vocabSize], seq.samplers...)
|
|
|
|
|
|
+ vocabSize := len(logits) / len(options.Outputs)
|
|
|
|
+
|
|
|
|
+ token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
|
|
if err != nil {
|
|
if err != nil {
|
|
- return err
|
|
|
|
|
|
+ return fmt.Errorf("failed to sample token: %w", err)
|
|
}
|
|
}
|
|
|
|
|
|
- // TODO(jessegross): Sampler will output a single int32 in the future
|
|
|
|
- token := int32(tokens[0])
|
|
|
|
-
|
|
|
|
// if it's an end of sequence token, break
|
|
// if it's an end of sequence token, break
|
|
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
|
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
|
|
// TODO (jmorganca): we should send this back
|
|
// TODO (jmorganca): we should send this back
|
|
@@ -565,27 +557,6 @@ type CompletionResponse struct {
|
|
Timings Timings `json:"timings"`
|
|
Timings Timings `json:"timings"`
|
|
}
|
|
}
|
|
|
|
|
|
-func getSamplers(_ CompletionRequest) []sample.Sampler {
|
|
|
|
- // TODO(jessegross): Waiting for sampling code
|
|
|
|
-
|
|
|
|
- /*samplingParams.TopK = req.TopK
|
|
|
|
- samplingParams.TopP = req.TopP
|
|
|
|
- samplingParams.MinP = req.MinP
|
|
|
|
- samplingParams.TypicalP = req.TypicalP
|
|
|
|
- samplingParams.Temp = req.Temperature
|
|
|
|
- samplingParams.RepeatLastN = req.RepeatLastN
|
|
|
|
- samplingParams.PenaltyRepeat = req.RepeatPenalty
|
|
|
|
- samplingParams.PenaltyFreq = req.FrequencyPenalty
|
|
|
|
- samplingParams.PenaltyPresent = req.PresencePenalty
|
|
|
|
- samplingParams.Mirostat = req.Mirostat
|
|
|
|
- samplingParams.MirostatTau = req.MirostatTau
|
|
|
|
- samplingParams.MirostatEta = req.MirostatEta
|
|
|
|
- samplingParams.Seed = uint32(req.Seed)
|
|
|
|
- samplingParams.Grammar = req.Grammar*/
|
|
|
|
-
|
|
|
|
- return []sample.Sampler{sample.Greedy()}
|
|
|
|
-}
|
|
|
|
-
|
|
|
|
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|
var req CompletionRequest
|
|
var req CompletionRequest
|
|
req.Options = Options(api.DefaultOptions())
|
|
req.Options = Options(api.DefaultOptions())
|
|
@@ -604,11 +575,23 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ sampler, err := sample.NewSampler(
|
|
|
|
+ req.Temperature,
|
|
|
|
+ req.TopK,
|
|
|
|
+ req.TopP,
|
|
|
|
+ req.MinP,
|
|
|
|
+ req.Seed,
|
|
|
|
+ )
|
|
|
|
+ if err != nil {
|
|
|
|
+ http.Error(w, fmt.Sprintf("Failed to create sampler: %v", err), http.StatusInternalServerError)
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+
|
|
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
|
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
|
numPredict: req.NumPredict,
|
|
numPredict: req.NumPredict,
|
|
stop: req.Stop,
|
|
stop: req.Stop,
|
|
numKeep: int32(req.NumKeep),
|
|
numKeep: int32(req.NumKeep),
|
|
- samplers: getSamplers(req),
|
|
|
|
|
|
+ sampler: sampler,
|
|
embedding: false,
|
|
embedding: false,
|
|
})
|
|
})
|
|
if err != nil {
|
|
if err != nil {
|