Browse Source

runner: enable returning more info from runner processing

Currently we return only the text predicted from the LLM. This was nice in
that it was simple, but there may be other info we want to know from the
processing. This change adds the ability to return more information from the
runner than just the text predicted.

A follow up change will add logprobs to the response returned from the
runner using this structure.
Bruce MacDonald 2 months ago
parent
commit
905da35468
5 changed files with 176 additions and 100 deletions
  1. 26 23
      runner/common/stop.go
  2. 69 23
      runner/common/stop_test.go
  3. 23 0
      runner/common/types.go
  4. 29 27
      runner/llamarunner/runner.go
  5. 29 27
      runner/ollamarunner/runner.go

+ 26 - 23
runner/common/stop.go

@@ -29,40 +29,43 @@ func ContainsStopSuffix(sequence string, stops []string) bool {
 // truncateStop removes the provided stop string from pieces,
 // returning the partial pieces with stop removed, including truncating
 // the last piece if required (and signalling if this was the case)
-func TruncateStop(pieces []string, stop string) ([]string, bool) {
-	joined := strings.Join(pieces, "")
-
-	index := strings.Index(joined, stop)
-	if index == -1 {
-		return pieces, false
+func TruncateStop(resps []CompletionResponse, stop string) ([]CompletionResponse, bool) {
+	var sequence string
+	for _, resp := range resps {
+		sequence += resp.Content
 	}
 
-	joined = joined[:index]
+	idx := strings.Index(sequence, stop)
+	if idx < 0 {
+		return resps, false
+	}
 
-	// Split truncated string back into pieces of original lengths
-	lengths := make([]int, len(pieces))
-	for i, piece := range pieces {
-		lengths[i] = len(piece)
+	truncated := sequence[:idx]
+	if len(truncated) == 0 {
+		return nil, true
 	}
 
-	var result []string
-	tokenTruncated := false
-	start := 0
-	for _, length := range lengths {
-		if start >= len(joined) {
+	result := make([]CompletionResponse, 0, len(resps))
+
+	// Track position in truncated sequence
+	pos := 0
+	truncationHappened := false
+	for _, resp := range resps {
+		if pos >= len(truncated) {
 			break
 		}
 
-		end := start + length
-		if end > len(joined) {
-			end = len(joined)
-			tokenTruncated = true
+		chunk := truncated[pos:min(pos+len(resp.Content), len(truncated))]
+		if len(chunk) < len(resp.Content) {
+			truncationHappened = true
+		}
+		if len(chunk) > 0 {
+			result = append(result, CompletionResponse{Content: chunk})
 		}
-		result = append(result, joined[start:end])
-		start = end
+		pos += len(resp.Content)
 	}
 
-	return result, tokenTruncated
+	return result, truncationHappened
 }
 
 func IncompleteUnicode(token string) bool {

+ 69 - 23
runner/common/stop_test.go

@@ -1,6 +1,7 @@
 package common
 
 import (
+	"fmt"
 	"reflect"
 	"testing"
 )
@@ -8,44 +9,74 @@ import (
 func TestTruncateStop(t *testing.T) {
 	tests := []struct {
 		name          string
-		pieces        []string
+		pieces        []CompletionResponse
 		stop          string
-		expected      []string
+		expected      []CompletionResponse
 		expectedTrunc bool
 	}{
 		{
-			name:          "Single word",
-			pieces:        []string{"hello", "world"},
-			stop:          "world",
-			expected:      []string{"hello"},
+			name: "Single word",
+			pieces: []CompletionResponse{
+				{Content: "Hello"},
+				{Content: "world"},
+			},
+			stop: "world",
+			expected: []CompletionResponse{
+				{Content: "Hello"},
+			},
 			expectedTrunc: false,
 		},
 		{
-			name:          "Partial",
-			pieces:        []string{"hello", "wor"},
-			stop:          "or",
-			expected:      []string{"hello", "w"},
+			name: "Partial",
+			pieces: []CompletionResponse{
+				{Content: "Hello"},
+				{Content: " wor"},
+			},
+			stop: "or",
+			expected: []CompletionResponse{
+				{Content: "Hello"},
+				{Content: " w"},
+			},
 			expectedTrunc: true,
 		},
 		{
-			name:          "Suffix",
-			pieces:        []string{"Hello", " there", "!"},
-			stop:          "!",
-			expected:      []string{"Hello", " there"},
+			name: "Suffix",
+			pieces: []CompletionResponse{
+				{Content: "Hello"},
+				{Content: " there"},
+				{Content: "!"},
+			},
+			stop: "!",
+			expected: []CompletionResponse{
+				{Content: "Hello"},
+				{Content: " there"},
+			},
 			expectedTrunc: false,
 		},
 		{
-			name:          "Suffix partial",
-			pieces:        []string{"Hello", " the", "re!"},
-			stop:          "there!",
-			expected:      []string{"Hello", " "},
+			name: "Suffix partial",
+			pieces: []CompletionResponse{
+				{Content: "Hello"},
+				{Content: " the"},
+				{Content: "re!"},
+			},
+			stop: "there!",
+			expected: []CompletionResponse{
+				{Content: "Hello"},
+				{Content: " "},
+			},
 			expectedTrunc: true,
 		},
 		{
-			name:          "Middle",
-			pieces:        []string{"hello", " wor"},
-			stop:          "llo w",
-			expected:      []string{"he"},
+			name: "Middle",
+			pieces: []CompletionResponse{
+				{Content: "Hello"},
+				{Content: " wo"},
+			},
+			stop: "llo w",
+			expected: []CompletionResponse{
+				{Content: "He"},
+			},
 			expectedTrunc: true,
 		},
 	}
@@ -54,12 +85,27 @@ func TestTruncateStop(t *testing.T) {
 		t.Run(tt.name, func(t *testing.T) {
 			result, resultTrunc := TruncateStop(tt.pieces, tt.stop)
 			if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
-				t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc)
+				t.Errorf("truncateStop(%v, %v):\n%shave truncated %v\nwant truncated %v",
+					tt.pieces, tt.stop, formatContentDiff(result, tt.expected), resultTrunc, tt.expectedTrunc)
 			}
 		})
 	}
 }
 
+func formatContentDiff(result, expected []CompletionResponse) string {
+	var s string
+	for i := 0; i < len(result) || i < len(expected); i++ {
+		if i < len(result) && i < len(expected) && result[i].Content != expected[i].Content {
+			s += fmt.Sprintf("[%d] %q vs %q\n", i, result[i].Content, expected[i].Content)
+		} else if i < len(result) && i >= len(expected) {
+			s += fmt.Sprintf("[%d] extra %q\n", i, result[i].Content)
+		} else if i >= len(result) && i < len(expected) {
+			s += fmt.Sprintf("[%d] missing %q\n", i, expected[i].Content)
+		}
+	}
+	return s
+}
+
 func TestIncompleteUnicode(t *testing.T) {
 	tests := []struct {
 		name     string

+ 23 - 0
runner/common/types.go

@@ -0,0 +1,23 @@
+package common
+
+type CompletionResponse struct {
+	Content string `json:"content"`
+	Stop    bool   `json:"stop"`
+
+	Model        string  `json:"model,omitempty"`
+	Prompt       string  `json:"prompt,omitempty"`
+	StoppedLimit bool    `json:"stopped_limit,omitempty"`
+	PredictedN   int     `json:"predicted_n,omitempty"`
+	PredictedMS  float64 `json:"predicted_ms,omitempty"`
+	PromptN      int     `json:"prompt_n,omitempty"`
+	PromptMS     float64 `json:"prompt_ms,omitempty"`
+
+	Timings Timings `json:"timings"`
+}
+
+type Timings struct {
+	PredictedN  int     `json:"predicted_n"`
+	PredictedMS float64 `json:"predicted_ms"`
+	PromptN     int     `json:"prompt_n"`
+	PromptMS    float64 `json:"prompt_ms"`
+}

+ 29 - 27
runner/llamarunner/runner.go

@@ -51,7 +51,7 @@ type Sequence struct {
 	pendingInputs []input
 
 	// tokens that have been generated but not returned yet (e.g. for stop sequences)
-	pendingResponses []string
+	pendingResponses []common.CompletionResponse
 
 	// input cache being used by this sequence
 	cache *InputCacheSlot
@@ -61,7 +61,7 @@ type Sequence struct {
 	crossAttention bool
 
 	// channel to send responses over
-	responses chan string
+	responses chan common.CompletionResponse
 
 	// channel to stop decoding (such as if the remote connection is closed)
 	quit chan bool
@@ -150,8 +150,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
 		numPromptInputs:     len(inputs),
 		startProcessingTime: startTime,
 		numPredict:          params.numPredict,
-		pendingResponses:    make([]string, 0),
-		responses:           make(chan string, 100),
+		pendingResponses:    make([]common.CompletionResponse, 0),
+		responses:           make(chan common.CompletionResponse, 100),
 		quit:                make(chan bool, 1),
 		embedding:           make(chan []float32, 1),
 		samplingCtx:         sc,
@@ -276,29 +276,28 @@ func (s *Server) allNil() bool {
 }
 
 func flushPending(seq *Sequence) bool {
-	joined := strings.Join(seq.pendingResponses, "")
-	seq.pendingResponses = []string{}
-
-	// Check if there are any partial UTF-8 characters remaining.
-	// We already check and queue as we are generating but some may
-	// still make it here:
-	// - Sequence is ending, e.g. generation limit has been hit
-	// - Invalid characters in the middle of a string
-	// This is a stricter check to ensure we never output invalid Unicode.
-	for !utf8.ValidString(joined) {
-		joined = joined[:len(joined)-1]
-	}
-
-	if len(joined) == 0 {
-		return true
-	}
+	pending := seq.pendingResponses
+	seq.pendingResponses = []common.CompletionResponse{}
+
+	for i, r := range pending {
+		if i == len(pending)-1 {
+			// Check and trim any trailing partial UTF-8 characters
+			content := r.Content
+			for !utf8.ValidString(content) {
+				content = content[:len(content)-1]
+			}
+			r.Content = content
+		}
 
-	select {
-	case seq.responses <- joined:
-		return true
-	case <-seq.quit:
-		return false
+		select {
+		case seq.responses <- r:
+			return true
+		case <-seq.quit:
+			return false
+		}
 	}
+	// no pending responses to send
+	return true
 }
 
 func (s *Server) removeSequence(seqIndex int, reason string) {
@@ -497,8 +496,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 
 		seq.inputs = []input{{token: token}}
 
-		seq.pendingResponses = append(seq.pendingResponses, piece)
-		sequence := strings.Join(seq.pendingResponses, "")
+		seq.pendingResponses = append(seq.pendingResponses, common.CompletionResponse{Content: piece})
+		sequence := ""
+		for _, r := range seq.pendingResponses {
+			sequence += r.Content
+		}
 
 		if ok, stop := common.FindStop(sequence, seq.stop); ok {
 			slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)

+ 29 - 27
runner/ollamarunner/runner.go

@@ -53,13 +53,13 @@ type Sequence struct {
 	pendingInputs []input.Input
 
 	// tokens that have been generated but not returned yet (e.g. for stop sequences)
-	pendingResponses []string
+	pendingResponses []common.CompletionResponse
 
 	// input cache being used by this sequence
 	cache *InputCacheSlot
 
 	// channel to send responses over
-	responses chan string
+	responses chan common.CompletionResponse
 
 	// channel to stop decoding (such as if the remote connection is closed)
 	quit chan bool
@@ -138,8 +138,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
 		numPromptInputs:     len(inputs),
 		startProcessingTime: startTime,
 		numPredict:          params.numPredict,
-		pendingResponses:    make([]string, 0),
-		responses:           make(chan string, 100),
+		pendingResponses:    make([]common.CompletionResponse, 0),
+		responses:           make(chan common.CompletionResponse, 100),
 		quit:                make(chan bool, 1),
 		embedding:           make(chan []float32, 1),
 		sampler:             params.sampler,
@@ -288,29 +288,28 @@ func (s *Server) allNil() bool {
 }
 
 func flushPending(seq *Sequence) bool {
-	joined := strings.Join(seq.pendingResponses, "")
-	seq.pendingResponses = []string{}
-
-	// Check if there are any partial UTF-8 characters remaining.
-	// We already check and queue as we are generating but some may
-	// still make it here:
-	// - Sequence is ending, e.g. generation limit has been hit
-	// - Invalid characters in the middle of a string
-	// This is a stricter check to ensure we never output invalid Unicode.
-	for !utf8.ValidString(joined) {
-		joined = joined[:len(joined)-1]
-	}
-
-	if len(joined) == 0 {
-		return true
-	}
+	pending := seq.pendingResponses
+	seq.pendingResponses = []common.CompletionResponse{}
+
+	for i, r := range pending {
+		if i == len(pending)-1 {
+			// Check and trim any trailing partial UTF-8 characters
+			content := r.Content
+			for !utf8.ValidString(content) {
+				content = content[:len(content)-1]
+			}
+			r.Content = content
+		}
 
-	select {
-	case seq.responses <- joined:
-		return true
-	case <-seq.quit:
-		return false
+		select {
+		case seq.responses <- r:
+			return true
+		case <-seq.quit:
+			return false
+		}
 	}
+	// no pending responses to send
+	return true
 }
 
 func (s *Server) removeSequence(seqIndex int, reason string) {
@@ -484,8 +483,11 @@ func (s *Server) processBatch() error {
 
 		seq.inputs = []input.Input{{Token: token}}
 
-		seq.pendingResponses = append(seq.pendingResponses, piece)
-		sequence := strings.Join(seq.pendingResponses, "")
+		seq.pendingResponses = append(seq.pendingResponses, common.CompletionResponse{Content: piece})
+		sequence := ""
+		for _, r := range seq.pendingResponses {
+			sequence += r.Content
+		}
 
 		if ok, stop := common.FindStop(sequence, seq.stop); ok {
 			slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)