Bruce MacDonald 2 months ago
parent
commit
64f95067ba
3 changed files with 105 additions and 62 deletions
  1. 26 16
      llama/runner/runner.go
  2. 27 24
      llama/runner/stop.go
  3. 52 22
      llama/runner/stop_test.go

+ 26 - 16
llama/runner/runner.go

@@ -50,8 +50,9 @@ type Sequence struct {
 	// inputs that have been added to a batch but not yet submitted to Decode
 	// inputs that have been added to a batch but not yet submitted to Decode
 	pendingInputs []input
 	pendingInputs []input
 
 
+	// TODO: update this comment
 	// tokens that have been generated but not returned yet (e.g. for stop sequences)
 	// tokens that have been generated but not returned yet (e.g. for stop sequences)
-	pendingResponses []string
+	pendingResponses []CompletionResponse
 
 
 	// input cache being used by this sequence
 	// input cache being used by this sequence
 	cache *InputCacheSlot
 	cache *InputCacheSlot
@@ -87,6 +88,9 @@ type Sequence struct {
 
 
 	logits []float32
 	logits []float32
 
 
+	// number of logprobs to return with the completion response
+	logprobs int
+
 	// Metrics
 	// Metrics
 	startProcessingTime time.Time
 	startProcessingTime time.Time
 	startGenerationTime time.Time
 	startGenerationTime time.Time
@@ -152,7 +156,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
 		numPromptInputs:     len(inputs),
 		numPromptInputs:     len(inputs),
 		startProcessingTime: startTime,
 		startProcessingTime: startTime,
 		numPredict:          params.numPredict,
 		numPredict:          params.numPredict,
-		pendingResponses:    make([]string, 0),
+		pendingResponses:    make([]CompletionResponse, 0),
 		responses:           make(chan CompletionResponse, 100),
 		responses:           make(chan CompletionResponse, 100),
 		quit:                make(chan bool, 1),
 		quit:                make(chan bool, 1),
 		embedding:           make(chan []float32, 1),
 		embedding:           make(chan []float32, 1),
@@ -281,8 +285,11 @@ func flushPending(seq *Sequence) bool {
 	if len(seq.pendingResponses) == 0 {
 	if len(seq.pendingResponses) == 0 {
 		return true
 		return true
 	}
 	}
-	content := strings.Join(seq.pendingResponses, "")
-	seq.pendingResponses = []string{}
+	content := ""
+	for _, resp := range seq.pendingResponses {
+		content += resp.Content
+	}
+	seq.pendingResponses = []CompletionResponse{}
 
 
 	// Check if there are any partial UTF-8 characters remaining.
 	// Check if there are any partial UTF-8 characters remaining.
 	// We already check and queue as we are generating but some may
 	// We already check and queue as we are generating but some may
@@ -362,27 +369,27 @@ func (s *Server) run(ctx context.Context) {
 	}
 	}
 }
 }
 
 
-// TokenData represents probability information for a token
-type TokenData struct {
+// TokenProbs represents probability information for a token
+type TokenProbs struct {
 	TokenID int
 	TokenID int
 	Logit   float32
 	Logit   float32
 	Prob    float32
 	Prob    float32
 	LogProb float32
 	LogProb float32
 }
 }
 
 
-// getTokenProbabilities returns sorted token probabilities for a specific token index
-func (s *Server) getTokenProbabilities(seq *Sequence) []TokenData {
+// probs returns sorted token probabilities for a specific token index
+func (s *Server) probs(seq *Sequence) []TokenProbs {
 	// Get logits for the specific token index
 	// Get logits for the specific token index
 	logits := s.lc.GetLogits()
 	logits := s.lc.GetLogits()
 	seq.logits = make([]float32, len(logits))
 	seq.logits = make([]float32, len(logits))
 	copy(seq.logits, logits)
 	copy(seq.logits, logits)
 
 
 	vocabSize := s.model.NumVocab()
 	vocabSize := s.model.NumVocab()
-	probs := make([]TokenData, vocabSize)
+	probs := make([]TokenProbs, vocabSize)
 
 
 	// Initialize token data with logits
 	// Initialize token data with logits
 	for i := 0; i < vocabSize; i++ {
 	for i := 0; i < vocabSize; i++ {
-		probs[i] = TokenData{
+		probs[i] = TokenProbs{
 			TokenID: i,
 			TokenID: i,
 			Logit:   logits[i],
 			Logit:   logits[i],
 		}
 		}
@@ -546,10 +553,9 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 
 
 		seq.numPredicted++
 		seq.numPredicted++
 
 
-		// TODO: only do this when flag specified
-		probs := s.getTokenProbabilities(seq)
-		for i := range 10 {
-			slog.Debug("top 10 tokens", "token", probs[i].TokenID, "prob", probs[i].Prob, "logit", probs[i].Logit, "piece", s.model.TokenToPiece(probs[i].TokenID))
+		if seq.logprobs > 0 {
+			// TODO: return selected token in logprobs always
+			// probs := s.probs(seq)
 		}
 		}
 
 
 		// if it's an end of sequence token, break
 		// if it's an end of sequence token, break
@@ -564,8 +570,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 
 
 		seq.inputs = []input{{token: token}}
 		seq.inputs = []input{{token: token}}
 
 
-		seq.pendingResponses = append(seq.pendingResponses, piece)
-		sequence := strings.Join(seq.pendingResponses, "")
+		// TODO: add probs here
+		seq.pendingResponses = append(seq.pendingResponses, CompletionResponse{Content: piece})
+		var sequence string
+		for _, r := range seq.pendingResponses {
+			sequence += r.Content
+		}
 
 
 		if ok, stop := findStop(sequence, seq.stop); ok {
 		if ok, stop := findStop(sequence, seq.stop); ok {
 			slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
 			slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)

+ 27 - 24
llama/runner/stop.go

@@ -29,40 +29,43 @@ func containsStopSuffix(sequence string, stops []string) bool {
 // truncateStop removes the provided stop string from pieces,
 // truncateStop removes the provided stop string from pieces,
 // returning the partial pieces with stop removed, including truncating
 // returning the partial pieces with stop removed, including truncating
 // the last piece if required (and signalling if this was the case)
 // the last piece if required (and signalling if this was the case)
-func truncateStop(pieces []string, stop string) ([]string, bool) {
-	joined := strings.Join(pieces, "")
+func truncateStop(pieces []CompletionResponse, stop string) ([]CompletionResponse, bool) {
+	// Build complete string and find stop position
+	var completeStr string
+	for _, piece := range pieces {
+		completeStr += piece.Content
+	}
 
 
-	index := strings.Index(joined, stop)
-	if index == -1 {
+	stopStart := strings.Index(completeStr, stop)
+	if stopStart == -1 {
 		return pieces, false
 		return pieces, false
 	}
 	}
 
 
-	joined = joined[:index]
-
-	// Split truncated string back into pieces of original lengths
-	lengths := make([]int, len(pieces))
-	for i, piece := range pieces {
-		lengths[i] = len(piece)
-	}
+	// Build result up to stop position
+	result := make([]CompletionResponse, 0)
+	accumulated := 0
 
 
-	var result []string
-	tokenTruncated := false
-	start := 0
-	for _, length := range lengths {
-		if start >= len(joined) {
-			break
+	truncated := false
+	for _, piece := range pieces {
+		if accumulated+len(piece.Content) <= stopStart {
+			result = append(result, piece)
+			accumulated += len(piece.Content)
+			continue
 		}
 		}
 
 
-		end := start + length
-		if end > len(joined) {
-			end = len(joined)
-			tokenTruncated = true
+		if accumulated < stopStart {
+			truncPiece := piece
+			truncPiece.Content = piece.Content[:stopStart-accumulated]
+			if len(truncPiece.Content) > 0 {
+				result = append(result, truncPiece)
+				truncated = true
+			}
 		}
 		}
-		result = append(result, joined[start:end])
-		start = end
+		break
 	}
 	}
 
 
-	return result, tokenTruncated
+	// Signal if we had to truncate the last piece
+	return result, truncated
 }
 }
 
 
 func incompleteUnicode(token string) bool {
 func incompleteUnicode(token string) bool {

+ 52 - 22
llama/runner/stop_test.go

@@ -8,44 +8,74 @@ import (
 func TestTruncateStop(t *testing.T) {
 func TestTruncateStop(t *testing.T) {
 	tests := []struct {
 	tests := []struct {
 		name          string
 		name          string
-		pieces        []string
+		pieces        []CompletionResponse
 		stop          string
 		stop          string
-		expected      []string
+		expected      []CompletionResponse
 		expectedTrunc bool
 		expectedTrunc bool
 	}{
 	}{
 		{
 		{
-			name:          "Single word",
-			pieces:        []string{"hello", "world"},
-			stop:          "world",
-			expected:      []string{"hello"},
+			name: "Single word",
+			pieces: []CompletionResponse{
+				{Content: "hello"},
+				{Content: "world"},
+			},
+			stop: "world",
+			expected: []CompletionResponse{
+				{Content: "hello"},
+			},
 			expectedTrunc: false,
 			expectedTrunc: false,
 		},
 		},
 		{
 		{
-			name:          "Partial",
-			pieces:        []string{"hello", "wor"},
-			stop:          "or",
-			expected:      []string{"hello", "w"},
+			name: "Partial",
+			pieces: []CompletionResponse{
+				{Content: "hello"},
+				{Content: "wor"},
+			},
+			stop: "or",
+			expected: []CompletionResponse{
+				{Content: "hello"},
+				{Content: "w"},
+			},
 			expectedTrunc: true,
 			expectedTrunc: true,
 		},
 		},
 		{
 		{
-			name:          "Suffix",
-			pieces:        []string{"Hello", " there", "!"},
-			stop:          "!",
-			expected:      []string{"Hello", " there"},
+			name: "Suffix",
+			pieces: []CompletionResponse{
+				{Content: "Hello"},
+				{Content: " there"},
+				{Content: "!"},
+			},
+			stop: "!",
+			expected: []CompletionResponse{
+				{Content: "Hello"},
+				{Content: " there"},
+			},
 			expectedTrunc: false,
 			expectedTrunc: false,
 		},
 		},
 		{
 		{
-			name:          "Suffix partial",
-			pieces:        []string{"Hello", " the", "re!"},
-			stop:          "there!",
-			expected:      []string{"Hello", " "},
+			name: "Suffix partial",
+			pieces: []CompletionResponse{
+				{Content: "Hello"},
+				{Content: " the"},
+				{Content: "re!"},
+			},
+			stop: "there!",
+			expected: []CompletionResponse{
+				{Content: "Hello"},
+				{Content: " "},
+			},
 			expectedTrunc: true,
 			expectedTrunc: true,
 		},
 		},
 		{
 		{
-			name:          "Middle",
-			pieces:        []string{"hello", " wor"},
-			stop:          "llo w",
-			expected:      []string{"he"},
+			name: "Middle",
+			pieces: []CompletionResponse{
+				{Content: "hello"},
+				{Content: " wor"},
+			},
+			stop: "llo w",
+			expected: []CompletionResponse{
+				{Content: "he"},
+			},
 			expectedTrunc: true,
 			expectedTrunc: true,
 		},
 		},
 	}
 	}