|
@@ -55,7 +55,7 @@ func (s *Sequence) prompt() bool {
|
|
return s.nPast < len(s.tokens)-1
|
|
return s.nPast < len(s.tokens)-1
|
|
}
|
|
}
|
|
|
|
|
|
-func (s *Server) NewSequence(prompt string, stop []string, params *llama.SamplingParams, embedding bool) *Sequence {
|
|
|
|
|
|
+func (s *Server) NewSequence(prompt string, numPredict int, stop []string, params *llama.SamplingParams, embedding bool) *Sequence {
|
|
tokens, err := s.lc.Model().Tokenize(prompt, false, true)
|
|
tokens, err := s.lc.Model().Tokenize(prompt, false, true)
|
|
if err != nil {
|
|
if err != nil {
|
|
panic(err)
|
|
panic(err)
|
|
@@ -148,8 +148,10 @@ func (s *Server) run(ctx context.Context) {
|
|
continue
|
|
continue
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ hitLimit := seq.numPredict > 0 && seq.numPredicted > seq.numPredict
|
|
|
|
+
|
|
// if past the num predict limit
|
|
// if past the num predict limit
|
|
- if seq.numPredicted > seq.numPredict || seq.nPast > s.numCtx {
|
|
|
|
|
|
+ if hitLimit || seq.nPast > s.numCtx {
|
|
seq.doneReason = "limit"
|
|
seq.doneReason = "limit"
|
|
close(seq.responses)
|
|
close(seq.responses)
|
|
s.lc.KvCacheSeqRm(i, 0, -1)
|
|
s.lc.KvCacheSeqRm(i, 0, -1)
|
|
@@ -317,7 +319,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|
samplingParams.Seed = uint32(req.Seed)
|
|
samplingParams.Seed = uint32(req.Seed)
|
|
samplingParams.Grammar = req.Grammar
|
|
samplingParams.Grammar = req.Grammar
|
|
|
|
|
|
- seq := s.NewSequence(req.Prompt, req.Stop, &samplingParams, false)
|
|
|
|
|
|
+ seq := s.NewSequence(req.Prompt, req.NumPredict, req.Stop, &samplingParams, false)
|
|
|
|
|
|
// TODO (jmorganca): add to sequence queue instead of
|
|
// TODO (jmorganca): add to sequence queue instead of
|
|
// failing if a slot isn't available
|
|
// failing if a slot isn't available
|
|
@@ -368,7 +370,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Header().Set("Content-Type", "application/json")
|
|
|
|
|
|
- seq := s.NewSequence(req.Prompt, nil, nil, true)
|
|
|
|
|
|
+ seq := s.NewSequence(req.Prompt, 0, nil, nil, true)
|
|
|
|
|
|
s.mu.Lock()
|
|
s.mu.Lock()
|
|
for i, sq := range s.seqs {
|
|
for i, sq := range s.seqs {
|
|
@@ -413,7 +415,7 @@ func main() {
|
|
ppath := flag.String("projector", "", "Path to projector binary file")
|
|
ppath := flag.String("projector", "", "Path to projector binary file")
|
|
parallel := flag.Int("parallel", 1, "Number of sequences to handle simultaneously")
|
|
parallel := flag.Int("parallel", 1, "Number of sequences to handle simultaneously")
|
|
batchSize := flag.Int("batch-size", 512, "Batch size")
|
|
batchSize := flag.Int("batch-size", 512, "Batch size")
|
|
- nGpuLayers := flag.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
|
|
|
|
|
|
+ nGpuLayers := flag.Int("num-gpu", 0, "Number of layers to offload to GPU")
|
|
mainGpu := flag.Int("main-gpu", 0, "Main GPU")
|
|
mainGpu := flag.Int("main-gpu", 0, "Main GPU")
|
|
flashAttention := flag.Bool("flash-attention", false, "Enable flash attention")
|
|
flashAttention := flag.Bool("flash-attention", false, "Enable flash attention")
|
|
numCtx := flag.Int("num-ctx", 2048, "Context (or KV cache) size")
|
|
numCtx := flag.Int("num-ctx", 2048, "Context (or KV cache) size")
|