Explorar el Código

llama3 conversion

Patrick Devine hace 1 año
padre
commit
c8cf0d94ed
Se han modificado 3 ficheros con 56 adiciones y 16 borrados
  1. 1 0
      convert/convert.go
  2. 54 16
      convert/llama.go
  3. 1 0
      llm/gguf.go

+ 1 - 0
convert/convert.go

@@ -93,6 +93,7 @@ type Vocab struct {
 	Tokens []string
 	Scores []float32
 	Types  []int32
+	Merges []string
 }
 
 func LoadSentencePieceTokens(dirpath string, params *Params) (*Vocab, error) {

+ 54 - 16
convert/llama.go

@@ -5,6 +5,8 @@ import (
 	"fmt"
 	"io"
 	"log/slog"
+	"os"
+	"path/filepath"
 	"regexp"
 	"strings"
 
@@ -105,12 +107,12 @@ func (m *LlamaModel) GetTensors() error {
 		matches := re.FindAllStringSubmatch(l.Name, -1)
 		if len(matches) > 0 {
 			slog.Debug(fmt.Sprintf("setting handler for: %s", l.Name))
-			switch l.WriterTo.(type) {
-			case torchWriterTo:
+			switch m.Format.(type) {
+			case *TorchFormat:
 				wt := l.WriterTo.(torchWriterTo)
 				wt.handler = llamaTorchLayerHandler
 				l.WriterTo = wt
-			case safetensorWriterTo:
+			case *SafetensorFormat:
 				wt := l.WriterTo.(safetensorWriterTo)
 				wt.handler = mistralLayerHandler
 				l.WriterTo = wt
@@ -123,18 +125,46 @@ func (m *LlamaModel) GetTensors() error {
 }
 
 func (m *LlamaModel) LoadVocab() error {
-	var v *Vocab
-	var err error
-
-	slog.Debug("loading vocab")
-	v, err = LoadSentencePieceTokens(m.Path, m.Params)
-	if err != nil {
-		return err
+	v := &Vocab{
+		Tokens: []string{},
+		Types:  []int32{},
+		Merges: []string{},
 	}
 
-	slog.Debug("vocab loaded")
+	tokpath := filepath.Join(m.Path, "tokenizer.json")
+	slog.Debug(fmt.Sprintf("looking for %s", tokpath))
+	if _, err := os.Stat(tokpath); !os.IsNotExist(err) {
+		t, err := newTokenizer(tokpath)
+		if err != nil {
+			return err
+		}
+
+		for _, tok := range t.Model.Tokens {
+			v.Tokens = append(v.Tokens, tok.Content)
+			var tokType int32
+			switch {
+			case tok.Special:
+				tokType = 3
+			case tok.UserDefined:
+				tokType = 4
+			default:
+				tokType = 1
+			}
+			v.Types = append(v.Types, tokType)
+		}
+		v.Merges = t.Model.Merges
+	} else {
+		slog.Debug("loading sentence piece vocab")
+		v, err = LoadSentencePieceTokens(m.Path, m.Params)
+		if err != nil {
+			return err
+		}
+
+		slog.Debug("vocab loaded")
 
+	}
 	m.Vocab = v
+
 	return nil
 }
 
@@ -147,22 +177,30 @@ func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error {
 		"llama.embedding_length":                 uint32(m.Params.HiddenSize),
 		"llama.block_count":                      uint32(m.Params.HiddenLayers),
 		"llama.feed_forward_length":              uint32(m.Params.IntermediateSize),
+		"llama.rope.freq_base":                   float32(m.Params.RopeFrequencyBase),
 		"llama.rope.dimension_count":             uint32(m.Params.HiddenSize / m.Params.AttentionHeads),
 		"llama.attention.head_count":             uint32(m.Params.AttentionHeads),
 		"llama.attention.head_count_kv":          uint32(m.Params.KeyValHeads),
 		"llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
-		"general.file_type":                      uint32(1),
-		"tokenizer.ggml.model":                   "llama",
+		//"general.file_type":                      uint32(1),
+		"general.file_type": uint32(2),
+		//"tokenizer.ggml.model":                   "llama",
+		"tokenizer.ggml.model": "gpt2",
 
 		"tokenizer.ggml.tokens":     m.Vocab.Tokens,
-		"tokenizer.ggml.scores":     m.Vocab.Scores,
 		"tokenizer.ggml.token_type": m.Vocab.Types,
 
 		"tokenizer.ggml.bos_token_id":     uint32(m.Params.BoSTokenID),
 		"tokenizer.ggml.eos_token_id":     uint32(m.Params.EoSTokenID),
 		"tokenizer.ggml.unknown_token_id": uint32(0),
-		"tokenizer.ggml.add_bos_token":    true,
-		"tokenizer.ggml.add_eos_token":    false,
+		//"tokenizer.ggml.add_bos_token":    true,
+		//"tokenizer.ggml.add_eos_token":    false,
+	}
+
+	if len(m.Vocab.Merges) > 0 {
+		kv["tokenizer.ggml.merges"] = m.Vocab.Merges
+	} else {
+		kv["tokenizer.ggml.scores"] = m.Vocab.Scores
 	}
 
 	return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)

+ 1 - 0
llm/gguf.go

@@ -483,6 +483,7 @@ var ggufKVOrder = map[string][]string{
 		"tokenizer.ggml.model",
 		"tokenizer.ggml.tokens",
 		"tokenizer.ggml.scores",
+		"tokenizer.ggml.merges",
 		"tokenizer.ggml.token_type",
 		"tokenizer.ggml.bos_token_id",
 		"tokenizer.ggml.eos_token_id",