Sfoglia il codice sorgente

Simplify model conversion (#3422)

Patrick Devine 1 anno fa
parent
commit
3b6a9154dd
4 ha cambiato i file con 364 aggiunte e 249 eliminazioni
  1. 48 235
      convert/convert.go
  2. 136 0
      convert/gemma.go
  3. 174 0
      convert/mistral.go
  4. 6 14
      server/images.go

+ 48 - 235
convert/convert.go

@@ -12,12 +12,9 @@ import (
 	"path/filepath"
 	"regexp"
 	"slices"
-	"strings"
 
 	"github.com/d4l3k/go-bfloat16"
 	"github.com/mitchellh/mapstructure"
-	"github.com/pdevine/tensor"
-	"github.com/pdevine/tensor/native"
 	"github.com/x448/float16"
 	"google.golang.org/protobuf/proto"
 
@@ -55,6 +52,20 @@ type MetaData struct {
 	Offsets []int  `mapstructure:"data_offsets"`
 }
 
+type ModelArch interface {
+	GetTensors() error
+	LoadVocab() error
+	WriteGGUF() (string, error)
+}
+
+type ModelData struct {
+	Path    string
+	Name    string
+	Params  *Params
+	Vocab   *Vocab
+	Tensors []llm.Tensor
+}
+
 func ReadSafeTensors(fn string, offset uint64, params *Params) ([]llm.Tensor, uint64, error) {
 	f, err := os.Open(fn)
 	if err != nil {
@@ -132,15 +143,13 @@ func ReadSafeTensors(fn string, offset uint64, params *Params) ([]llm.Tensor, ui
 		}
 
 		t.WriterTo = safetensorWriterTo{
-			t:           &t,
-			params:      params,
-			bo:          params.ByteOrder,
-			headCount:   uint32(params.AttentionHeads),
-			headCountKV: uint32(params.KeyValHeads),
-			filename:    fn,
-			start:       uint64(data.Offsets[0]),
-			end:         uint64(data.Offsets[1]),
-			padding:     8 + jsonSize,
+			t:        &t,
+			params:   params,
+			bo:       params.ByteOrder,
+			filename: fn,
+			start:    uint64(data.Offsets[0]),
+			end:      uint64(data.Offsets[1]),
+			padding:  8 + jsonSize,
 		}
 
 		slog.Debug(fmt.Sprintf("%v", t))
@@ -198,7 +207,7 @@ type Vocab struct {
 	Types  []int32
 }
 
-func LoadTokens(dirpath string, params *Params) (*Vocab, error) {
+func LoadSentencePieceTokens(dirpath string, vocabSize int) (*Vocab, error) {
 	slog.Info(fmt.Sprintf("reading vocab from %s", filepath.Join(dirpath, "tokenizer.model")))
 	in, err := os.ReadFile(filepath.Join(dirpath, "tokenizer.model"))
 	if err != nil {
@@ -278,8 +287,8 @@ func LoadTokens(dirpath string, params *Params) (*Vocab, error) {
 	}
 	slog.Info(fmt.Sprintf("vocab size w/ extra tokens: %d", len(v.Tokens)))
 
-	if params.VocabSize > len(v.Tokens) {
-		missingTokens := params.VocabSize - len(v.Tokens)
+	if vocabSize > len(v.Tokens) {
+		missingTokens := vocabSize - len(v.Tokens)
 		slog.Warn(fmt.Sprintf("vocab is missing %d tokens", missingTokens))
 		for cnt := 0; cnt < missingTokens; cnt++ {
 			v.Tokens = append(v.Tokens, fmt.Sprintf("<dummy%05d>", cnt+1))
@@ -327,77 +336,16 @@ func GetTensorName(n string) (string, error) {
 type safetensorWriterTo struct {
 	t *llm.Tensor
 
-	params      *Params
-	bo          ByteOrder
-	headCount   uint32
-	headCountKV uint32
+	params *Params
+	bo     ByteOrder
 
 	filename string
 
 	start, end, padding uint64
-}
-
-func (r safetensorWriterTo) addOnes(data []float32) ([]float32, error) {
-	n := tensor.New(tensor.WithShape(int(r.t.Shape[0])), tensor.WithBacking(data))
-	ones := tensor.Ones(tensor.Float32, int(r.t.Shape[0]))
-
-	var err error
-	n, err = n.Add(ones)
-	if err != nil {
-		return []float32{}, err
-	}
-
-	newN, err := native.SelectF32(n, 0)
-	if err != nil {
-		return []float32{}, err
-	}
-
-	var fullTensor []float32
-	for _, v := range newN {
-		fullTensor = append(fullTensor, v...)
-	}
-
-	return fullTensor, nil
-}
-
-func (r safetensorWriterTo) repack(data []uint16, heads int) ([]uint16, error) {
-	n := tensor.New(tensor.WithShape(int(r.t.Shape[0]), int(r.t.Shape[1])), tensor.WithBacking(data))
-	origShape := n.Shape().Clone()
-
-	// reshape the tensor and swap axes 1 and 2 to unpack the layer for gguf
-	if err := n.Reshape(heads, 2, origShape[0]/heads/2, origShape[1]); err != nil {
-		return nil, err
-	}
-
-	if err := n.T(0, 2, 1, 3); err != nil {
-		return nil, err
-	}
-
-	if err := n.Reshape(origShape...); err != nil {
-		return nil, err
-	}
-
-	if err := n.Transpose(); err != nil {
-		return nil, err
-	}
-	newN, err := native.SelectU16(n, 1)
-	if err != nil {
-		return nil, err
-	}
-
-	var fullTensor []uint16
-	for _, v := range newN {
-		fullTensor = append(fullTensor, v...)
-	}
-	return fullTensor, nil
+	handler             func(w io.Writer, r safetensorWriterTo, f *os.File) error
 }
 
 func (r safetensorWriterTo) WriteTo(w io.Writer) (n int64, err error) {
-	arch, err := getArchFromParams(r.params)
-	if err != nil {
-		return 0, err
-	}
-
 	f, err := os.Open(r.filename)
 	if err != nil {
 		return 0, err
@@ -408,83 +356,9 @@ func (r safetensorWriterTo) WriteTo(w io.Writer) (n int64, err error) {
 		return 0, err
 	}
 
-	switch arch {
-	case "llama":
-
-		pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
-		re, err := regexp.Compile(pattern)
-		if err != nil {
-			return 0, err
-		}
-
-		matches := re.FindAllStringSubmatch(r.t.Name, -1)
-		if len(matches) > 0 {
-			layerSize := r.end - r.start
-
-			var err error
-			tData := make([]uint16, layerSize/2)
-			if err = binary.Read(f, r.bo, tData); err != nil {
-				return 0, err
-			}
-
-			layerType := matches[0][re.SubexpIndex("layer")]
-			var heads uint32
-			switch layerType {
-			case "q":
-				heads = r.headCount
-			case "k":
-				heads = r.headCountKV
-				if heads == 0 {
-					heads = r.headCount
-				}
-			}
-
-			tData, err = r.repack(tData, int(heads))
-			if err != nil {
-				return 0, err
-			}
-
-			var buf []byte
-			for _, n := range tData {
-				buf = r.bo.AppendUint16(buf, n)
-			}
-
-			tempBuf := make([]uint16, len(tData))
-			tDataF32 := bfloat16.DecodeFloat32(buf)
-			for cnt, v := range tDataF32 {
-				tDataF16 := float16.Fromfloat32(v)
-				tempBuf[cnt] = uint16(tDataF16)
-			}
-
-			if err = binary.Write(w, r.bo, tempBuf); err != nil {
-				return 0, err
-			}
-
-			return 0, nil
-		}
-
-	case "gemma":
-		if strings.HasSuffix(r.t.Name, "norm.weight") {
-			slog.Debug(fmt.Sprintf("converting '%s'", r.t.Name))
-
-			data := make([]byte, r.end-r.start)
-			if err = binary.Read(f, r.bo, data); err != nil {
-				return 0, err
-			}
-
-			tDataF32 := bfloat16.DecodeFloat32(data)
-
-			var err error
-			tDataF32, err = r.addOnes(tDataF32)
-			if err != nil {
-				return 0, err
-			}
-
-			if err := binary.Write(w, r.bo, tDataF32); err != nil {
-				return 0, err
-			}
-			return 0, nil
-		}
+	// use the handler if one is present
+	if r.handler != nil {
+		return 0, r.handler(w, r, f)
 	}
 
 	remaining := r.end - r.start
@@ -529,93 +403,32 @@ func (r safetensorWriterTo) WriteTo(w io.Writer) (n int64, err error) {
 	return 0, nil
 }
 
-func getArchFromParams(params *Params) (string, error) {
-	var arch string
+func GetModelArchFromParams(name, dirPath string, params *Params) (ModelArch, error) {
 	switch len(params.Architectures) {
 	case 0:
-		return "", fmt.Errorf("No architecture specified to convert")
+		return nil, fmt.Errorf("No architecture specified to convert")
 	case 1:
 		switch params.Architectures[0] {
 		case "MistralForCausalLM":
-			arch = "llama"
+			return &MistralModel{
+				ModelData{
+					Name:   name,
+					Path:   dirPath,
+					Params: params,
+				},
+			}, nil
 		case "GemmaForCausalLM":
-			arch = "gemma"
+			return &GemmaModel{
+				ModelData{
+					Name:   name,
+					Path:   dirPath,
+					Params: params,
+				},
+			}, nil
 		default:
-			return "", fmt.Errorf("Models based on '%s' are not yet supported", params.Architectures[0])
+			return nil, fmt.Errorf("Models based on '%s' are not yet supported", params.Architectures[0])
 		}
-	default:
-		return "", fmt.Errorf("Multimodal models are not yet supported")
-	}
-
-	return arch, nil
-}
-
-func WriteGGUF(name string, tensors []llm.Tensor, params *Params, vocab *Vocab) (string, error) {
-	arch, err := getArchFromParams(params)
-	if err != nil {
-		return "", err
-	}
-
-	kv := llm.KV{
-		"general.architecture": arch,
-		"general.name":         name,
-	}
-
-	switch arch {
-	case "llama":
-		kv["llama.context_length"] = uint32(params.ContextSize)
-		kv["llama.embedding_length"] = uint32(params.HiddenSize)
-		kv["llama.block_count"] = uint32(params.HiddenLayers)
-		kv["llama.feed_forward_length"] = uint32(params.IntermediateSize)
-		kv["llama.rope.dimension_count"] = uint32(params.HiddenSize / params.AttentionHeads)
-		slog.Debug(fmt.Sprintf("rope dim count = %d", kv["llama.rope.dimension_count"]))
-		kv["llama.attention.head_count"] = uint32(params.AttentionHeads)
-		kv["llama.attention.head_count_kv"] = uint32(params.KeyValHeads)
-		kv["llama.attention.layer_norm_rms_epsilon"] = float32(params.NormEPS)
-		kv["llama.rope.freq_base"] = float32(params.RopeFreqBase)
-	case "gemma":
-		kv["gemma.context_length"] = uint32(params.ContextSize)
-		kv["gemma.embedding_length"] = uint32(params.HiddenSize)
-		kv["gemma.block_count"] = uint32(params.HiddenLayers)
-		kv["gemma.feed_forward_length"] = uint32(params.IntermediateSize)
-		kv["gemma.attention.head_count"] = uint32(params.AttentionHeads)
-		kv["gemma.attention.head_count_kv"] = uint32(params.KeyValHeads)
-		kv["gemma.attention.layer_norm_rms_epsilon"] = float32(params.NormEPS)
-		kv["gemma.attention.key_length"] = uint32(params.HeadDimension)
-		kv["gemma.attention.value_length"] = uint32(params.HeadDimension)
-	}
-
-	kv["general.file_type"] = uint32(1)
-	kv["tokenizer.ggml.model"] = "llama"
-
-	kv["tokenizer.ggml.tokens"] = vocab.Tokens
-	kv["tokenizer.ggml.scores"] = vocab.Scores
-	kv["tokenizer.ggml.token_type"] = vocab.Types
-
-	kv["tokenizer.ggml.bos_token_id"] = uint32(params.BoSTokenID)
-	kv["tokenizer.ggml.eos_token_id"] = uint32(params.EoSTokenID)
-
-	switch arch {
-	case "llama":
-		kv["tokenizer.ggml.unknown_token_id"] = uint32(0)
-	case "gemma":
-		kv["tokenizer.ggml.padding_token_id"] = uint32(params.PaddingTokenID)
-		kv["tokenizer.ggml.unknown_token_id"] = uint32(3)
-	}
-
-	kv["tokenizer.ggml.add_bos_token"] = true
-	kv["tokenizer.ggml.add_eos_token"] = false
-
-	f, err := os.CreateTemp("", "ollama-gguf")
-	if err != nil {
-		return "", err
-	}
-	defer f.Close()
-
-	m := llm.NewGGUFV3(params.ByteOrder)
-	if err := m.Encode(f, kv, tensors); err != nil {
-		return "", err
 	}
 
-	return f.Name(), nil
+	return nil, fmt.Errorf("Unknown error")
 }

+ 136 - 0
convert/gemma.go

@@ -0,0 +1,136 @@
+package convert
+
+import (
+	"encoding/binary"
+	"fmt"
+	"io"
+	"log/slog"
+	"os"
+	"strings"
+
+	"github.com/d4l3k/go-bfloat16"
+	"github.com/pdevine/tensor"
+	"github.com/pdevine/tensor/native"
+
+	"github.com/ollama/ollama/llm"
+)
+
+type GemmaModel struct {
+	ModelData
+}
+
+func gemmaLayerHandler(w io.Writer, r safetensorWriterTo, f *os.File) error {
+	slog.Debug(fmt.Sprintf("converting '%s'", r.t.Name))
+
+	data := make([]byte, r.end-r.start)
+	if err := binary.Read(f, r.bo, data); err != nil {
+		return err
+	}
+
+	tDataF32 := bfloat16.DecodeFloat32(data)
+
+	var err error
+	tDataF32, err = addOnes(tDataF32, int(r.t.Shape[0]))
+	if err != nil {
+		return err
+	}
+
+	if err := binary.Write(w, r.bo, tDataF32); err != nil {
+		return err
+	}
+	return nil
+}
+
+func addOnes(data []float32, vectorSize int) ([]float32, error) {
+	n := tensor.New(tensor.WithShape(vectorSize), tensor.WithBacking(data))
+	ones := tensor.Ones(tensor.Float32, vectorSize)
+
+	var err error
+	n, err = n.Add(ones)
+	if err != nil {
+		return []float32{}, err
+	}
+
+	newN, err := native.SelectF32(n, 0)
+	if err != nil {
+		return []float32{}, err
+	}
+
+	var fullTensor []float32
+	for _, v := range newN {
+		fullTensor = append(fullTensor, v...)
+	}
+
+	return fullTensor, nil
+}
+
+func (m *GemmaModel) GetTensors() error {
+	t, err := GetSafeTensors(m.Path, m.Params)
+	if err != nil {
+		return err
+	}
+
+	m.Tensors = []llm.Tensor{}
+
+	for _, l := range t {
+		if strings.HasSuffix(l.Name, "norm.weight") {
+			wt := l.WriterTo.(safetensorWriterTo)
+			wt.handler = gemmaLayerHandler
+			l.WriterTo = wt
+		}
+		m.Tensors = append(m.Tensors, l)
+	}
+
+	return nil
+}
+
+func (m *GemmaModel) LoadVocab() error {
+	v, err := LoadSentencePieceTokens(m.Path, m.Params.VocabSize)
+	if err != nil {
+		return err
+	}
+	m.Vocab = v
+	return nil
+}
+
+func (m *GemmaModel) WriteGGUF() (string, error) {
+	kv := llm.KV{
+		"general.architecture":                   "gemma",
+		"general.name":                           m.Name,
+		"gemma.context_length":                   uint32(m.Params.ContextSize),
+		"gemma.embedding_length":                 uint32(m.Params.HiddenSize),
+		"gemma.block_count":                      uint32(m.Params.HiddenLayers),
+		"gemma.feed_forward_length":              uint32(m.Params.IntermediateSize),
+		"gemma.attention.head_count":             uint32(m.Params.AttentionHeads),
+		"gemma.attention.head_count_kv":          uint32(m.Params.KeyValHeads),
+		"gemma.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
+		"gemma.attention.key_length":             uint32(m.Params.HeadDimension),
+		"gemma.attention.value_length":           uint32(m.Params.HeadDimension),
+		"general.file_type":                      uint32(1),
+		"tokenizer.ggml.model":                   "llama",
+
+		"tokenizer.ggml.tokens":     m.Vocab.Tokens,
+		"tokenizer.ggml.scores":     m.Vocab.Scores,
+		"tokenizer.ggml.token_type": m.Vocab.Types,
+
+		"tokenizer.ggml.bos_token_id":     uint32(m.Params.BoSTokenID),
+		"tokenizer.ggml.eos_token_id":     uint32(m.Params.EoSTokenID),
+		"tokenizer.ggml.padding_token_id": uint32(m.Params.PaddingTokenID),
+		"tokenizer.ggml.unknown_token_id": uint32(3),
+		"tokenizer.ggml.add_bos_token":    true,
+		"tokenizer.ggml.add_eos_token":    false,
+	}
+
+	f, err := os.CreateTemp("", "ollama-gguf")
+	if err != nil {
+		return "", err
+	}
+	defer f.Close()
+
+	mod := llm.NewGGUFV3(m.Params.ByteOrder)
+	if err := mod.Encode(f, kv, m.Tensors); err != nil {
+		return "", err
+	}
+
+	return f.Name(), nil
+}

+ 174 - 0
convert/mistral.go

@@ -0,0 +1,174 @@
+package convert
+
+import (
+	"encoding/binary"
+	"fmt"
+	"io"
+	"os"
+	"regexp"
+	"strings"
+
+	"github.com/d4l3k/go-bfloat16"
+	"github.com/pdevine/tensor"
+	"github.com/pdevine/tensor/native"
+	"github.com/x448/float16"
+
+	"github.com/ollama/ollama/llm"
+)
+
+type MistralModel struct {
+	ModelData
+}
+
+func mistralLayerHandler(w io.Writer, r safetensorWriterTo, f *os.File) error {
+	layerSize := r.end - r.start
+
+	var err error
+	tData := make([]uint16, layerSize/2)
+	if err = binary.Read(f, r.bo, tData); err != nil {
+		return err
+	}
+
+	var heads uint32
+	if strings.Contains(r.t.Name, "attn_q") {
+		heads = uint32(r.params.AttentionHeads)
+	} else if strings.Contains(r.t.Name, "attn_k") {
+		heads = uint32(r.params.KeyValHeads)
+		if heads == 0 {
+			heads = uint32(r.params.AttentionHeads)
+		}
+	} else {
+		return fmt.Errorf("unknown layer type")
+	}
+
+	tData, err = repack(tData, int(heads), r.t.Shape)
+	if err != nil {
+		return err
+	}
+
+	var buf []byte
+	for _, n := range tData {
+		buf = r.bo.AppendUint16(buf, n)
+	}
+
+	tempBuf := make([]uint16, len(tData))
+	tDataF32 := bfloat16.DecodeFloat32(buf)
+	for cnt, v := range tDataF32 {
+		tDataF16 := float16.Fromfloat32(v)
+		tempBuf[cnt] = uint16(tDataF16)
+	}
+
+	if err = binary.Write(w, r.bo, tempBuf); err != nil {
+		return err
+	}
+	return nil
+}
+
+func repack(data []uint16, heads int, shape []uint64) ([]uint16, error) {
+	n := tensor.New(tensor.WithShape(int(shape[0]), int(shape[1])), tensor.WithBacking(data))
+	origShape := n.Shape().Clone()
+
+	// reshape the tensor and swap axes 1 and 2 to unpack the layer for gguf
+	if err := n.Reshape(heads, 2, origShape[0]/heads/2, origShape[1]); err != nil {
+		return nil, err
+	}
+
+	if err := n.T(0, 2, 1, 3); err != nil {
+		return nil, err
+	}
+
+	if err := n.Reshape(origShape...); err != nil {
+		return nil, err
+	}
+
+	if err := n.Transpose(); err != nil {
+		return nil, err
+	}
+	newN, err := native.SelectU16(n, 1)
+	if err != nil {
+		return nil, err
+	}
+
+	var fullTensor []uint16
+	for _, v := range newN {
+		fullTensor = append(fullTensor, v...)
+	}
+	return fullTensor, nil
+}
+
+func (m *MistralModel) GetTensors() error {
+	t, err := GetSafeTensors(m.Path, m.Params)
+	if err != nil {
+		return err
+	}
+
+	m.Tensors = []llm.Tensor{}
+
+	pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
+	re, err := regexp.Compile(pattern)
+	if err != nil {
+		return err
+	}
+
+	for _, l := range t {
+		matches := re.FindAllStringSubmatch(l.Name, -1)
+		if len(matches) > 0 {
+			wt := l.WriterTo.(safetensorWriterTo)
+			wt.handler = mistralLayerHandler
+			l.WriterTo = wt
+		}
+		m.Tensors = append(m.Tensors, l)
+	}
+
+	return nil
+}
+
+func (m *MistralModel) LoadVocab() error {
+	v, err := LoadSentencePieceTokens(m.Path, m.Params.VocabSize)
+	if err != nil {
+		return err
+	}
+	m.Vocab = v
+	return nil
+}
+
+func (m *MistralModel) WriteGGUF() (string, error) {
+	kv := llm.KV{
+		"general.architecture":                   "llama",
+		"general.name":                           m.Name,
+		"llama.context_length":                   uint32(m.Params.ContextSize),
+		"llama.embedding_length":                 uint32(m.Params.HiddenSize),
+		"llama.block_count":                      uint32(m.Params.HiddenLayers),
+		"llama.feed_forward_length":              uint32(m.Params.IntermediateSize),
+		"llama.rope.dimension_count":             uint32(m.Params.HiddenSize / m.Params.AttentionHeads),
+		"llama.attention.head_count":             uint32(m.Params.AttentionHeads),
+		"llama.attention.head_count_kv":          uint32(m.Params.KeyValHeads),
+		"llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
+		"llama.rope.freq_base":                   float32(m.Params.RopeFreqBase),
+		"general.file_type":                      uint32(1),
+		"tokenizer.ggml.model":                   "llama",
+
+		"tokenizer.ggml.tokens":     m.Vocab.Tokens,
+		"tokenizer.ggml.scores":     m.Vocab.Scores,
+		"tokenizer.ggml.token_type": m.Vocab.Types,
+
+		"tokenizer.ggml.bos_token_id":     uint32(m.Params.BoSTokenID),
+		"tokenizer.ggml.eos_token_id":     uint32(m.Params.EoSTokenID),
+		"tokenizer.ggml.add_bos_token":    true,
+		"tokenizer.ggml.add_eos_token":    false,
+		"tokenizer.ggml.unknown_token_id": uint32(0),
+	}
+
+	f, err := os.CreateTemp("", "ollama-gguf")
+	if err != nil {
+		return "", err
+	}
+	defer f.Close()
+
+	mod := llm.NewGGUFV3(m.Params.ByteOrder)
+	if err := mod.Encode(f, kv, m.Tensors); err != nil {
+		return "", err
+	}
+
+	return f.Name(), nil
+}

+ 6 - 14
server/images.go

@@ -654,30 +654,22 @@ func convertSafetensors(name, path string, fn func(resp api.ProgressResponse)) (
 		return "", err
 	}
 
-	SupportedArchs := []string{
-		"MistralForCausalLM",
-		"GemmaForCausalLM",
-	}
-
-	for _, arch := range params.Architectures {
-		if !slices.Contains(SupportedArchs, arch) {
-			return "", fmt.Errorf("this safetensors model is not yet supported")
-		}
+	mArch, err := convert.GetModelArchFromParams(name, tempDir, params)
+	if err != nil {
+		return "", err
 	}
 
 	fn(api.ProgressResponse{Status: "processing safetensors"})
-	t, err := convert.GetSafeTensors(tempDir, params)
-	if err != nil {
+	if err := mArch.GetTensors(); err != nil {
 		return "", err
 	}
 
-	vocab, err := convert.LoadTokens(tempDir, params)
-	if err != nil {
+	if err := mArch.LoadVocab(); err != nil {
 		return "", err
 	}
 
 	fn(api.ProgressResponse{Status: "converting model"})
-	path, err = convert.WriteGGUF(name, t, params, vocab)
+	path, err = mArch.WriteGGUF()
 	if err != nil {
 		return "", err
 	}