Browse Source

update completion responses

Bruce MacDonald 1 month ago
parent
commit
946fdd5388

+ 5 - 3
runner/common/stop.go

@@ -2,6 +2,8 @@ package common
 
 import (
 	"strings"
+
+	"github.com/ollama/ollama/llm"
 )
 
 func FindStop(sequence string, stops []string) (bool, string) {
@@ -29,7 +31,7 @@ func ContainsStopSuffix(sequence string, stops []string) bool {
 // truncateStop removes the provided stop string from pieces,
 // returning the partial pieces with stop removed, including truncating
 // the last piece if required (and signalling if this was the case)
-func TruncateStop(resps []CompletionResponse, stop string) ([]CompletionResponse, bool) {
+func TruncateStop(resps []llm.CompletionResponse, stop string) ([]llm.CompletionResponse, bool) {
 	var sequence string
 	for _, resp := range resps {
 		sequence += resp.Content
@@ -45,7 +47,7 @@ func TruncateStop(resps []CompletionResponse, stop string) ([]CompletionResponse
 		return nil, true
 	}
 
-	result := make([]CompletionResponse, 0, len(resps))
+	result := make([]llm.CompletionResponse, 0, len(resps))
 
 	// Track position in truncated sequence
 	pos := 0
@@ -60,7 +62,7 @@ func TruncateStop(resps []CompletionResponse, stop string) ([]CompletionResponse
 			truncationHappened = true
 		}
 		if len(chunk) > 0 {
-			result = append(result, CompletionResponse{Content: chunk})
+			result = append(result, llm.CompletionResponse{Content: chunk})
 		}
 		pos += len(resp.Content)
 	}

+ 15 - 13
runner/common/stop_test.go

@@ -4,36 +4,38 @@ import (
 	"fmt"
 	"reflect"
 	"testing"
+
+	"github.com/ollama/ollama/llm"
 )
 
 func TestTruncateStop(t *testing.T) {
 	tests := []struct {
 		name          string
-		pieces        []CompletionResponse
+		pieces        []llm.CompletionResponse
 		stop          string
-		expected      []CompletionResponse
+		expected      []llm.CompletionResponse
 		expectedTrunc bool
 	}{
 		{
 			name: "Single word",
-			pieces: []CompletionResponse{
+			pieces: []llm.CompletionResponse{
 				{Content: "Hello"},
 				{Content: "world"},
 			},
 			stop: "world",
-			expected: []CompletionResponse{
+			expected: []llm.CompletionResponse{
 				{Content: "Hello"},
 			},
 			expectedTrunc: false,
 		},
 		{
 			name: "Partial",
-			pieces: []CompletionResponse{
+			pieces: []llm.CompletionResponse{
 				{Content: "Hello"},
 				{Content: " wor"},
 			},
 			stop: "or",
-			expected: []CompletionResponse{
+			expected: []llm.CompletionResponse{
 				{Content: "Hello"},
 				{Content: " w"},
 			},
@@ -41,13 +43,13 @@ func TestTruncateStop(t *testing.T) {
 		},
 		{
 			name: "Suffix",
-			pieces: []CompletionResponse{
+			pieces: []llm.CompletionResponse{
 				{Content: "Hello"},
 				{Content: " there"},
 				{Content: "!"},
 			},
 			stop: "!",
-			expected: []CompletionResponse{
+			expected: []llm.CompletionResponse{
 				{Content: "Hello"},
 				{Content: " there"},
 			},
@@ -55,13 +57,13 @@ func TestTruncateStop(t *testing.T) {
 		},
 		{
 			name: "Suffix partial",
-			pieces: []CompletionResponse{
+			pieces: []llm.CompletionResponse{
 				{Content: "Hello"},
 				{Content: " the"},
 				{Content: "re!"},
 			},
 			stop: "there!",
-			expected: []CompletionResponse{
+			expected: []llm.CompletionResponse{
 				{Content: "Hello"},
 				{Content: " "},
 			},
@@ -69,12 +71,12 @@ func TestTruncateStop(t *testing.T) {
 		},
 		{
 			name: "Middle",
-			pieces: []CompletionResponse{
+			pieces: []llm.CompletionResponse{
 				{Content: "Hello"},
 				{Content: " wo"},
 			},
 			stop: "llo w",
-			expected: []CompletionResponse{
+			expected: []llm.CompletionResponse{
 				{Content: "He"},
 			},
 			expectedTrunc: true,
@@ -92,7 +94,7 @@ func TestTruncateStop(t *testing.T) {
 	}
 }
 
-func formatContentDiff(result, expected []CompletionResponse) string {
+func formatContentDiff(result, expected []llm.CompletionResponse) string {
 	var s string
 	for i := 0; i < len(result) || i < len(expected); i++ {
 		if i < len(result) && i < len(expected) && result[i].Content != expected[i].Content {

+ 0 - 23
runner/common/types.go

@@ -1,23 +0,0 @@
-package common
-
-type CompletionResponse struct {
-	Content string `json:"content"`
-	Stop    bool   `json:"stop"`
-
-	Model        string  `json:"model,omitempty"`
-	Prompt       string  `json:"prompt,omitempty"`
-	StoppedLimit bool    `json:"stopped_limit,omitempty"`
-	PredictedN   int     `json:"predicted_n,omitempty"`
-	PredictedMS  float64 `json:"predicted_ms,omitempty"`
-	PromptN      int     `json:"prompt_n,omitempty"`
-	PromptMS     float64 `json:"prompt_ms,omitempty"`
-
-	Timings Timings `json:"timings"`
-}
-
-type Timings struct {
-	PredictedN  int     `json:"predicted_n"`
-	PredictedMS float64 `json:"predicted_ms"`
-	PromptN     int     `json:"prompt_n"`
-	PromptMS    float64 `json:"prompt_ms"`
-}

+ 7 - 9
runner/llamarunner/runner.go

@@ -51,7 +51,7 @@ type Sequence struct {
 	pendingInputs []input
 
 	// tokens that have been generated but not returned yet (e.g. for stop sequences)
-	pendingResponses []common.CompletionResponse
+	pendingResponses []llm.CompletionResponse
 
 	// input cache being used by this sequence
 	cache *InputCacheSlot
@@ -61,7 +61,7 @@ type Sequence struct {
 	crossAttention bool
 
 	// channel to send responses over
-	responses chan common.CompletionResponse
+	responses chan llm.CompletionResponse
 
 	// channel to stop decoding (such as if the remote connection is closed)
 	quit chan bool
@@ -150,8 +150,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
 		numPromptInputs:     len(inputs),
 		startProcessingTime: startTime,
 		numPredict:          params.numPredict,
-		pendingResponses:    make([]common.CompletionResponse, 0),
-		responses:           make(chan common.CompletionResponse, 100),
+		pendingResponses:    make([]llm.CompletionResponse, 0),
+		responses:           make(chan llm.CompletionResponse, 100),
 		quit:                make(chan bool, 1),
 		embedding:           make(chan []float32, 1),
 		samplingCtx:         sc,
@@ -277,7 +277,7 @@ func (s *Server) allNil() bool {
 
 func flushPending(seq *Sequence) bool {
 	pending := seq.pendingResponses
-	seq.pendingResponses = []common.CompletionResponse{}
+	seq.pendingResponses = []llm.CompletionResponse{}
 
 	for i, r := range pending {
 		if i == len(pending)-1 {
@@ -496,7 +496,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 
 		seq.inputs = []input{{token: token}}
 
-		seq.pendingResponses = append(seq.pendingResponses, common.CompletionResponse{Content: piece})
+		seq.pendingResponses = append(seq.pendingResponses, llm.CompletionResponse{Content: piece})
 		sequence := ""
 		for _, r := range seq.pendingResponses {
 			sequence += r.Content
@@ -639,9 +639,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 			return
 		case content, ok := <-seq.responses:
 			if ok {
-				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
-					Content: content,
-				}); err != nil {
+				if err := json.NewEncoder(w).Encode(&content); err != nil {
 					http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
 					close(seq.quit)
 					return

+ 7 - 9
runner/ollamarunner/runner.go

@@ -53,13 +53,13 @@ type Sequence struct {
 	pendingInputs []input.Input
 
 	// tokens that have been generated but not returned yet (e.g. for stop sequences)
-	pendingResponses []common.CompletionResponse
+	pendingResponses []llm.CompletionResponse
 
 	// input cache being used by this sequence
 	cache *InputCacheSlot
 
 	// channel to send responses over
-	responses chan common.CompletionResponse
+	responses chan llm.CompletionResponse
 
 	// channel to stop decoding (such as if the remote connection is closed)
 	quit chan bool
@@ -138,8 +138,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
 		numPromptInputs:     len(inputs),
 		startProcessingTime: startTime,
 		numPredict:          params.numPredict,
-		pendingResponses:    make([]common.CompletionResponse, 0),
-		responses:           make(chan common.CompletionResponse, 100),
+		pendingResponses:    make([]llm.CompletionResponse, 0),
+		responses:           make(chan llm.CompletionResponse, 100),
 		quit:                make(chan bool, 1),
 		embedding:           make(chan []float32, 1),
 		sampler:             params.sampler,
@@ -289,7 +289,7 @@ func (s *Server) allNil() bool {
 
 func flushPending(seq *Sequence) bool {
 	pending := seq.pendingResponses
-	seq.pendingResponses = []common.CompletionResponse{}
+	seq.pendingResponses = []llm.CompletionResponse{}
 
 	for i, r := range pending {
 		if i == len(pending)-1 {
@@ -483,7 +483,7 @@ func (s *Server) processBatch() error {
 
 		seq.inputs = []input.Input{{Token: token}}
 
-		seq.pendingResponses = append(seq.pendingResponses, common.CompletionResponse{Content: piece})
+		seq.pendingResponses = append(seq.pendingResponses, llm.CompletionResponse{Content: piece})
 		sequence := ""
 		for _, r := range seq.pendingResponses {
 			sequence += r.Content
@@ -625,9 +625,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 			return
 		case content, ok := <-seq.responses:
 			if ok {
-				if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
-					Content: content,
-				}); err != nil {
+				if err := json.NewEncoder(w).Encode(&content); err != nil {
 					http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
 					close(seq.quit)
 					return