瀏覽代碼

runner: enable returning more info from runner processing

Currently we return only the text predicted from the LLM. This was nice in
that it was simple, but there may be other info we want to know from the
processing. This change adds the ability to return more information from the
runner than just the text predicted.

A follow up change will add logprobs to the response returned from the
runner using this structure.
Bruce MacDonald 2 月之前
父節點
當前提交
905da35468
共有 5 個文件被更改,包括 176 次插入100 次删除
  1. 26 23
      runner/common/stop.go
  2. 69 23
      runner/common/stop_test.go
  3. 23 0
      runner/common/types.go
  4. 29 27
      runner/llamarunner/runner.go
  5. 29 27
      runner/ollamarunner/runner.go

+ 26 - 23
runner/common/stop.go

@@ -29,40 +29,43 @@ func ContainsStopSuffix(sequence string, stops []string) bool {
 // truncateStop removes the provided stop string from pieces,
 // truncateStop removes the provided stop string from pieces,
 // returning the partial pieces with stop removed, including truncating
 // returning the partial pieces with stop removed, including truncating
 // the last piece if required (and signalling if this was the case)
 // the last piece if required (and signalling if this was the case)
-func TruncateStop(pieces []string, stop string) ([]string, bool) {
-	joined := strings.Join(pieces, "")
-
-	index := strings.Index(joined, stop)
-	if index == -1 {
-		return pieces, false
+func TruncateStop(resps []CompletionResponse, stop string) ([]CompletionResponse, bool) {
+	var sequence string
+	for _, resp := range resps {
+		sequence += resp.Content
 	}
 	}
 
 
-	joined = joined[:index]
+	idx := strings.Index(sequence, stop)
+	if idx < 0 {
+		return resps, false
+	}
 
 
-	// Split truncated string back into pieces of original lengths
-	lengths := make([]int, len(pieces))
-	for i, piece := range pieces {
-		lengths[i] = len(piece)
+	truncated := sequence[:idx]
+	if len(truncated) == 0 {
+		return nil, true
 	}
 	}
 
 
-	var result []string
-	tokenTruncated := false
-	start := 0
-	for _, length := range lengths {
-		if start >= len(joined) {
+	result := make([]CompletionResponse, 0, len(resps))
+
+	// Track position in truncated sequence
+	pos := 0
+	truncationHappened := false
+	for _, resp := range resps {
+		if pos >= len(truncated) {
 			break
 			break
 		}
 		}
 
 
-		end := start + length
-		if end > len(joined) {
-			end = len(joined)
-			tokenTruncated = true
+		chunk := truncated[pos:min(pos+len(resp.Content), len(truncated))]
+		if len(chunk) < len(resp.Content) {
+			truncationHappened = true
+		}
+		if len(chunk) > 0 {
+			result = append(result, CompletionResponse{Content: chunk})
 		}
 		}
-		result = append(result, joined[start:end])
-		start = end
+		pos += len(resp.Content)
 	}
 	}
 
 
-	return result, tokenTruncated
+	return result, truncationHappened
 }
 }
 
 
 func IncompleteUnicode(token string) bool {
 func IncompleteUnicode(token string) bool {

+ 69 - 23
runner/common/stop_test.go

@@ -1,6 +1,7 @@
 package common
 package common
 
 
 import (
 import (
+	"fmt"
 	"reflect"
 	"reflect"
 	"testing"
 	"testing"
 )
 )
@@ -8,44 +9,74 @@ import (
 func TestTruncateStop(t *testing.T) {
 func TestTruncateStop(t *testing.T) {
 	tests := []struct {
 	tests := []struct {
 		name          string
 		name          string
-		pieces        []string
+		pieces        []CompletionResponse
 		stop          string
 		stop          string
-		expected      []string
+		expected      []CompletionResponse
 		expectedTrunc bool
 		expectedTrunc bool
 	}{
 	}{
 		{
 		{
-			name:          "Single word",
-			pieces:        []string{"hello", "world"},
-			stop:          "world",
-			expected:      []string{"hello"},
+			name: "Single word",
+			pieces: []CompletionResponse{
+				{Content: "Hello"},
+				{Content: "world"},
+			},
+			stop: "world",
+			expected: []CompletionResponse{
+				{Content: "Hello"},
+			},
 			expectedTrunc: false,
 			expectedTrunc: false,
 		},
 		},
 		{
 		{
-			name:          "Partial",
-			pieces:        []string{"hello", "wor"},
-			stop:          "or",
-			expected:      []string{"hello", "w"},
+			name: "Partial",
+			pieces: []CompletionResponse{
+				{Content: "Hello"},
+				{Content: " wor"},
+			},
+			stop: "or",
+			expected: []CompletionResponse{
+				{Content: "Hello"},
+				{Content: " w"},
+			},
 			expectedTrunc: true,
 			expectedTrunc: true,
 		},
 		},
 		{
 		{
-			name:          "Suffix",
-			pieces:        []string{"Hello", " there", "!"},
-			stop:          "!",
-			expected:      []string{"Hello", " there"},
+			name: "Suffix",
+			pieces: []CompletionResponse{
+				{Content: "Hello"},
+				{Content: " there"},
+				{Content: "!"},
+			},
+			stop: "!",
+			expected: []CompletionResponse{
+				{Content: "Hello"},
+				{Content: " there"},
+			},
 			expectedTrunc: false,
 			expectedTrunc: false,
 		},
 		},
 		{
 		{
-			name:          "Suffix partial",
-			pieces:        []string{"Hello", " the", "re!"},
-			stop:          "there!",
-			expected:      []string{"Hello", " "},
+			name: "Suffix partial",
+			pieces: []CompletionResponse{
+				{Content: "Hello"},
+				{Content: " the"},
+				{Content: "re!"},
+			},
+			stop: "there!",
+			expected: []CompletionResponse{
+				{Content: "Hello"},
+				{Content: " "},
+			},
 			expectedTrunc: true,
 			expectedTrunc: true,
 		},
 		},
 		{
 		{
-			name:          "Middle",
-			pieces:        []string{"hello", " wor"},
-			stop:          "llo w",
-			expected:      []string{"he"},
+			name: "Middle",
+			pieces: []CompletionResponse{
+				{Content: "Hello"},
+				{Content: " wo"},
+			},
+			stop: "llo w",
+			expected: []CompletionResponse{
+				{Content: "He"},
+			},
 			expectedTrunc: true,
 			expectedTrunc: true,
 		},
 		},
 	}
 	}
@@ -54,12 +85,27 @@ func TestTruncateStop(t *testing.T) {
 		t.Run(tt.name, func(t *testing.T) {
 		t.Run(tt.name, func(t *testing.T) {
 			result, resultTrunc := TruncateStop(tt.pieces, tt.stop)
 			result, resultTrunc := TruncateStop(tt.pieces, tt.stop)
 			if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
 			if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
-				t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc)
+				t.Errorf("truncateStop(%v, %v):\n%shave truncated %v\nwant truncated %v",
+					tt.pieces, tt.stop, formatContentDiff(result, tt.expected), resultTrunc, tt.expectedTrunc)
 			}
 			}
 		})
 		})
 	}
 	}
 }
 }
 
 
+func formatContentDiff(result, expected []CompletionResponse) string {
+	var s string
+	for i := 0; i < len(result) || i < len(expected); i++ {
+		if i < len(result) && i < len(expected) && result[i].Content != expected[i].Content {
+			s += fmt.Sprintf("[%d] %q vs %q\n", i, result[i].Content, expected[i].Content)
+		} else if i < len(result) && i >= len(expected) {
+			s += fmt.Sprintf("[%d] extra %q\n", i, result[i].Content)
+		} else if i >= len(result) && i < len(expected) {
+			s += fmt.Sprintf("[%d] missing %q\n", i, expected[i].Content)
+		}
+	}
+	return s
+}
+
 func TestIncompleteUnicode(t *testing.T) {
 func TestIncompleteUnicode(t *testing.T) {
 	tests := []struct {
 	tests := []struct {
 		name     string
 		name     string

+ 23 - 0
runner/common/types.go

@@ -0,0 +1,23 @@
+package common
+
+type CompletionResponse struct {
+	Content string `json:"content"`
+	Stop    bool   `json:"stop"`
+
+	Model        string  `json:"model,omitempty"`
+	Prompt       string  `json:"prompt,omitempty"`
+	StoppedLimit bool    `json:"stopped_limit,omitempty"`
+	PredictedN   int     `json:"predicted_n,omitempty"`
+	PredictedMS  float64 `json:"predicted_ms,omitempty"`
+	PromptN      int     `json:"prompt_n,omitempty"`
+	PromptMS     float64 `json:"prompt_ms,omitempty"`
+
+	Timings Timings `json:"timings"`
+}
+
+type Timings struct {
+	PredictedN  int     `json:"predicted_n"`
+	PredictedMS float64 `json:"predicted_ms"`
+	PromptN     int     `json:"prompt_n"`
+	PromptMS    float64 `json:"prompt_ms"`
+}

+ 29 - 27
runner/llamarunner/runner.go

@@ -51,7 +51,7 @@ type Sequence struct {
 	pendingInputs []input
 	pendingInputs []input
 
 
 	// tokens that have been generated but not returned yet (e.g. for stop sequences)
 	// tokens that have been generated but not returned yet (e.g. for stop sequences)
-	pendingResponses []string
+	pendingResponses []common.CompletionResponse
 
 
 	// input cache being used by this sequence
 	// input cache being used by this sequence
 	cache *InputCacheSlot
 	cache *InputCacheSlot
@@ -61,7 +61,7 @@ type Sequence struct {
 	crossAttention bool
 	crossAttention bool
 
 
 	// channel to send responses over
 	// channel to send responses over
-	responses chan string
+	responses chan common.CompletionResponse
 
 
 	// channel to stop decoding (such as if the remote connection is closed)
 	// channel to stop decoding (such as if the remote connection is closed)
 	quit chan bool
 	quit chan bool
@@ -150,8 +150,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
 		numPromptInputs:     len(inputs),
 		numPromptInputs:     len(inputs),
 		startProcessingTime: startTime,
 		startProcessingTime: startTime,
 		numPredict:          params.numPredict,
 		numPredict:          params.numPredict,
-		pendingResponses:    make([]string, 0),
-		responses:           make(chan string, 100),
+		pendingResponses:    make([]common.CompletionResponse, 0),
+		responses:           make(chan common.CompletionResponse, 100),
 		quit:                make(chan bool, 1),
 		quit:                make(chan bool, 1),
 		embedding:           make(chan []float32, 1),
 		embedding:           make(chan []float32, 1),
 		samplingCtx:         sc,
 		samplingCtx:         sc,
@@ -276,29 +276,28 @@ func (s *Server) allNil() bool {
 }
 }
 
 
 func flushPending(seq *Sequence) bool {
 func flushPending(seq *Sequence) bool {
-	joined := strings.Join(seq.pendingResponses, "")
-	seq.pendingResponses = []string{}
-
-	// Check if there are any partial UTF-8 characters remaining.
-	// We already check and queue as we are generating but some may
-	// still make it here:
-	// - Sequence is ending, e.g. generation limit has been hit
-	// - Invalid characters in the middle of a string
-	// This is a stricter check to ensure we never output invalid Unicode.
-	for !utf8.ValidString(joined) {
-		joined = joined[:len(joined)-1]
-	}
-
-	if len(joined) == 0 {
-		return true
-	}
+	pending := seq.pendingResponses
+	seq.pendingResponses = []common.CompletionResponse{}
+
+	for i, r := range pending {
+		if i == len(pending)-1 {
+			// Check and trim any trailing partial UTF-8 characters
+			content := r.Content
+			for !utf8.ValidString(content) {
+				content = content[:len(content)-1]
+			}
+			r.Content = content
+		}
 
 
-	select {
-	case seq.responses <- joined:
-		return true
-	case <-seq.quit:
-		return false
+		select {
+		case seq.responses <- r:
+			return true
+		case <-seq.quit:
+			return false
+		}
 	}
 	}
+	// no pending responses to send
+	return true
 }
 }
 
 
 func (s *Server) removeSequence(seqIndex int, reason string) {
 func (s *Server) removeSequence(seqIndex int, reason string) {
@@ -497,8 +496,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 
 
 		seq.inputs = []input{{token: token}}
 		seq.inputs = []input{{token: token}}
 
 
-		seq.pendingResponses = append(seq.pendingResponses, piece)
-		sequence := strings.Join(seq.pendingResponses, "")
+		seq.pendingResponses = append(seq.pendingResponses, common.CompletionResponse{Content: piece})
+		sequence := ""
+		for _, r := range seq.pendingResponses {
+			sequence += r.Content
+		}
 
 
 		if ok, stop := common.FindStop(sequence, seq.stop); ok {
 		if ok, stop := common.FindStop(sequence, seq.stop); ok {
 			slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
 			slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)

+ 29 - 27
runner/ollamarunner/runner.go

@@ -53,13 +53,13 @@ type Sequence struct {
 	pendingInputs []input.Input
 	pendingInputs []input.Input
 
 
 	// tokens that have been generated but not returned yet (e.g. for stop sequences)
 	// tokens that have been generated but not returned yet (e.g. for stop sequences)
-	pendingResponses []string
+	pendingResponses []common.CompletionResponse
 
 
 	// input cache being used by this sequence
 	// input cache being used by this sequence
 	cache *InputCacheSlot
 	cache *InputCacheSlot
 
 
 	// channel to send responses over
 	// channel to send responses over
-	responses chan string
+	responses chan common.CompletionResponse
 
 
 	// channel to stop decoding (such as if the remote connection is closed)
 	// channel to stop decoding (such as if the remote connection is closed)
 	quit chan bool
 	quit chan bool
@@ -138,8 +138,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
 		numPromptInputs:     len(inputs),
 		numPromptInputs:     len(inputs),
 		startProcessingTime: startTime,
 		startProcessingTime: startTime,
 		numPredict:          params.numPredict,
 		numPredict:          params.numPredict,
-		pendingResponses:    make([]string, 0),
-		responses:           make(chan string, 100),
+		pendingResponses:    make([]common.CompletionResponse, 0),
+		responses:           make(chan common.CompletionResponse, 100),
 		quit:                make(chan bool, 1),
 		quit:                make(chan bool, 1),
 		embedding:           make(chan []float32, 1),
 		embedding:           make(chan []float32, 1),
 		sampler:             params.sampler,
 		sampler:             params.sampler,
@@ -288,29 +288,28 @@ func (s *Server) allNil() bool {
 }
 }
 
 
 func flushPending(seq *Sequence) bool {
 func flushPending(seq *Sequence) bool {
-	joined := strings.Join(seq.pendingResponses, "")
-	seq.pendingResponses = []string{}
-
-	// Check if there are any partial UTF-8 characters remaining.
-	// We already check and queue as we are generating but some may
-	// still make it here:
-	// - Sequence is ending, e.g. generation limit has been hit
-	// - Invalid characters in the middle of a string
-	// This is a stricter check to ensure we never output invalid Unicode.
-	for !utf8.ValidString(joined) {
-		joined = joined[:len(joined)-1]
-	}
-
-	if len(joined) == 0 {
-		return true
-	}
+	pending := seq.pendingResponses
+	seq.pendingResponses = []common.CompletionResponse{}
+
+	for i, r := range pending {
+		if i == len(pending)-1 {
+			// Check and trim any trailing partial UTF-8 characters
+			content := r.Content
+			for !utf8.ValidString(content) {
+				content = content[:len(content)-1]
+			}
+			r.Content = content
+		}
 
 
-	select {
-	case seq.responses <- joined:
-		return true
-	case <-seq.quit:
-		return false
+		select {
+		case seq.responses <- r:
+			return true
+		case <-seq.quit:
+			return false
+		}
 	}
 	}
+	// no pending responses to send
+	return true
 }
 }
 
 
 func (s *Server) removeSequence(seqIndex int, reason string) {
 func (s *Server) removeSequence(seqIndex int, reason string) {
@@ -484,8 +483,11 @@ func (s *Server) processBatch() error {
 
 
 		seq.inputs = []input.Input{{Token: token}}
 		seq.inputs = []input.Input{{Token: token}}
 
 
-		seq.pendingResponses = append(seq.pendingResponses, piece)
-		sequence := strings.Join(seq.pendingResponses, "")
+		seq.pendingResponses = append(seq.pendingResponses, common.CompletionResponse{Content: piece})
+		sequence := ""
+		for _, r := range seq.pendingResponses {
+			sequence += r.Content
+		}
 
 
 		if ok, stop := common.FindStop(sequence, seq.stop); ok {
 		if ok, stop := common.FindStop(sequence, seq.stop); ok {
 			slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
 			slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)