Sfoglia il codice sorgente

fix issues with runner

jmorganca 11 mesi fa
parent
commit
ac090b6b71
2 ha cambiato i file con 18 aggiunte e 5 eliminazioni
  1. 11 0
      llama/runner/README.md
  2. 7 5
      llama/runner/runner.go

+ 11 - 0
llama/runner/README.md

@@ -1,5 +1,7 @@
 # `runner`
 # `runner`
 
 
+> Note: this is a work in progress
+
 A minimial runner for loading a model and running inference via a http web server.
 A minimial runner for loading a model and running inference via a http web server.
 
 
 ```
 ```
@@ -13,3 +15,12 @@ curl -X POST -H "Content-Type: application/json" -d '{"prompt": "hi"}' http://lo
 ```
 ```
 
 
 ### Embeddings
 ### Embeddings
+
+```
+curl -X POST -H "Content-Type: application/json" -d '{"prompt": "turn me into an embedding"}' http://localhost:8080/embeddings
+```
+
+### TODO
+
+- [ ] Parallization
+- [ ] More tests

+ 7 - 5
llama/runner/runner.go

@@ -55,7 +55,7 @@ func (s *Sequence) prompt() bool {
 	return s.nPast < len(s.tokens)-1
 	return s.nPast < len(s.tokens)-1
 }
 }
 
 
-func (s *Server) NewSequence(prompt string, stop []string, params *llama.SamplingParams, embedding bool) *Sequence {
+func (s *Server) NewSequence(prompt string, numPredict int, stop []string, params *llama.SamplingParams, embedding bool) *Sequence {
 	tokens, err := s.lc.Model().Tokenize(prompt, false, true)
 	tokens, err := s.lc.Model().Tokenize(prompt, false, true)
 	if err != nil {
 	if err != nil {
 		panic(err)
 		panic(err)
@@ -148,8 +148,10 @@ func (s *Server) run(ctx context.Context) {
 					continue
 					continue
 				}
 				}
 
 
+				hitLimit := seq.numPredict > 0 && seq.numPredicted > seq.numPredict
+
 				// if past the num predict limit
 				// if past the num predict limit
-				if seq.numPredicted > seq.numPredict || seq.nPast > s.numCtx {
+				if hitLimit || seq.nPast > s.numCtx {
 					seq.doneReason = "limit"
 					seq.doneReason = "limit"
 					close(seq.responses)
 					close(seq.responses)
 					s.lc.KvCacheSeqRm(i, 0, -1)
 					s.lc.KvCacheSeqRm(i, 0, -1)
@@ -317,7 +319,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 	samplingParams.Seed = uint32(req.Seed)
 	samplingParams.Seed = uint32(req.Seed)
 	samplingParams.Grammar = req.Grammar
 	samplingParams.Grammar = req.Grammar
 
 
-	seq := s.NewSequence(req.Prompt, req.Stop, &samplingParams, false)
+	seq := s.NewSequence(req.Prompt, req.NumPredict, req.Stop, &samplingParams, false)
 
 
 	// TODO (jmorganca): add to sequence queue instead of
 	// TODO (jmorganca): add to sequence queue instead of
 	// failing if a slot isn't available
 	// failing if a slot isn't available
@@ -368,7 +370,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
 
 
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Set("Content-Type", "application/json")
 
 
-	seq := s.NewSequence(req.Prompt, nil, nil, true)
+	seq := s.NewSequence(req.Prompt, 0, nil, nil, true)
 
 
 	s.mu.Lock()
 	s.mu.Lock()
 	for i, sq := range s.seqs {
 	for i, sq := range s.seqs {
@@ -413,7 +415,7 @@ func main() {
 	ppath := flag.String("projector", "", "Path to projector binary file")
 	ppath := flag.String("projector", "", "Path to projector binary file")
 	parallel := flag.Int("parallel", 1, "Number of sequences to handle simultaneously")
 	parallel := flag.Int("parallel", 1, "Number of sequences to handle simultaneously")
 	batchSize := flag.Int("batch-size", 512, "Batch size")
 	batchSize := flag.Int("batch-size", 512, "Batch size")
-	nGpuLayers := flag.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
+	nGpuLayers := flag.Int("num-gpu", 0, "Number of layers to offload to GPU")
 	mainGpu := flag.Int("main-gpu", 0, "Main GPU")
 	mainGpu := flag.Int("main-gpu", 0, "Main GPU")
 	flashAttention := flag.Bool("flash-attention", false, "Enable flash attention")
 	flashAttention := flag.Bool("flash-attention", false, "Enable flash attention")
 	numCtx := flag.Int("num-ctx", 2048, "Context (or KV cache) size")
 	numCtx := flag.Int("num-ctx", 2048, "Context (or KV cache) size")