Pārlūkot izejas kodu

WIP but got logits n stuff

ParthSareen 4 mēneši atpakaļ
vecāks
revīzija
c92d418a7c
5 mainītis faili ar 114 papildinājumiem un 47 dzēšanām
  1. 16 0
      api/types.go
  2. 13 11
      llama/llama.go
  3. 56 19
      llama/runner/runner.go
  4. 15 9
      llm/server.go
  5. 14 8
      server/routes.go

+ 16 - 0
api/types.go

@@ -80,6 +80,8 @@ type GenerateRequest struct {
 	// Options lists model-specific options. For example, temperature can be
 	// Options lists model-specific options. For example, temperature can be
 	// set through this field, if the model supports it.
 	// set through this field, if the model supports it.
 	Options map[string]interface{} `json:"options"`
 	Options map[string]interface{} `json:"options"`
+
+	ReturnLogits bool `json:"return_logits,omitempty"`
 }
 }
 
 
 // ChatRequest describes a request sent by [Client.Chat].
 // ChatRequest describes a request sent by [Client.Chat].
@@ -105,6 +107,8 @@ type ChatRequest struct {
 
 
 	// Options lists model-specific options.
 	// Options lists model-specific options.
 	Options map[string]interface{} `json:"options"`
 	Options map[string]interface{} `json:"options"`
+
+	ReturnLogits bool `json:"return_logits,omitempty"`
 }
 }
 
 
 type Tools []Tool
 type Tools []Tool
@@ -189,6 +193,7 @@ type ChatResponse struct {
 	CreatedAt  time.Time `json:"created_at"`
 	CreatedAt  time.Time `json:"created_at"`
 	Message    Message   `json:"message"`
 	Message    Message   `json:"message"`
 	DoneReason string    `json:"done_reason,omitempty"`
 	DoneReason string    `json:"done_reason,omitempty"`
+	Logits     []float32 `json:"logits"`
 
 
 	Done bool `json:"done"`
 	Done bool `json:"done"`
 
 
@@ -204,6 +209,15 @@ type Metrics struct {
 	EvalDuration       time.Duration `json:"eval_duration,omitempty"`
 	EvalDuration       time.Duration `json:"eval_duration,omitempty"`
 }
 }
 
 
+type TokenLogprob struct {
+	Token   string  `json:"token"`
+	Logprob float32 `json:"logprob"`
+}
+
+type LogProbs struct {
+	TopLogprobs []TokenLogprob `json:"top_logprobs"`
+}
+
 // Options specified in [GenerateRequest].  If you add a new option here, also
 // Options specified in [GenerateRequest].  If you add a new option here, also
 // add it to the API docs.
 // add it to the API docs.
 type Options struct {
 type Options struct {
@@ -450,6 +464,8 @@ type GenerateResponse struct {
 	Context []int `json:"context,omitempty"`
 	Context []int `json:"context,omitempty"`
 
 
 	Metrics
 	Metrics
+
+	Logits []float32 `json:"logits"`
 }
 }
 
 
 // ModelDetails provides details about a model.
 // ModelDetails provides details about a model.

+ 13 - 11
llama/llama.go

@@ -260,6 +260,19 @@ func (c *Context) GetEmbeddingsIth(i int) []float32 {
 	return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
 	return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
 }
 }
 
 
+// GetLogits returns the logits from the last decode operation.
+// The returned slice has length equal to the vocabulary size.
+func (c *Context) GetLogits() []float32 {
+	logits := unsafe.Pointer(C.llama_get_logits(c.c))
+	if logits == nil {
+		return nil
+	}
+
+	// Get the number of vocabulary tokens to determine array size
+	vocabSize := c.Model().NumVocab()
+	return unsafe.Slice((*float32)(logits), vocabSize)
+}
+
 type ModelParams struct {
 type ModelParams struct {
 	NumGpuLayers int
 	NumGpuLayers int
 	MainGpu      int
 	MainGpu      int
@@ -737,14 +750,3 @@ func SchemaToGrammar(schema []byte) []byte {
 	}
 	}
 	return buf[:n]
 	return buf[:n]
 }
 }
-
-// GetLogits returns the logits from the last decode operation.
-// The returned slice has length equal to the vocabulary size.
-func (c *Context) GetLogits() []float32 {
-	logits := unsafe.Pointer(C.llama_get_logits(c.c))
-	if logits == nil {
-		return nil
-	}
-
-	// Get the number of vocabulary tokens to determine array size
-	vocabSize := c.Model().NumVocab()

+ 56 - 19
llama/runner/runner.go

@@ -8,12 +8,14 @@ import (
 	"fmt"
 	"fmt"
 	"log"
 	"log"
 	"log/slog"
 	"log/slog"
+	"math"
 	"net"
 	"net"
 	"net/http"
 	"net/http"
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
 	"regexp"
 	"regexp"
 	"runtime"
 	"runtime"
+	"sort"
 	"strconv"
 	"strconv"
 	"strings"
 	"strings"
 	"sync"
 	"sync"
@@ -59,7 +61,7 @@ type Sequence struct {
 	crossAttention bool
 	crossAttention bool
 
 
 	// channel to send responses over
 	// channel to send responses over
-	responses chan string
+	responses chan CompletionResponse
 
 
 	// channel to stop decoding (such as if the remote connection is closed)
 	// channel to stop decoding (such as if the remote connection is closed)
 	quit chan bool
 	quit chan bool
@@ -88,6 +90,15 @@ type Sequence struct {
 	startGenerationTime time.Time
 	startGenerationTime time.Time
 	numDecoded          int
 	numDecoded          int
 	numPromptInputs     int
 	numPromptInputs     int
+
+	// New flag we need to add to Sequence struct
+	returnLogits bool
+
+	// Using our new GetLogits() method
+	logits []float32
+
+	// Add new channel for logits
+	logitsOut chan []float32
 }
 }
 
 
 type NewSequenceParams struct {
 type NewSequenceParams struct {
@@ -96,6 +107,7 @@ type NewSequenceParams struct {
 	numKeep        int
 	numKeep        int
 	samplingParams *llama.SamplingParams
 	samplingParams *llama.SamplingParams
 	embedding      bool
 	embedding      bool
+	returnLogits   bool
 }
 }
 
 
 func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
 func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequenceParams) (*Sequence, error) {
@@ -149,13 +161,15 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen
 		startProcessingTime: startTime,
 		startProcessingTime: startTime,
 		numPredict:          params.numPredict,
 		numPredict:          params.numPredict,
 		pendingResponses:    make([]string, 0),
 		pendingResponses:    make([]string, 0),
-		responses:           make(chan string, 100),
+		responses:           make(chan CompletionResponse, 100),
 		quit:                make(chan bool, 1),
 		quit:                make(chan bool, 1),
 		embedding:           make(chan []float32, 1),
 		embedding:           make(chan []float32, 1),
 		samplingCtx:         sc,
 		samplingCtx:         sc,
 		embeddingOnly:       params.embedding,
 		embeddingOnly:       params.embedding,
 		stop:                params.stop,
 		stop:                params.stop,
 		numKeep:             params.numKeep,
 		numKeep:             params.numKeep,
+		returnLogits:        params.returnLogits,
+		logitsOut:           make(chan []float32, 100),
 	}, nil
 	}, nil
 }
 }
 
 
@@ -274,25 +288,36 @@ func (s *Server) allNil() bool {
 }
 }
 
 
 func flushPending(seq *Sequence) bool {
 func flushPending(seq *Sequence) bool {
-	joined := strings.Join(seq.pendingResponses, "")
-	seq.pendingResponses = []string{}
+	if len(seq.pendingResponses) == 0 {
+		return true
+	}
 
 
+	content := strings.Join(seq.pendingResponses, "")
 	// Check if there are any partial UTF-8 characters remaining.
 	// Check if there are any partial UTF-8 characters remaining.
 	// We already check and queue as we are generating but some may
 	// We already check and queue as we are generating but some may
 	// still make it here:
 	// still make it here:
 	// - Sequence is ending, e.g. generation limit has been hit
 	// - Sequence is ending, e.g. generation limit has been hit
 	// - Invalid characters in the middle of a string
 	// - Invalid characters in the middle of a string
 	// This is a stricter check to ensure we never output invalid Unicode.
 	// This is a stricter check to ensure we never output invalid Unicode.
-	for !utf8.ValidString(joined) {
-		joined = joined[:len(joined)-1]
+	for !utf8.ValidString(content) {
+		content = content[:len(content)-1]
 	}
 	}
+	seq.pendingResponses = nil
 
 
-	if len(joined) == 0 {
-		return true
+	resp := CompletionResponse{
+		Content: content,
 	}
 	}
 
 
+	// Add logits if requested and available
+	if seq.returnLogits && seq.logits != nil {
+		slog.Info("returning logits - flushPending")
+		resp.Logits = seq.logits
+		seq.logits = nil
+	}
+
+	slog.Info("returning logits - flushPending", "logits", resp.Logits[0])
 	select {
 	select {
-	case seq.responses <- joined:
+	case seq.responses <- resp:
 		return true
 		return true
 	case <-seq.quit:
 	case <-seq.quit:
 		return false
 		return false
@@ -476,7 +501,14 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
 			continue
 			continue
 		}
 		}
 
 
-		// sample a token
+		// Before sampling:
+		if seq.returnLogits { // New flag we need to add to Sequence struct
+			slog.Info("returning logits")
+			seq.logits = s.lc.GetLogits() // Using our new GetLogits() method
+
+		}
+
+		// Then sample token
 		token := seq.samplingCtx.Sample(s.lc, seq.iBatch)
 		token := seq.samplingCtx.Sample(s.lc, seq.iBatch)
 		seq.samplingCtx.Accept(token, true)
 		seq.samplingCtx.Accept(token, true)
 		piece := s.model.TokenToPiece(token)
 		piece := s.model.TokenToPiece(token)
@@ -572,10 +604,11 @@ type ImageData struct {
 }
 }
 
 
 type CompletionRequest struct {
 type CompletionRequest struct {
-	Prompt      string      `json:"prompt"`
-	Images      []ImageData `json:"image_data"`
-	Grammar     string      `json:"grammar"`
-	CachePrompt bool        `json:"cache_prompt"`
+	Prompt       string      `json:"prompt"`
+	Images       []ImageData `json:"image_data"`
+	Grammar      string      `json:"grammar"`
+	CachePrompt  bool        `json:"cache_prompt"`
+	ReturnLogits bool        `json:"return_logits"`
 
 
 	Options
 	Options
 }
 }
@@ -588,8 +621,10 @@ type Timings struct {
 }
 }
 
 
 type CompletionResponse struct {
 type CompletionResponse struct {
-	Content string `json:"content"`
-	Stop    bool   `json:"stop"`
+	Content string    `json:"content"`
+	Logits  []float32 `json:"logits,omitempty"`
+	Tokens  []string  `json:"tokens,omitempty"`
+	Stop    bool      `json:"stop"`
 
 
 	Model        string  `json:"model,omitempty"`
 	Model        string  `json:"model,omitempty"`
 	Prompt       string  `json:"prompt,omitempty"`
 	Prompt       string  `json:"prompt,omitempty"`
@@ -637,12 +672,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 	samplingParams.Seed = uint32(req.Seed)
 	samplingParams.Seed = uint32(req.Seed)
 	samplingParams.Grammar = req.Grammar
 	samplingParams.Grammar = req.Grammar
 
 
+	slog.Info("completion request", "return_logits", req.ReturnLogits)
 	seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
 	seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
 		numPredict:     req.NumPredict,
 		numPredict:     req.NumPredict,
 		stop:           req.Stop,
 		stop:           req.Stop,
 		numKeep:        req.NumKeep,
 		numKeep:        req.NumKeep,
 		samplingParams: &samplingParams,
 		samplingParams: &samplingParams,
 		embedding:      false,
 		embedding:      false,
+		returnLogits:   req.ReturnLogits,
 	})
 	})
 	if err != nil {
 	if err != nil {
 		http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
 		http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
@@ -691,10 +728,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
 			close(seq.quit)
 			close(seq.quit)
 			return
 			return
 		case content, ok := <-seq.responses:
 		case content, ok := <-seq.responses:
+			slog.Info("logits in last chan", "content", content.Logits[0])
 			if ok {
 			if ok {
-				if err := json.NewEncoder(w).Encode(&CompletionResponse{
-					Content: content,
-				}); err != nil {
+				slog.Info("content", "content", content.Content)
+				if err := json.NewEncoder(w).Encode(&content); err != nil {
 					http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
 					http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
 					close(seq.quit)
 					close(seq.quit)
 					return
 					return

+ 15 - 9
llm/server.go

@@ -642,11 +642,12 @@ type ImageData struct {
 }
 }
 
 
 type completion struct {
 type completion struct {
-	Content      string `json:"content"`
-	Model        string `json:"model"`
-	Prompt       string `json:"prompt"`
-	Stop         bool   `json:"stop"`
-	StoppedLimit bool   `json:"stopped_limit"`
+	Content      string    `json:"content"`
+	Model        string    `json:"model"`
+	Prompt       string    `json:"prompt"`
+	Stop         bool      `json:"stop"`
+	StoppedLimit bool      `json:"stopped_limit"`
+	Logits       []float32 `json:"logits,omitempty"`
 
 
 	Timings struct {
 	Timings struct {
 		PredictedN  int     `json:"predicted_n"`
 		PredictedN  int     `json:"predicted_n"`
@@ -657,10 +658,11 @@ type completion struct {
 }
 }
 
 
 type CompletionRequest struct {
 type CompletionRequest struct {
-	Prompt  string
-	Format  json.RawMessage
-	Images  []ImageData
-	Options *api.Options
+	Prompt       string
+	Format       json.RawMessage
+	Images       []ImageData
+	Options      *api.Options
+	ReturnLogits bool
 }
 }
 
 
 type CompletionResponse struct {
 type CompletionResponse struct {
@@ -671,6 +673,7 @@ type CompletionResponse struct {
 	PromptEvalDuration time.Duration
 	PromptEvalDuration time.Duration
 	EvalCount          int
 	EvalCount          int
 	EvalDuration       time.Duration
 	EvalDuration       time.Duration
+	Logits             []float32
 }
 }
 
 
 func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
 func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
@@ -696,6 +699,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 		"seed":              req.Options.Seed,
 		"seed":              req.Options.Seed,
 		"stop":              req.Options.Stop,
 		"stop":              req.Options.Stop,
 		"image_data":        req.Images,
 		"image_data":        req.Images,
+		"return_logits":     req.ReturnLogits,
 		"cache_prompt":      true,
 		"cache_prompt":      true,
 	}
 	}
 
 
@@ -821,6 +825,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 			if c.Content != "" {
 			if c.Content != "" {
 				fn(CompletionResponse{
 				fn(CompletionResponse{
 					Content: c.Content,
 					Content: c.Content,
+					Logits:  c.Logits,
 				})
 				})
 			}
 			}
 
 
@@ -837,6 +842,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
 					PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
 					PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
 					EvalCount:          c.Timings.PredictedN,
 					EvalCount:          c.Timings.PredictedN,
 					EvalDuration:       parseDurationMs(c.Timings.PredictedMS),
 					EvalDuration:       parseDurationMs(c.Timings.PredictedMS),
+					Logits:             c.Logits,
 				})
 				})
 				return nil
 				return nil
 			}
 			}

+ 14 - 8
server/routes.go

@@ -295,10 +295,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 		var sb strings.Builder
 		var sb strings.Builder
 		defer close(ch)
 		defer close(ch)
 		if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
 		if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
-			Prompt:  prompt,
-			Images:  images,
-			Format:  req.Format,
-			Options: opts,
+			Prompt:       prompt,
+			Images:       images,
+			Format:       req.Format,
+			Options:      opts,
+			ReturnLogits: req.ReturnLogits,
 		}, func(cr llm.CompletionResponse) {
 		}, func(cr llm.CompletionResponse) {
 			res := api.GenerateResponse{
 			res := api.GenerateResponse{
 				Model:      req.Model,
 				Model:      req.Model,
@@ -312,6 +313,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
 					EvalCount:          cr.EvalCount,
 					EvalCount:          cr.EvalCount,
 					EvalDuration:       cr.EvalDuration,
 					EvalDuration:       cr.EvalDuration,
 				},
 				},
+				Logits: cr.Logits,
 			}
 			}
 
 
 			if _, err := sb.WriteString(cr.Content); err != nil {
 			if _, err := sb.WriteString(cr.Content); err != nil {
@@ -1541,16 +1543,19 @@ func (s *Server) ChatHandler(c *gin.Context) {
 
 
 	slog.Debug("chat request", "images", len(images), "prompt", prompt)
 	slog.Debug("chat request", "images", len(images), "prompt", prompt)
 
 
+	slog.Info("chat request", "return_logits", req.ReturnLogits)
+
 	ch := make(chan any)
 	ch := make(chan any)
 	go func() {
 	go func() {
 		defer close(ch)
 		defer close(ch)
 		var sb strings.Builder
 		var sb strings.Builder
 		var toolCallIndex int = 0
 		var toolCallIndex int = 0
 		if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
 		if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
-			Prompt:  prompt,
-			Images:  images,
-			Format:  req.Format,
-			Options: opts,
+			Prompt:       prompt,
+			Images:       images,
+			Format:       req.Format,
+			Options:      opts,
+			ReturnLogits: true,
 		}, func(r llm.CompletionResponse) {
 		}, func(r llm.CompletionResponse) {
 			res := api.ChatResponse{
 			res := api.ChatResponse{
 				Model:      req.Model,
 				Model:      req.Model,
@@ -1558,6 +1563,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
 				Message:    api.Message{Role: "assistant", Content: r.Content},
 				Message:    api.Message{Role: "assistant", Content: r.Content},
 				Done:       r.Done,
 				Done:       r.Done,
 				DoneReason: r.DoneReason,
 				DoneReason: r.DoneReason,
+				Logits:     r.Logits,
 				Metrics: api.Metrics{
 				Metrics: api.Metrics{
 					PromptEvalCount:    r.PromptEvalCount,
 					PromptEvalCount:    r.PromptEvalCount,
 					PromptEvalDuration: r.PromptEvalDuration,
 					PromptEvalDuration: r.PromptEvalDuration,