Patrick Devine 2 months ago
parent
commit
035e69799e
2 changed files with 13 additions and 24 deletions
  1. 4 4
      ml/backend/ggml/ggml.go
  2. 9 20
      model/process_text_spm.go

+ 4 - 4
ml/backend/ggml/ggml.go

@@ -596,9 +596,9 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
 }
 
 const (
-	ropeTypeNorm C.int = 0
-	ropeTypeNeox C.int = 2
-	ropeTypeMrope C.int = 8
+	ropeTypeNorm   C.int = 0
+	ropeTypeNeox   C.int = 2
+	ropeTypeMrope  C.int = 8
 	ropeTypeVision C.int = 24
 )
 
@@ -617,7 +617,7 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
 			ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
 			C.int(ropeDim),
 			C.int(ropeType),
-			131072,       // YaRN n_ctx_train
+			131072, // YaRN n_ctx_train
 			C.float(ropeBase),
 			C.float(ropeScale),
 			0.,  // YaRN ext_factor

+ 9 - 20
model/process_text_spm.go

@@ -1,11 +1,9 @@
 package model
 
 import (
-	"fmt"
 	"iter"
 	"log/slog"
 	"strings"
-	//"unicode/utf8"
 
 	"github.com/dlclark/regexp2"
 	queue "github.com/emirpasic/gods/queues/priorityqueue"
@@ -24,9 +22,7 @@ type SentencePieceModel struct {
 }
 
 func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
-	fmt.Printf("Tokens (%d): %5s %5s %5s ...\n", len(vocab.Values), vocab.Values[0], vocab.Values[1], vocab.Values[2])
-	fmt.Printf("Scores (%d): %0.3f %0.3f %0.3f ...\n", len(vocab.Scores), vocab.Scores[0], vocab.Scores[1], vocab.Scores[2])
-	fmt.Printf("Types  (%d): %5d %5d %5d ...\n", len(vocab.Types), vocab.Types[0], vocab.Types[1], vocab.Types[2])
+	slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:3], "scores", vocab.Scores[:3], "types", vocab.Types[:3])
 
 	counter := map[int]int{}
 	var maxTokenLen int
@@ -41,13 +37,9 @@ func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
 		}
 	}
 
-	fmt.Printf("Normal: %d\n", counter[TOKEN_TYPE_NORMAL])
-	fmt.Printf("Unknown: %d\n", counter[TOKEN_TYPE_UNKNOWN])
-	fmt.Printf("Control: %d\n", counter[TOKEN_TYPE_CONTROL])
-	fmt.Printf("User Defined: %d\n", counter[TOKEN_TYPE_USER_DEFINED])
-	fmt.Printf("Unused: %d\n", counter[TOKEN_TYPE_UNUSED])
-	fmt.Printf("Byte: %d\n", counter[TOKEN_TYPE_BYTE])
-	fmt.Printf("Max token len: %d\n", maxTokenLen)
+	slog.Debug("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
+		"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
+		"max token len", maxTokenLen)
 
 	return SentencePieceModel{
 		maxTokenLen: maxTokenLen,
@@ -98,7 +90,7 @@ func (spm SentencePieceModel) Encode(s string) ([]int32, error) {
 			fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
 		}
 	}
-	fmt.Printf("frags = %#v\n", fragments)
+	slog.Debug("fragments", "frags", fragments)
 
 	var ids []int32
 	for _, frag := range fragments {
@@ -135,8 +127,6 @@ func (spm SentencePieceModel) Encode(s string) ([]int32, error) {
 					runes: []rune{runes[r]},
 				}
 			}
-			fmt.Printf("remaining runes = %#v\n", runes)
-			fmt.Printf("merges = %#v\n", merges)
 
 			pairwise := func(a, b int) *candidate {
 				if a < 0 || b >= len(runes) {
@@ -144,7 +134,6 @@ func (spm SentencePieceModel) Encode(s string) ([]int32, error) {
 				}
 
 				left, right := string(merges[a].runes), string(merges[b].runes)
-				fmt.Printf("looking up '%s'\n", left+right)
 				if id := spm.vocab.Encode(left + right); id >= 0 {
 					return &candidate{
 						a:      a,
@@ -165,7 +154,7 @@ func (spm SentencePieceModel) Encode(s string) ([]int32, error) {
 			pqv := pq.Values()
 			for _, v := range pqv {
 				e := v.(*candidate)
-				fmt.Printf("candidate = %#v\n", e)
+				slog.Debug("candidate", "candidate", e)
 			}
 
 			for !pq.Empty() {
@@ -193,21 +182,21 @@ func (spm SentencePieceModel) Encode(s string) ([]int32, error) {
 				}
 			}
 
-			fmt.Printf("merges = %#v\n", merges)
+			slog.Debug("merges", "merges", merges)
 
 			for _, merge := range merges {
 				if len(merge.runes) > 0 {
 					if id := spm.vocab.Encode(string(merge.runes)); id >= 0 {
 						ids = append(ids, id)
 					} else {
-						fmt.Printf("!!! missing token for '%s'\n", string(merge.runes))
+						slog.Debug("missing token", "token", string(merge.runes))
 					}
 				}
 			}
 		}
 
 	}
-	fmt.Printf("tokens = %#v\n", ids)
+	slog.Debug("encoded", "ids", ids)
 
 	return ids, nil
 }