Browse Source

prototype

Bruce MacDonald 2 tháng trước cách đây
mục cha
commit
fdbb0b5cfe
7 tập tin đã thay đổi với 154 bổ sung171 xóa
  1. 17 4
      api/types.go
  2. 0 12
      llama/llama.go
  3. 49 39
      llama/runner/runner.go
  4. 7 38
      llama/runner/stop.go
  5. 29 59
      llama/runner/stop_test.go
  6. 27 11
      llm/server.go
  7. 25 8
      server/routes.go

+ 17 - 4
api/types.go

@@ -77,6 +77,8 @@ type GenerateRequest struct {
 	// request, for multimodal models.
 	Images []ImageData `json:"images,omitempty"`
 
+	LogProbs int `json:"logprobs,omitempty"`
+
 	// Options lists model-specific options. For example, temperature can be
 	// set through this field, if the model supports it.
 	Options map[string]interface{} `json:"options"`
@@ -103,6 +105,8 @@ type ChatRequest struct {
 	// Tools is an optional list of tools the model has access to.
 	Tools `json:"tools,omitempty"`
 
+	LogProbs int `json:"logprobs,omitempty"`
+
 	// Options lists model-specific options.
 	Options map[string]interface{} `json:"options"`
 }
@@ -182,13 +186,20 @@ func (t *ToolFunction) String() string {
 	return string(bts)
 }
 
+type TokenProbs struct {
+	TokenID int     `json:"id"`
+	LogProb float32 `json:"logprob"`
+	Token   string  `json:"token"`
+}
+
 // ChatResponse is the response returned by [Client.Chat]. Its fields are
 // similar to [GenerateResponse].
 type ChatResponse struct {
-	Model      string    `json:"model"`
-	CreatedAt  time.Time `json:"created_at"`
-	Message    Message   `json:"message"`
-	DoneReason string    `json:"done_reason,omitempty"`
+	Model      string       `json:"model"`
+	CreatedAt  time.Time    `json:"created_at"`
+	Message    Message      `json:"message"`
+	DoneReason string       `json:"done_reason,omitempty"`
+	LogProbs   []TokenProbs `json:"logprobs,omitempty"`
 
 	Done bool `json:"done"`
 
@@ -452,6 +463,8 @@ type GenerateResponse struct {
 	// can be sent in the next request to keep a conversational memory.
 	Context []int `json:"context,omitempty"`
 
+	LogProbs []TokenProbs `json:"logprobs,omitempty"`
+
 	Metrics
 }
 

+ 0 - 12
llama/llama.go

@@ -233,18 +233,6 @@ func (c *Context) GetLogits() []float32 {
 	return unsafe.Slice((*float32)(logits), vocabSize)
 }
 
-func (m *Model) Detokenize(tokens []int) (string, error) {
-	var text string
-	for _, token := range tokens {
-		piece := m.TokenToPiece(token)
-		if piece == "" {
-			return "", fmt.Errorf("failed to convert token %d to piece", token)
-		}
-		text += piece
-	}
-	return text, nil
-}
-
 type ModelParams struct {
 	NumGpuLayers int
 	MainGpu      int

+ 49 - 39
llama/runner/runner.go

@@ -104,6 +104,7 @@ type NewSequenceParams struct {
 	numKeep        int
 	samplingParams *llama.SamplingParams
 	embedding      bool
+	logprobs       int
 }
 
 func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
@@ -164,6 +165,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
 		embeddingOnly:       params.embedding,
 		stop:                params.stop,
 		numKeep:             params.numKeep,
+		logprobs:            params.logprobs,
 	}, nil
 }
 
@@ -285,37 +287,34 @@ func flushPending(seq *Sequence) bool {
 	if len(seq.pendingResponses) == 0 {
 		return true
 	}
-	content := ""
+	resps := []CompletionResponse{}
 	for _, resp := range seq.pendingResponses {
-		content += resp.Content
+		resps = append(resps, resp)
 	}
 	seq.pendingResponses = []CompletionResponse{}
 
-	// Check if there are any partial UTF-8 characters remaining.
-	// We already check and queue as we are generating but some may
-	// still make it here:
-	// - Sequence is ending, e.g. generation limit has been hit
-	// - Invalid characters in the middle of a string
-	// This is a stricter check to ensure we never output invalid Unicode.
-	for !utf8.ValidString(content) {
-		content = content[:len(content)-1]
-	}
+	// TODO: figure out this result logic
+	result := false
+	for _, resp := range resps {
+		// Check if there are any partial UTF-8 characters remaining.
+		// We already check and queue as we are generating but some may
+		// still make it here:
+		// - Sequence is ending, e.g. generation limit has been hit
+		// - Invalid characters in the middle of a string
+		// This is a stricter check to ensure we never output invalid Unicode.
+		for !utf8.ValidString(resp.Content) {
+			resp.Content = resp.Content[:len(resp.Content)-1]
+		}
 
-	// Add logits if requested and available
-	wantLogits := true
-	if wantLogits && seq.logits != nil {
-		// resp.Logits = seq.logits
-		seq.logits = nil
+		select {
+		case seq.responses <- resp:
+			result = true
+		case <-seq.quit:
+			result = false
+		}
 	}
 
-	select {
-	case seq.responses <- CompletionResponse{
-		Content: content,
-	}:
-		return true
-	case <-seq.quit:
-		return false
-	}
+	return result
 }
 
 func (s *Server) removeSequence(seqIndex int, reason string) {
@@ -371,10 +370,11 @@ func (s *Server) run(ctx context.Context) {
 
 // TokenProbs represents probability information for a token
 type TokenProbs struct {
-	TokenID int
-	Logit   float32
-	Prob    float32
-	LogProb float32
+	TokenID int     `json:"id"`
+	Logit   float32 `json:"logit"`
+	Prob    float32 `json:"prob"`
+	LogProb float32 `json:"logprob"`
+	Token   string  `json:"token"`
 }
 
 // probs returns sorted token probabilities for a specific token index
@@ -553,9 +553,17 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 
 		seq.numPredicted++
 
+		resp := CompletionResponse{Content: piece}
+
 		if seq.logprobs > 0 {
 			// TODO: return selected token in logprobs always
-			// probs := s.probs(seq)
+			resp.LogProbs = s.probs(seq)
+			// TODO: fix this logprobs limit
+			resp.LogProbs = resp.LogProbs[:min(len(resp.LogProbs), seq.logprobs)]
+			for i := range resp.LogProbs {
+				// decode the token id to a piece
+				resp.LogProbs[i].Token = s.model.TokenToPiece(resp.LogProbs[i].TokenID)
+			}
 		}
 
 		// if it's an end of sequence token, break
@@ -571,7 +579,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 		seq.inputs = []input{{token: token}}
 
 		// TODO: add probs here
-		seq.pendingResponses = append(seq.pendingResponses, CompletionResponse{Content: piece})
+		seq.pendingResponses = append(seq.pendingResponses, resp)
 		var sequence string
 		for _, r := range seq.pendingResponses {
 			sequence += r.Content
@@ -580,10 +588,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 		if ok, stop := findStop(sequence, seq.stop); ok {
 			slog.Debug("hit stop token", "pending", seq.pendingResponses, "stop", stop)
 
+			// TODO: fix this stop sequence caching
 			var tokenTruncated bool
-			origLen := len(seq.pendingResponses)
-			seq.pendingResponses, tokenTruncated = truncateStop(seq.pendingResponses, stop)
-			newLen := len(seq.pendingResponses)
+			origLen := len(sequence)
+			sequence, tokenTruncated = truncateStop(sequence, stop)
+			newLen := len(sequence)
 
 			// Update the cache based on the tokens that will be returned:
 			// - We have 1 token more than is currently in the cache because
@@ -654,6 +663,7 @@ type CompletionRequest struct {
 	Images      []ImageData `json:"image_data"`
 	Grammar     string      `json:"grammar"`
 	CachePrompt bool        `json:"cache_prompt"`
+	Logprobs    int         `json:"logprobs,omitempty"`
 
 	Options
 }
@@ -669,8 +679,10 @@ type CompletionResponse struct {
 	Content string `json:"content"`
 	Stop    bool   `json:"stop"`
 
-	Model        string  `json:"model,omitempty"`
-	Prompt       string  `json:"prompt,omitempty"`
+	Model    string       `json:"model,omitempty"`
+	Prompt   string       `json:"prompt,omitempty"`
+	LogProbs []TokenProbs `json:"logprobs,omitempty"`
+
 	StoppedLimit bool    `json:"stopped_limit,omitempty"`
 	PredictedN   int     `json:"predicted_n,omitempty"`
 	PredictedMS  float64 `json:"predicted_ms,omitempty"`
@@ -688,10 +700,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 		return
 	}
 
-	// Set the headers to indicate streaming
-	w.Header().Set("Content-Type", "application/json")
-	w.Header().Set("Transfer-Encoding", "chunked")
-
 	flusher, ok := w.(http.Flusher)
 	if !ok {
 		http.Error(w, "Streaming not supported", http.StatusInternalServerError)
@@ -720,6 +728,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 		numKeep:        req.NumKeep,
 		samplingParams: &samplingParams,
 		embedding:      false,
+		logprobs:       req.Logprobs,
 	})
 	if err != nil {
 		http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
@@ -769,6 +778,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 			return
 		case resp, ok := <-seq.responses:
 			if ok {
+				fmt.Println("response", resp)
 				if err := json.NewEncoder(w).Encode(&resp); err != nil {
 					http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
 					close(seq.quit)

+ 7 - 38
llama/runner/stop.go

@@ -26,46 +26,15 @@ func containsStopSuffix(sequence string, stops []string) bool {
 	return false
 }
 
-// truncateStop removes the provided stop string from pieces,
-// returning the partial pieces with stop removed, including truncating
-// the last piece if required (and signalling if this was the case)
-func truncateStop(pieces []CompletionResponse, stop string) ([]CompletionResponse, bool) {
-	// Build complete string and find stop position
-	var completeStr string
-	for _, piece := range pieces {
-		completeStr += piece.Content
+// truncateStop removes the provided stop string from sequence,
+// returning both the truncated sequence and a bool indicating if truncation occurred
+func truncateStop(sequence string, stop string) (string, bool) {
+	index := strings.Index(sequence, stop)
+	if index == -1 {
+		return sequence, false
 	}
 
-	stopStart := strings.Index(completeStr, stop)
-	if stopStart == -1 {
-		return pieces, false
-	}
-
-	// Build result up to stop position
-	result := make([]CompletionResponse, 0)
-	accumulated := 0
-
-	truncated := false
-	for _, piece := range pieces {
-		if accumulated+len(piece.Content) <= stopStart {
-			result = append(result, piece)
-			accumulated += len(piece.Content)
-			continue
-		}
-
-		if accumulated < stopStart {
-			truncPiece := piece
-			truncPiece.Content = piece.Content[:stopStart-accumulated]
-			if len(truncPiece.Content) > 0 {
-				result = append(result, truncPiece)
-				truncated = true
-			}
-		}
-		break
-	}
-
-	// Signal if we had to truncate the last piece
-	return result, truncated
+	return sequence[:index], true
 }
 
 func incompleteUnicode(token string) bool {

+ 29 - 59
llama/runner/stop_test.go

@@ -1,90 +1,60 @@
 package runner
 
 import (
-	"reflect"
 	"testing"
 )
 
 func TestTruncateStop(t *testing.T) {
 	tests := []struct {
 		name          string
-		pieces        []CompletionResponse
+		sequence      string
 		stop          string
-		expected      []CompletionResponse
+		expected      string
 		expectedTrunc bool
 	}{
 		{
-			name: "Single word",
-			pieces: []CompletionResponse{
-				{Content: "hello"},
-				{Content: "world"},
-			},
-			stop: "world",
-			expected: []CompletionResponse{
-				{Content: "hello"},
-			},
-			expectedTrunc: false,
+			name:          "Single word",
+			sequence:      "helloworld",
+			stop:          "world",
+			expected:      "hello",
+			expectedTrunc: true,
 		},
 		{
-			name: "Partial",
-			pieces: []CompletionResponse{
-				{Content: "hello"},
-				{Content: "wor"},
-			},
-			stop: "or",
-			expected: []CompletionResponse{
-				{Content: "hello"},
-				{Content: "w"},
-			},
+			name:          "Partial",
+			sequence:      "hellowor",
+			stop:          "or",
+			expected:      "hellow",
 			expectedTrunc: true,
 		},
 		{
-			name: "Suffix",
-			pieces: []CompletionResponse{
-				{Content: "Hello"},
-				{Content: " there"},
-				{Content: "!"},
-			},
-			stop: "!",
-			expected: []CompletionResponse{
-				{Content: "Hello"},
-				{Content: " there"},
-			},
-			expectedTrunc: false,
+			name:          "Suffix",
+			sequence:      "Hello there!",
+			stop:          "!",
+			expected:      "Hello there",
+			expectedTrunc: true,
 		},
 		{
-			name: "Suffix partial",
-			pieces: []CompletionResponse{
-				{Content: "Hello"},
-				{Content: " the"},
-				{Content: "re!"},
-			},
-			stop: "there!",
-			expected: []CompletionResponse{
-				{Content: "Hello"},
-				{Content: " "},
-			},
+			name:          "Middle",
+			sequence:      "hello wor",
+			stop:          "llo w",
+			expected:      "he",
 			expectedTrunc: true,
 		},
 		{
-			name: "Middle",
-			pieces: []CompletionResponse{
-				{Content: "hello"},
-				{Content: " wor"},
-			},
-			stop: "llo w",
-			expected: []CompletionResponse{
-				{Content: "he"},
-			},
-			expectedTrunc: true,
+			name:          "No stop found",
+			sequence:      "hello world",
+			stop:          "xyz",
+			expected:      "hello world",
+			expectedTrunc: false,
 		},
 	}
 
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
-			result, resultTrunc := truncateStop(tt.pieces, tt.stop)
-			if !reflect.DeepEqual(result, tt.expected) || resultTrunc != tt.expectedTrunc {
-				t.Errorf("truncateStop(%v, %s): have %v (%v); want %v (%v)", tt.pieces, tt.stop, result, resultTrunc, tt.expected, tt.expectedTrunc)
+			result, truncated := truncateStop(tt.sequence, tt.stop)
+			if result != tt.expected || truncated != tt.expectedTrunc {
+				t.Errorf("truncateStop(%q, %q): have %q (%v); want %q (%v)",
+					tt.sequence, tt.stop, result, truncated, tt.expected, tt.expectedTrunc)
 			}
 		})
 	}

+ 27 - 11
llm/server.go

@@ -644,12 +644,22 @@ type ImageData struct {
 	AspectRatioID int    `json:"aspect_ratio_id"`
 }
 
+// TokenProbs represents probability information for a token
+type TokenProbs struct {
+	TokenID int     `json:"id"`
+	Logit   float32 `json:"logit"`
+	Prob    float32 `json:"prob"`
+	LogProb float32 `json:"logprob"`
+	Token   string  `json:"token"`
+}
+
 type completion struct {
-	Content      string `json:"content"`
-	Model        string `json:"model"`
-	Prompt       string `json:"prompt"`
-	Stop         bool   `json:"stop"`
-	StoppedLimit bool   `json:"stopped_limit"`
+	Content      string       `json:"content"`
+	Model        string       `json:"model"`
+	Prompt       string       `json:"prompt"`
+	Stop         bool         `json:"stop"`
+	StoppedLimit bool         `json:"stopped_limit"`
+	LogProbs     []TokenProbs `json:"logprobs"`
 
 	Timings struct {
 		PredictedN  int     `json:"predicted_n"`
@@ -660,14 +670,16 @@ type completion struct {
 }
 
 type CompletionRequest struct {
-	Prompt  string
-	Format  json.RawMessage
-	Images  []ImageData
-	Options *api.Options
+	Prompt   string
+	Format   json.RawMessage
+	Images   []ImageData
+	LogProbs int
+	Options  *api.Options
 }
 
 type CompletionResponse struct {
 	Content            string
+	LogProbs           []TokenProbs
 	DoneReason         string
 	Done               bool
 	PromptEvalCount    int
@@ -698,9 +710,12 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 		"seed":              req.Options.Seed,
 		"stop":              req.Options.Stop,
 		"image_data":        req.Images,
+		"logprobs":          req.LogProbs,
 		"cache_prompt":      true,
 	}
 
+	fmt.Println("completion request:", request)
+
 	if len(req.Format) > 0 {
 		switch string(req.Format) {
 		case `null`, `""`:
@@ -796,7 +811,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 				continue
 			}
 
-			// slog.Debug("got line", "line", string(line))
 			evt, ok := bytes.CutPrefix(line, []byte("data: "))
 			if !ok {
 				evt = line
@@ -822,7 +836,8 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 
 			if c.Content != "" {
 				fn(CompletionResponse{
-					Content: c.Content,
+					Content:  c.Content,
+					LogProbs: c.LogProbs,
 				})
 			}
 
@@ -839,6 +854,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 					PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
 					EvalCount:          c.Timings.PredictedN,
 					EvalDuration:       parseDurationMs(c.Timings.PredictedMS),
+					LogProbs:           c.LogProbs,
 				})
 				return nil
 			}

+ 25 - 8
server/routes.go

@@ -293,11 +293,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 		var sb strings.Builder
 		defer close(ch)
 		if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
-			Prompt:  prompt,
-			Images:  images,
-			Format:  req.Format,
-			Options: opts,
+			Prompt:   prompt,
+			Images:   images,
+			Format:   req.Format,
+			LogProbs: req.LogProbs,
+			Options:  opts,
 		}, func(cr llm.CompletionResponse) {
+			fmt.Printf("banana: %#v\n", cr)
 			res := api.GenerateResponse{
 				Model:      req.Model,
 				CreatedAt:  time.Now().UTC(),
@@ -311,6 +313,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 					EvalDuration:       cr.EvalDuration,
 				},
 			}
+			for _, p := range cr.LogProbs {
+				res.LogProbs = append(res.LogProbs, api.TokenProbs{
+					TokenID: p.TokenID,
+					LogProb: p.LogProb,
+					Token:   p.Token,
+				})
+			}
 
 			if _, err := sb.WriteString(cr.Content); err != nil {
 				ch <- gin.H{"error": err.Error()}
@@ -1466,10 +1475,11 @@ func (s *Server) ChatHandler(c *gin.Context) {
 		var sb strings.Builder
 		var toolCallIndex int = 0
 		if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
-			Prompt:  prompt,
-			Images:  images,
-			Format:  req.Format,
-			Options: opts,
+			Prompt:   prompt,
+			Images:   images,
+			Format:   req.Format,
+			LogProbs: req.LogProbs,
+			Options:  opts,
 		}, func(r llm.CompletionResponse) {
 			res := api.ChatResponse{
 				Model:      req.Model,
@@ -1484,6 +1494,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
 					EvalDuration:       r.EvalDuration,
 				},
 			}
+			for _, p := range r.LogProbs {
+				res.LogProbs = append(res.LogProbs, api.TokenProbs{
+					TokenID: p.TokenID,
+					LogProb: p.LogProb,
+					Token:   p.Token,
+				})
+			}
 
 			if r.Done {
 				res.TotalDuration = time.Since(checkpointStart)