jmorganca 11 mēneši atpakaļ
vecāks
revīzija
7d0a452938
2 mainītis faili ar 11 papildinājumiem un 4 dzēšanām
  1. 2 2
      llama/README.md
  2. 9 2
      llama/runner/runner.go

+ 2 - 2
llama/README.md

@@ -10,9 +10,9 @@ Supported:
 - [x] Windows CUDA
 - [x] Windows CUDA
 - [x] Windows ROCm
 - [x] Windows ROCm
 - [x] Linux CUDA
 - [x] Linux CUDA
-- [ ] Linux ROCm
+- [x] Linux ROCm
 - [x] Llava
 - [x] Llava
-- [ ] Parallel Requests
+- [x] Parallel Requests
 
 
 Extra build steps are required for CUDA and ROCm on Windows since `nvcc` and `hipcc` both require using msvc as the host compiler. For these small dlls are created:
 Extra build steps are required for CUDA and ROCm on Windows since `nvcc` and `hipcc` both require using msvc as the host compiler. For these small dlls are created:
 
 

+ 9 - 2
llama/runner/runner.go

@@ -23,6 +23,9 @@ type Sequence struct {
 	// number of tokens evaluated
 	// number of tokens evaluated
 	nPast int
 	nPast int
 
 
+	// number of tokens predicted so far
+	numPredicted int
+
 	// tokens left to evaluate
 	// tokens left to evaluate
 	tokens []int
 	tokens []int
 
 
@@ -47,6 +50,7 @@ type Sequence struct {
 }
 }
 
 
 // prompt returns true if the prompt is still being processed
 // prompt returns true if the prompt is still being processed
+// TODO (jmorganca): clean up this logic
 func (s *Sequence) prompt() bool {
 func (s *Sequence) prompt() bool {
 	return s.nPast < len(s.tokens)-1
 	return s.nPast < len(s.tokens)-1
 }
 }
@@ -203,8 +207,8 @@ func (s *Server) run(ctx context.Context) {
 					continue
 					continue
 				}
 				}
 
 
-				// we've reached the context limit
-				if seq.nPast > s.numCtx {
+				// if past the num predict limit
+				if seq.numPredicted > seq.numPredict || seq.nPast > s.numCtx {
 					seq.doneReason = "limit"
 					seq.doneReason = "limit"
 					close(seq.responses)
 					close(seq.responses)
 					s.lc.KvCacheSeqRm(i, 0, -1)
 					s.lc.KvCacheSeqRm(i, 0, -1)
@@ -269,6 +273,9 @@ func (s *Server) run(ctx context.Context) {
 
 
 				seq.samplingCtx.Accept(s.lc, token, true)
 				seq.samplingCtx.Accept(s.lc, token, true)
 				piece := s.model.TokenToPiece(token)
 				piece := s.model.TokenToPiece(token)
+
+				seq.numPredicted++
+
 				slog.Info("sampled", "piece", piece)
 				slog.Info("sampled", "piece", piece)
 
 
 				// if it's an end of sequence token, break
 				// if it's an end of sequence token, break