ParthSareen 4 mesiacov pred
rodič
commit
d7e7e6a01e
2 zmenil súbory, kde vykonal 84 pridanie a 0 odobranie
  1. 11 0
      llama/llama.go
  2. 73 0
      llama/runner/runner.go

+ 11 - 0
llama/llama.go

@@ -737,3 +737,14 @@ func SchemaToGrammar(schema []byte) []byte {
 	}
 	}
 	return buf[:n]
 	return buf[:n]
 }
 }
+
+// GetLogits returns the logits from the last decode operation.
+// The returned slice has length equal to the vocabulary size.
+func (c *Context) GetLogits() []float32 {
+	logits := unsafe.Pointer(C.llama_get_logits(c.c))
+	if logits == nil {
+		return nil
+	}
+
+	// Get the number of vocabulary tokens to determine array size
+	vocabSize := c.Model().NumVocab()

+ 73 - 0
llama/runner/runner.go

@@ -1003,3 +1003,76 @@ func Execute(args []string) error {
 	cancel()
 	cancel()
 	return nil
 	return nil
 }
 }
+
+// Helper function to get top K logits and convert to log probabilities
+func getTopLogits(logits []float32, k int, model *llama.Model) []api.LogProbs {
+	if k <= 0 {
+		return nil
+	}
+
+	// Convert logits to probabilities using softmax
+	probs := softmax(logits)
+
+	// Create slice of index/probability pairs
+	pairs := make([]struct {
+		token int
+		prob  float32
+	}, len(probs))
+
+	for i, p := range probs {
+		pairs[i] = struct {
+			token int
+			prob  float32
+		}{i, p}
+	}
+
+	// Sort by probability (descending)
+	sort.Slice(pairs, func(i, j int) bool {
+		return pairs[i].prob > pairs[j].prob
+	})
+
+	// Take top K
+	k = min(k, len(pairs))
+	result := make([]api.LogProbs, k)
+
+	for i := 0; i < k; i++ {
+		result[i] = api.LogProbs{
+			TopLogprobs: []api.TokenLogprob{
+				{
+					Token:   model.TokenToPiece(pairs[i].token),
+					Logprob: float32(math.Log(float64(pairs[i].prob))),
+				},
+			},
+		}
+	}
+
+	return result
+}
+
+// Helper function to compute softmax
+func softmax(logits []float32) []float32 {
+	probs := make([]float32, len(logits))
+
+	// Find max for numerical stability
+	max := float32(math.Inf(-1))
+	for _, l := range logits {
+		if l > max {
+			max = l
+		}
+	}
+
+	// Compute exp(x - max) and sum
+	sum := float32(0)
+	for i, l := range logits {
+		ex := float32(math.Exp(float64(l - max)))
+		probs[i] = ex
+		sum += ex
+	}
+
+	// Normalize
+	for i := range probs {
+		probs[i] /= sum
+	}
+
+	return probs
+}