Jelajahi Sumber

logging: add a new customer logger and trace method

This change addresses over logging with debug in the SPM tokenizer by
adding a trace level to slog.
Patrick Devine 1 bulan lalu
induk
melakukan
73a1e99f8a
2 mengubah file dengan 57 tambahan dan 14 penghapusan
  1. 40 0
      logging/log.go
  2. 17 14
      model/process_text_spm.go

+ 40 - 0
logging/log.go

@@ -0,0 +1,40 @@
+package logging
+
+import (
+	"context"
+	"log/slog"
+	"os"
+)
+
+const LevelTrace slog.Level = slog.LevelDebug - 4
+
+type Logger struct {
+	logger *slog.Logger
+}
+
+func NewLogger() *Logger {
+	handler := slog.NewTextHandler(os.Stdout, nil)
+	return &Logger{
+		logger: slog.New(handler),
+	}
+}
+
+func (l *Logger) Trace(msg string, args ...any) {
+	l.logger.Log(context.Background(), LevelTrace, msg, args...)
+}
+
+func (l *Logger) Debug(msg string, args ...any) {
+	l.logger.Debug(msg, args...)
+}
+
+func (l *Logger) Info(msg string, args ...any) {
+	l.logger.Info(msg, args...)
+}
+
+func (l *Logger) Warn(msg string, args ...any) {
+	l.logger.Warn(msg, args...)
+}
+
+func (l *Logger) Error(msg string, args ...any) {
+	l.logger.Error(msg, args...)
+}

+ 17 - 14
model/process_text_spm.go

@@ -2,15 +2,18 @@ package model
 
 import (
 	"iter"
-	"log/slog"
 	"strings"
 
 	"github.com/dlclark/regexp2"
 	queue "github.com/emirpasic/gods/v2/queues/priorityqueue"
+
+	"github.com/ollama/ollama/logging"
 )
 
 const spmWhitespaceSep = "▁"
 
+var log = logging.NewLogger()
+
 func replaceWhitespaceBySeperator(s string) string {
 	return strings.ReplaceAll(s, " ", spmWhitespaceSep)
 }
@@ -24,7 +27,7 @@ type SentencePieceModel struct {
 var _ TextProcessor = (*SentencePieceModel)(nil)
 
 func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
-	slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
+	log.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
 
 	counter := map[int]int{}
 	var maxTokenLen int
@@ -38,7 +41,7 @@ func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
 		}
 	}
 
-	slog.Debug("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
+	log.Debug("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
 		"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
 		"max token len", maxTokenLen)
 
@@ -91,7 +94,7 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
 			fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
 		}
 	}
-	slog.Debug("fragments", "frags", fragments)
+	log.Trace("fragments", "frags", fragments)
 
 	var ids []int32
 	for _, frag := range fragments {
@@ -129,7 +132,7 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
 				}
 			}
 
-			slog.Debug("tokenizer", "merges", merges)
+			log.Trace("tokenizer", "merges", merges)
 
 			pairwise := func(a, b int) *candidate {
 				if a < 0 || b >= len(runes) {
@@ -156,7 +159,7 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
 			pqv := pq.Values()
 			for _, v := range pqv {
 				e := v.(*candidate)
-				slog.Debug("candidate", "candidate", e)
+				log.Trace("candidate", "candidate", e)
 			}
 
 			for !pq.Empty() {
@@ -164,7 +167,7 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
 				pair := v.(*candidate)
 				left, right := merges[pair.a], merges[pair.b]
 
-				slog.Debug("pair", "left", left, "right", right)
+				log.Trace("pair", "left", left, "right", right)
 				if len(left.runes) == 0 || len(right.runes) == 0 {
 					continue
 				}
@@ -189,14 +192,14 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
 				}
 			}
 
-			slog.Debug("merges", "merges", merges)
+			log.Trace("merges", "merges", merges)
 
 			for _, merge := range merges {
 				if len(merge.runes) > 0 {
 					if id := spm.vocab.Encode(string(merge.runes)); id >= 0 {
 						ids = append(ids, id)
 					} else {
-						slog.Debug("missing token", "token", string(merge.runes))
+						log.Error("missing token", "token", string(merge.runes))
 					}
 				}
 			}
@@ -206,19 +209,19 @@ func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error)
 	if addSpecial && len(ids) > 0 {
 		if spm.vocab.AddBOS {
 			if ids[0] == spm.vocab.BOS {
-				slog.Warn("adding bos token to prompt which already has it", "id", spm.vocab.BOS)
+				log.Warn("adding bos token to prompt which already has it", "id", spm.vocab.BOS)
 			}
 
-			slog.Debug("adding bos token to prompt", "id", spm.vocab.BOS)
+			log.Debug("adding bos token to prompt", "id", spm.vocab.BOS)
 			ids = append([]int32{spm.vocab.BOS}, ids...)
 		}
 
 		if spm.vocab.AddEOS {
 			if ids[len(ids)-1] == spm.vocab.EOS {
-				slog.Warn("adding eos token to prompt which already has it", "id", spm.vocab.EOS)
+				log.Warn("adding eos token to prompt which already has it", "id", spm.vocab.EOS)
 			}
 
-			slog.Debug("adding eos token to prompt", "id", spm.vocab.EOS)
+			log.Debug("adding eos token to prompt", "id", spm.vocab.EOS)
 			ids = append(ids, spm.vocab.EOS)
 		}
 	}
@@ -241,6 +244,6 @@ func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
 		}
 	}
 
-	slog.Debug("decoded", "ids", ids, "text", sb.String())
+	log.Debug("decoded", "ids", ids, "text", sb.String())
 	return sb.String(), nil
 }