Browse Source

Add gemma safetensors conversion (#3250)

Co-authored-by: Michael Yang <mxyng@pm.me>
Patrick Devine 1 year ago
parent
commit
5a5efee46b
11 changed files with 941 additions and 825 deletions
  1. 0 10
      .golangci.yaml
  2. 4 1
      cmd/cmd.go
  3. 6 2
      cmd/interactive.go
  4. 346 56
      convert/convert.go
  5. 1 1
      go.mod
  6. 25 23
      llm/ggla.go
  7. 82 3
      llm/ggml.go
  8. 446 718
      llm/gguf.go
  9. 3 1
      readline/history.go
  10. 16 8
      server/images.go
  11. 12 2
      server/routes_test.go

+ 0 - 10
.golangci.yaml

@@ -15,13 +15,3 @@ linters:
     - misspell
     - nilerr
     - unused
-linters-settings:
-  errcheck:
-    # exclude the following functions since we don't generally
-    # need to be concerned with the returned errors
-    exclude-functions:
-      - encoding/binary.Read
-      - (*os.File).Seek
-      - (*bufio.Writer).WriteString
-      - (*github.com/spf13/pflag.FlagSet).Set
-      - (*github.com/ollama/ollama/llm.readSeekOffset).Seek

+ 4 - 1
cmd/cmd.go

@@ -213,7 +213,10 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
 	if _, err := io.Copy(hash, bin); err != nil {
 		return "", err
 	}
-	bin.Seek(0, io.SeekStart)
+
+	if _, err := bin.Seek(0, io.SeekStart); err != nil {
+		return "", err
+	}
 
 	digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
 	if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {

+ 6 - 2
cmd/interactive.go

@@ -295,10 +295,14 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
 					opts.WordWrap = false
 					fmt.Println("Set 'nowordwrap' mode.")
 				case "verbose":
-					cmd.Flags().Set("verbose", "true")
+					if err := cmd.Flags().Set("verbose", "true"); err != nil {
+						return err
+					}
 					fmt.Println("Set 'verbose' mode.")
 				case "quiet":
-					cmd.Flags().Set("verbose", "false")
+					if err := cmd.Flags().Set("verbose", "false"); err != nil {
+						return err
+					}
 					fmt.Println("Set 'quiet' mode.")
 				case "format":
 					if len(args) < 3 || args[2] != "json" {

+ 346 - 56
convert/convert.go

@@ -12,8 +12,13 @@ import (
 	"path/filepath"
 	"regexp"
 	"slices"
+	"strings"
 
+	"github.com/d4l3k/go-bfloat16"
 	"github.com/mitchellh/mapstructure"
+	"github.com/pdevine/tensor"
+	"github.com/pdevine/tensor/native"
+	"github.com/x448/float16"
 	"google.golang.org/protobuf/proto"
 
 	"github.com/ollama/ollama/convert/sentencepiece"
@@ -33,6 +38,15 @@ type Params struct {
 	RopeFreqBase     float64  `json:"rope_theta"`
 	BoSTokenID       int      `json:"bos_token_id"`
 	EoSTokenID       int      `json:"eos_token_id"`
+	HeadDimension    int      `json:"head_dim"`
+	PaddingTokenID   int      `json:"pad_token_id"`
+
+	ByteOrder
+}
+
+type ByteOrder interface {
+	binary.ByteOrder
+	binary.AppendByteOrder
 }
 
 type MetaData struct {
@@ -41,27 +55,29 @@ type MetaData struct {
 	Offsets []int  `mapstructure:"data_offsets"`
 }
 
-func ReadSafeTensors(fn string, offset uint64) ([]llm.Tensor, uint64, error) {
+func ReadSafeTensors(fn string, offset uint64, params *Params) ([]llm.Tensor, uint64, error) {
 	f, err := os.Open(fn)
 	if err != nil {
-		return []llm.Tensor{}, 0, err
+		return nil, 0, err
 	}
 	defer f.Close()
 
 	var jsonSize uint64
-	binary.Read(f, binary.LittleEndian, &jsonSize)
+	if err := binary.Read(f, binary.LittleEndian, &jsonSize); err != nil {
+		return nil, 0, err
+	}
 
 	buf := make([]byte, jsonSize)
 	_, err = io.ReadFull(f, buf)
 	if err != nil {
-		return []llm.Tensor{}, 0, err
+		return nil, 0, err
 	}
 
 	d := json.NewDecoder(bytes.NewBuffer(buf))
 	d.UseNumber()
 	var parsed map[string]interface{}
 	if err = d.Decode(&parsed); err != nil {
-		return []llm.Tensor{}, 0, err
+		return nil, 0, err
 	}
 
 	var keys []string
@@ -78,7 +94,7 @@ func ReadSafeTensors(fn string, offset uint64) ([]llm.Tensor, uint64, error) {
 		vals := parsed[k].(map[string]interface{})
 		var data MetaData
 		if err = mapstructure.Decode(vals, &data); err != nil {
-			return []llm.Tensor{}, 0, err
+			return nil, 0, err
 		}
 
 		var size uint64
@@ -100,7 +116,7 @@ func ReadSafeTensors(fn string, offset uint64) ([]llm.Tensor, uint64, error) {
 		ggufName, err := GetTensorName(k)
 		if err != nil {
 			slog.Error("%v", err)
-			return []llm.Tensor{}, 0, err
+			return nil, 0, err
 		}
 
 		shape := []uint64{0, 0, 0, 0}
@@ -109,14 +125,24 @@ func ReadSafeTensors(fn string, offset uint64) ([]llm.Tensor, uint64, error) {
 		}
 
 		t := llm.Tensor{
-			Name:          ggufName,
-			Kind:          kind,
-			Offset:        offset,
-			Shape:         shape[:],
-			FileName:      fn,
-			OffsetPadding: 8 + jsonSize,
-			FileOffsets:   []uint64{uint64(data.Offsets[0]), uint64(data.Offsets[1])},
+			Name:   ggufName,
+			Kind:   kind,
+			Offset: offset,
+			Shape:  shape[:],
 		}
+
+		t.WriterTo = safetensorWriterTo{
+			t:           &t,
+			params:      params,
+			bo:          params.ByteOrder,
+			headCount:   uint32(params.AttentionHeads),
+			headCountKV: uint32(params.KeyValHeads),
+			filename:    fn,
+			start:       uint64(data.Offsets[0]),
+			end:         uint64(data.Offsets[1]),
+			padding:     8 + jsonSize,
+		}
+
 		slog.Debug(fmt.Sprintf("%v", t))
 		tensors = append(tensors, t)
 		offset += size
@@ -124,21 +150,21 @@ func ReadSafeTensors(fn string, offset uint64) ([]llm.Tensor, uint64, error) {
 	return tensors, offset, nil
 }
 
-func GetSafeTensors(dirpath string) ([]llm.Tensor, error) {
+func GetSafeTensors(dirpath string, params *Params) ([]llm.Tensor, error) {
 	var tensors []llm.Tensor
 	files, err := filepath.Glob(filepath.Join(dirpath, "/model-*.safetensors"))
 	if err != nil {
-		return []llm.Tensor{}, err
+		return nil, err
 	}
 
 	var offset uint64
 	for _, f := range files {
 		var t []llm.Tensor
 		var err error
-		t, offset, err = ReadSafeTensors(f, offset)
+		t, offset, err = ReadSafeTensors(f, offset, params)
 		if err != nil {
 			slog.Error("%v", err)
-			return []llm.Tensor{}, err
+			return nil, err
 		}
 		tensors = append(tensors, t...)
 	}
@@ -160,6 +186,7 @@ func GetParams(dirpath string) (*Params, error) {
 		return nil, err
 	}
 
+	params.ByteOrder = binary.LittleEndian
 	return &params, nil
 }
 
@@ -171,7 +198,7 @@ type Vocab struct {
 	Types  []int32
 }
 
-func LoadTokens(dirpath string) (*Vocab, error) {
+func LoadTokens(dirpath string, params *Params) (*Vocab, error) {
 	slog.Info(fmt.Sprintf("reading vocab from %s", filepath.Join(dirpath, "tokenizer.model")))
 	in, err := os.ReadFile(filepath.Join(dirpath, "tokenizer.model"))
 	if err != nil {
@@ -196,6 +223,14 @@ func LoadTokens(dirpath string) (*Vocab, error) {
 		v.Tokens = append(v.Tokens, p.GetPiece())
 		v.Scores = append(v.Scores, p.GetScore())
 		t := p.GetType()
+		switch t {
+		case sentencepiece.ModelProto_SentencePiece_UNKNOWN:
+		case sentencepiece.ModelProto_SentencePiece_CONTROL:
+		case sentencepiece.ModelProto_SentencePiece_UNUSED:
+		case sentencepiece.ModelProto_SentencePiece_BYTE:
+		default:
+			t = sentencepiece.ModelProto_SentencePiece_NORMAL
+		}
 		v.Types = append(v.Types, int32(t))
 	}
 
@@ -243,6 +278,16 @@ func LoadTokens(dirpath string) (*Vocab, error) {
 	}
 	slog.Info(fmt.Sprintf("vocab size w/ extra tokens: %d", len(v.Tokens)))
 
+	if params.VocabSize > len(v.Tokens) {
+		missingTokens := params.VocabSize - len(v.Tokens)
+		slog.Warn(fmt.Sprintf("vocab is missing %d tokens", missingTokens))
+		for cnt := 0; cnt < missingTokens; cnt++ {
+			v.Tokens = append(v.Tokens, fmt.Sprintf("<dummy%05d>", cnt+1))
+			v.Scores = append(v.Scores, -1)
+			v.Types = append(v.Types, int32(llm.GGUFTokenUserDefined))
+		}
+	}
+
 	return v, nil
 }
 
@@ -279,42 +324,287 @@ func GetTensorName(n string) (string, error) {
 	return "", fmt.Errorf("couldn't find a layer name for '%s'", n)
 }
 
+type safetensorWriterTo struct {
+	t *llm.Tensor
+
+	params      *Params
+	bo          ByteOrder
+	headCount   uint32
+	headCountKV uint32
+
+	filename string
+
+	start, end, padding uint64
+}
+
+func (r safetensorWriterTo) addOnes(data []float32) ([]float32, error) {
+	n := tensor.New(tensor.WithShape(int(r.t.Shape[0])), tensor.WithBacking(data))
+	ones := tensor.Ones(tensor.Float32, int(r.t.Shape[0]))
+
+	var err error
+	n, err = n.Add(ones)
+	if err != nil {
+		return []float32{}, err
+	}
+
+	newN, err := native.SelectF32(n, 0)
+	if err != nil {
+		return []float32{}, err
+	}
+
+	var fullTensor []float32
+	for _, v := range newN {
+		fullTensor = append(fullTensor, v...)
+	}
+
+	return fullTensor, nil
+}
+
+func (r safetensorWriterTo) repack(data []uint16, heads int) ([]uint16, error) {
+	n := tensor.New(tensor.WithShape(int(r.t.Shape[0]), int(r.t.Shape[1])), tensor.WithBacking(data))
+	origShape := n.Shape().Clone()
+
+	// reshape the tensor and swap axes 1 and 2 to unpack the layer for gguf
+	if err := n.Reshape(heads, 2, origShape[0]/heads/2, origShape[1]); err != nil {
+		return nil, err
+	}
+
+	if err := n.T(0, 2, 1, 3); err != nil {
+		return nil, err
+	}
+
+	if err := n.Reshape(origShape...); err != nil {
+		return nil, err
+	}
+
+	if err := n.Transpose(); err != nil {
+		return nil, err
+	}
+	newN, err := native.SelectU16(n, 1)
+	if err != nil {
+		return nil, err
+	}
+
+	var fullTensor []uint16
+	for _, v := range newN {
+		fullTensor = append(fullTensor, v...)
+	}
+	return fullTensor, nil
+}
+
+func (r safetensorWriterTo) WriteTo(w io.Writer) (n int64, err error) {
+	arch, err := getArchFromParams(r.params)
+	if err != nil {
+		return 0, err
+	}
+
+	f, err := os.Open(r.filename)
+	if err != nil {
+		return 0, err
+	}
+	defer f.Close()
+
+	if _, err = f.Seek(int64(r.padding+r.start), 0); err != nil {
+		return 0, err
+	}
+
+	switch arch {
+	case "llama":
+
+		pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
+		re, err := regexp.Compile(pattern)
+		if err != nil {
+			return 0, err
+		}
+
+		matches := re.FindAllStringSubmatch(r.t.Name, -1)
+		if len(matches) > 0 {
+			layerSize := r.end - r.start
+
+			var err error
+			tData := make([]uint16, layerSize/2)
+			if err = binary.Read(f, r.bo, tData); err != nil {
+				return 0, err
+			}
+
+			layerType := matches[0][re.SubexpIndex("layer")]
+			var heads uint32
+			switch layerType {
+			case "q":
+				heads = r.headCount
+			case "k":
+				heads = r.headCountKV
+				if heads == 0 {
+					heads = r.headCount
+				}
+			}
+
+			tData, err = r.repack(tData, int(heads))
+			if err != nil {
+				return 0, err
+			}
+
+			var buf []byte
+			for _, n := range tData {
+				buf = r.bo.AppendUint16(buf, n)
+			}
+
+			tempBuf := make([]uint16, len(tData))
+			tDataF32 := bfloat16.DecodeFloat32(buf)
+			for cnt, v := range tDataF32 {
+				tDataF16 := float16.Fromfloat32(v)
+				tempBuf[cnt] = uint16(tDataF16)
+			}
+
+			if err = binary.Write(w, r.bo, tempBuf); err != nil {
+				return 0, err
+			}
+
+			return 0, nil
+		}
+
+	case "gemma":
+		if strings.HasSuffix(r.t.Name, "norm.weight") {
+			slog.Debug(fmt.Sprintf("converting '%s'", r.t.Name))
+
+			data := make([]byte, r.end-r.start)
+			if err = binary.Read(f, r.bo, data); err != nil {
+				return 0, err
+			}
+
+			tDataF32 := bfloat16.DecodeFloat32(data)
+
+			var err error
+			tDataF32, err = r.addOnes(tDataF32)
+			if err != nil {
+				return 0, err
+			}
+
+			if err := binary.Write(w, r.bo, tDataF32); err != nil {
+				return 0, err
+			}
+			return 0, nil
+		}
+	}
+
+	remaining := r.end - r.start
+
+	bufSize := uint64(10240)
+	var finished bool
+	for {
+		data := make([]byte, min(bufSize, remaining))
+
+		b, err := io.ReadFull(f, data)
+		remaining -= uint64(b)
+
+		if err == io.EOF || remaining <= 0 {
+			finished = true
+		} else if err != nil {
+			return 0, err
+		}
+
+		// convert bfloat16 -> ieee float32
+		tDataF32 := bfloat16.DecodeFloat32(data)
+
+		switch r.t.Kind {
+		case 0:
+			if err := binary.Write(w, r.bo, tDataF32); err != nil {
+				return 0, err
+			}
+		case 1:
+			// convert float32 -> float16
+			tempBuf := make([]uint16, len(data)/2)
+			for cnt, v := range tDataF32 {
+				tDataF16 := float16.Fromfloat32(v)
+				tempBuf[cnt] = uint16(tDataF16)
+			}
+			if err := binary.Write(w, binary.LittleEndian, tempBuf); err != nil {
+				return 0, err
+			}
+		}
+		if finished {
+			break
+		}
+	}
+	return 0, nil
+}
+
+func getArchFromParams(params *Params) (string, error) {
+	var arch string
+	switch len(params.Architectures) {
+	case 0:
+		return "", fmt.Errorf("No architecture specified to convert")
+	case 1:
+		switch params.Architectures[0] {
+		case "MistralForCausalLM":
+			arch = "llama"
+		case "GemmaForCausalLM":
+			arch = "gemma"
+		default:
+			return "", fmt.Errorf("Models based on '%s' are not yet supported", params.Architectures[0])
+		}
+	default:
+		return "", fmt.Errorf("Multimodal models are not yet supported")
+	}
+
+	return arch, nil
+}
+
 func WriteGGUF(name string, tensors []llm.Tensor, params *Params, vocab *Vocab) (string, error) {
-	c := llm.ContainerGGUF{
-		ByteOrder: binary.LittleEndian,
-	}
-
-	m := llm.NewGGUFModel(&c)
-	m.Tensors = tensors
-	m.KV["general.architecture"] = "llama"
-	m.KV["general.name"] = name
-	m.KV["llama.context_length"] = uint32(params.ContextSize)
-	m.KV["llama.embedding_length"] = uint32(params.HiddenSize)
-	m.KV["llama.block_count"] = uint32(params.HiddenLayers)
-	m.KV["llama.feed_forward_length"] = uint32(params.IntermediateSize)
-	m.KV["llama.rope.dimension_count"] = uint32(128)
-	m.KV["llama.attention.head_count"] = uint32(params.AttentionHeads)
-	m.KV["llama.attention.head_count_kv"] = uint32(params.KeyValHeads)
-	m.KV["llama.attention.layer_norm_rms_epsilon"] = float32(params.NormEPS)
-	m.KV["llama.rope.freq_base"] = float32(params.RopeFreqBase)
-	m.KV["general.file_type"] = uint32(1)
-	m.KV["tokenizer.ggml.model"] = "llama"
-
-	m.KV["tokenizer.ggml.tokens"] = vocab.Tokens
-	m.KV["tokenizer.ggml.scores"] = vocab.Scores
-	m.KV["tokenizer.ggml.token_type"] = vocab.Types
-
-	m.KV["tokenizer.ggml.bos_token_id"] = uint32(params.BoSTokenID)
-	m.KV["tokenizer.ggml.eos_token_id"] = uint32(params.EoSTokenID)
-	m.KV["tokenizer.ggml.unknown_token_id"] = uint32(0)
-	m.KV["tokenizer.ggml.add_bos_token"] = true
-	m.KV["tokenizer.ggml.add_eos_token"] = false
-
-	// llamacpp sets the chat template, however we don't need to set it since we pass it in through a layer
-	// m.KV["tokenizer.chat_template"] = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" // XXX removeme
-
-	c.V3.NumTensor = uint64(len(tensors))
-	c.V3.NumKV = uint64(len(m.KV))
+	arch, err := getArchFromParams(params)
+	if err != nil {
+		return "", err
+	}
+
+	kv := llm.KV{
+		"general.architecture": arch,
+		"general.name":         name,
+	}
+
+	switch arch {
+	case "llama":
+		kv["llama.context_length"] = uint32(params.ContextSize)
+		kv["llama.embedding_length"] = uint32(params.HiddenSize)
+		kv["llama.block_count"] = uint32(params.HiddenLayers)
+		kv["llama.feed_forward_length"] = uint32(params.IntermediateSize)
+		kv["llama.rope.dimension_count"] = uint32(params.HiddenSize / params.AttentionHeads)
+		slog.Debug(fmt.Sprintf("rope dim count = %d", kv["llama.rope.dimension_count"]))
+		kv["llama.attention.head_count"] = uint32(params.AttentionHeads)
+		kv["llama.attention.head_count_kv"] = uint32(params.KeyValHeads)
+		kv["llama.attention.layer_norm_rms_epsilon"] = float32(params.NormEPS)
+		kv["llama.rope.freq_base"] = float32(params.RopeFreqBase)
+	case "gemma":
+		kv["gemma.context_length"] = uint32(params.ContextSize)
+		kv["gemma.embedding_length"] = uint32(params.HiddenSize)
+		kv["gemma.block_count"] = uint32(params.HiddenLayers)
+		kv["gemma.feed_forward_length"] = uint32(params.IntermediateSize)
+		kv["gemma.attention.head_count"] = uint32(params.AttentionHeads)
+		kv["gemma.attention.head_count_kv"] = uint32(params.KeyValHeads)
+		kv["gemma.attention.layer_norm_rms_epsilon"] = float32(params.NormEPS)
+		kv["gemma.attention.key_length"] = uint32(params.HeadDimension)
+		kv["gemma.attention.value_length"] = uint32(params.HeadDimension)
+	}
+
+	kv["general.file_type"] = uint32(1)
+	kv["tokenizer.ggml.model"] = "llama"
+
+	kv["tokenizer.ggml.tokens"] = vocab.Tokens
+	kv["tokenizer.ggml.scores"] = vocab.Scores
+	kv["tokenizer.ggml.token_type"] = vocab.Types
+
+	kv["tokenizer.ggml.bos_token_id"] = uint32(params.BoSTokenID)
+	kv["tokenizer.ggml.eos_token_id"] = uint32(params.EoSTokenID)
+
+	switch arch {
+	case "llama":
+		kv["tokenizer.ggml.unknown_token_id"] = uint32(0)
+	case "gemma":
+		kv["tokenizer.ggml.padding_token_id"] = uint32(params.PaddingTokenID)
+		kv["tokenizer.ggml.unknown_token_id"] = uint32(3)
+	}
+
+	kv["tokenizer.ggml.add_bos_token"] = true
+	kv["tokenizer.ggml.add_eos_token"] = false
 
 	f, err := os.CreateTemp("", "ollama-gguf")
 	if err != nil {
@@ -322,8 +612,8 @@ func WriteGGUF(name string, tensors []llm.Tensor, params *Params, vocab *Vocab)
 	}
 	defer f.Close()
 
-	err = m.Encode(f)
-	if err != nil {
+	m := llm.NewGGUFV3(params.ByteOrder)
+	if err := m.Encode(f, kv, tensors); err != nil {
 		return "", err
 	}
 

+ 1 - 1
go.mod

@@ -9,7 +9,7 @@ require (
 	github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
 	github.com/emirpasic/gods v1.18.1
 	github.com/gin-gonic/gin v1.9.1
-	github.com/golang/protobuf v1.5.0
+	github.com/golang/protobuf v1.5.0 // indirect
 	github.com/google/uuid v1.0.0
 	github.com/mitchellh/mapstructure v1.5.0
 	github.com/olekukonko/tablewriter v0.0.5

+ 25 - 23
llm/ggla.go

@@ -7,16 +7,18 @@ import (
 	"slices"
 )
 
-type ContainerGGLA struct {
+type containerGGLA struct {
 	version uint32
 }
 
-func (c *ContainerGGLA) Name() string {
+func (c *containerGGLA) Name() string {
 	return "ggla"
 }
 
-func (c *ContainerGGLA) Decode(rs io.ReadSeeker) (model, error) {
-	binary.Read(rs, binary.LittleEndian, &c.version)
+func (c *containerGGLA) Decode(rs io.ReadSeeker) (model, error) {
+	if err := binary.Read(rs, binary.LittleEndian, &c.version); err != nil {
+		return nil, err
+	}
 
 	switch c.version {
 	case 1:
@@ -24,26 +26,26 @@ func (c *ContainerGGLA) Decode(rs io.ReadSeeker) (model, error) {
 		return nil, errors.New("invalid version")
 	}
 
-	model := newModelGGLA(c)
+	model := newGGLA(c)
 	err := model.decode(rs)
 	return model, err
 }
 
-type ModelGGLA struct {
-	*ContainerGGLA
+type ggla struct {
+	*containerGGLA
 
 	kv      KV
 	tensors []Tensor
 }
 
-func newModelGGLA(container *ContainerGGLA) *ModelGGLA {
-	return &ModelGGLA{
-		ContainerGGLA: container,
+func newGGLA(container *containerGGLA) *ggla {
+	return &ggla{
+		containerGGLA: container,
 		kv:            make(KV),
 	}
 }
 
-func (m *ModelGGLA) decode(rs io.ReadSeeker) error {
+func (m *ggla) decode(rs io.ReadSeeker) error {
 	var r uint32
 	if err := binary.Read(rs, binary.LittleEndian, &r); err != nil {
 		return err
@@ -109,7 +111,7 @@ func (m *ModelGGLA) decode(rs io.ReadSeeker) error {
 
 		t.Offset = uint64(offset)
 
-		if _, err := rs.Seek(int64(t.Size()), io.SeekCurrent); err != nil {
+		if _, err := rs.Seek(int64(t.size()), io.SeekCurrent); err != nil {
 			return err
 		}
 
@@ -117,46 +119,46 @@ func (m *ModelGGLA) decode(rs io.ReadSeeker) error {
 	}
 }
 
-func (m *ModelGGLA) KV() KV {
+func (m *ggla) KV() KV {
 	return m.kv
 }
 
-func (m *ModelGGLA) Tensor() []Tensor {
+func (m *ggla) Tensor() []Tensor {
 	return m.tensors
 }
 
-func (*ModelGGLA) ModelFamily() string {
+func (*ggla) ModelFamily() string {
 	return "ggla"
 }
 
-func (*ModelGGLA) ModelType() string {
+func (*ggla) ModelType() string {
 	panic("not implemented")
 }
 
-func (*ModelGGLA) FileType() string {
+func (*ggla) FileType() string {
 	panic("not implemented")
 }
 
-func (*ModelGGLA) NumLayers() uint32 {
+func (*ggla) NumLayers() uint32 {
 	panic("not implemented")
 }
 
-func (*ModelGGLA) NumGQA() uint32 {
+func (*ggla) NumGQA() uint32 {
 	panic("not implemented")
 }
 
-func (*ModelGGLA) NumEmbed() uint32 {
+func (*ggla) NumEmbed() uint32 {
 	panic("not implemented")
 }
 
-func (*ModelGGLA) NumHead() uint32 {
+func (*ggla) NumHead() uint32 {
 	panic("not implemented")
 }
 
-func (*ModelGGLA) NumHeadKv() uint32 {
+func (*ggla) NumHeadKv() uint32 {
 	panic("not implemented")
 }
 
-func (*ModelGGLA) NumCtx() uint32 {
+func (*ggla) NumCtx() uint32 {
 	panic("not implemented")
 }

+ 82 - 3
llm/ggml.go

@@ -101,6 +101,85 @@ type model interface {
 	NumCtx() uint32
 }
 
+type KV map[string]any
+
+type Tensor struct {
+	Name   string
+	Kind   uint32
+	Offset uint64
+
+	// Shape is the number of elements in each dimension
+	Shape []uint64
+
+	io.WriterTo
+}
+
+func (t Tensor) blockSize() uint64 {
+	switch {
+	case t.Kind < 2:
+		return 1
+	case t.Kind < 10:
+		return 32
+	default:
+		return 256
+	}
+}
+
+func (t Tensor) typeSize() uint64 {
+	blockSize := t.blockSize()
+
+	switch t.Kind {
+	case 0: // FP32
+		return 4
+	case 1: // FP16
+		return 2
+	case 2: // Q4_0
+		return 2 + blockSize/2
+	case 3: // Q4_1
+		return 2 + 2 + blockSize/2
+	case 6: // Q5_0
+		return 2 + 4 + blockSize/2
+	case 7: // Q5_1
+		return 2 + 2 + 4 + blockSize/2
+	case 8: // Q8_0
+		return 2 + blockSize
+	case 9: // Q8_1
+		return 4 + 4 + blockSize
+	case 10: // Q2_K
+		return blockSize/16 + blockSize/4 + 2 + 2
+	case 11: // Q3_K
+		return blockSize/8 + blockSize/4 + 12 + 2
+	case 12: // Q4_K
+		return 2 + 2 + 12 + blockSize/2
+	case 13: // Q5_K
+		return 2 + 2 + 12 + blockSize/8 + blockSize/2
+	case 14: // Q6_K
+		return blockSize/2 + blockSize/4 + blockSize/16 + 2
+	case 15: // Q8_K
+		return 2 + blockSize + 2*blockSize/16
+	case 16: // IQ2_XXS
+		return 2 + 2*blockSize/8
+	case 17: // IQ2_XS
+		return 2 + 2*blockSize/8 + blockSize/32
+	case 18: // IQ3_XXS
+		return 2 + 3*blockSize/8
+	default:
+		return 0
+	}
+}
+
+func (t Tensor) parameters() uint64 {
+	var count uint64 = 1
+	for _, n := range t.Shape {
+		count *= n
+	}
+	return count
+}
+
+func (t Tensor) size() uint64 {
+	return t.parameters() * t.typeSize() / t.blockSize()
+}
+
 type container interface {
 	Name() string
 	Decode(io.ReadSeeker) (model, error)
@@ -133,11 +212,11 @@ func DecodeGGML(rs io.ReadSeeker) (*GGML, error) {
 	case FILE_MAGIC_GGML, FILE_MAGIC_GGMF, FILE_MAGIC_GGJT:
 		return nil, ErrUnsupportedFormat
 	case FILE_MAGIC_GGLA:
-		c = &ContainerGGLA{}
+		c = &containerGGLA{}
 	case FILE_MAGIC_GGUF_LE:
-		c = &ContainerGGUF{ByteOrder: binary.LittleEndian}
+		c = &containerGGUF{ByteOrder: binary.LittleEndian}
 	case FILE_MAGIC_GGUF_BE:
-		c = &ContainerGGUF{ByteOrder: binary.BigEndian}
+		c = &containerGGUF{ByteOrder: binary.BigEndian}
 	default:
 		return nil, errors.New("invalid file magic")
 	}

+ 446 - 718
llm/gguf.go

@@ -3,22 +3,14 @@ package llm
 import (
 	"bytes"
 	"encoding/binary"
-	"errors"
 	"fmt"
 	"io"
-	"log/slog"
-	"os"
-	"regexp"
-
-	"github.com/d4l3k/go-bfloat16"
-	"github.com/pdevine/tensor"
-	"github.com/pdevine/tensor/native"
-	"github.com/x448/float16"
+	"strings"
 
 	"github.com/ollama/ollama/format"
 )
 
-type ContainerGGUF struct {
+type containerGGUF struct {
 	ByteOrder binary.ByteOrder
 
 	Version uint32
@@ -39,21 +31,29 @@ type ContainerGGUF struct {
 	}
 }
 
-func (c *ContainerGGUF) Name() string {
+func (c *containerGGUF) Name() string {
 	return "gguf"
 }
 
-func (c *ContainerGGUF) Decode(rs io.ReadSeeker) (model, error) {
-	binary.Read(rs, c.ByteOrder, &c.Version)
+func (c *containerGGUF) Decode(rs io.ReadSeeker) (model, error) {
+	if err := binary.Read(rs, c.ByteOrder, &c.Version); err != nil {
+		return nil, err
+	}
 
+	var err error
 	switch c.Version {
 	case 1:
-		binary.Read(rs, c.ByteOrder, &c.V1)
+		err = binary.Read(rs, c.ByteOrder, &c.V1)
+	case 2:
+		err = binary.Read(rs, c.ByteOrder, &c.V2)
 	default:
-		binary.Read(rs, c.ByteOrder, &c.V2)
+		err = binary.Read(rs, c.ByteOrder, &c.V3)
+	}
+	if err != nil {
+		return nil, err
 	}
 
-	model := NewGGUFModel(c)
+	model := newGGUF(c)
 	if err := model.Decode(rs); err != nil {
 		return nil, err
 	}
@@ -72,136 +72,23 @@ const (
 )
 
 const (
-	GGUFTypeUint8 uint32 = iota
-	GGUFTypeInt8
-	GGUFTypeUint16
-	GGUFTypeInt16
-	GGUFTypeUint32
-	GGUFTypeInt32
-	GGUFTypeFloat32
-	GGUFTypeBool
-	GGUFTypeString
-	GGUFTypeArray
-	GGUFTypeUint64
-	GGUFTypeInt64
-	GGUFTypeFloat64
+	ggufTypeUint8 uint32 = iota
+	ggufTypeInt8
+	ggufTypeUint16
+	ggufTypeInt16
+	ggufTypeUint32
+	ggufTypeInt32
+	ggufTypeFloat32
+	ggufTypeBool
+	ggufTypeString
+	ggufTypeArray
+	ggufTypeUint64
+	ggufTypeInt64
+	ggufTypeFloat64
 )
 
-type KV map[string]any
-
-type Tensor struct {
-	Name   string
-	Kind   uint32
-	Offset uint64
-
-	// shape is the number of elements in each dimension
-	Shape []uint64
-
-	FileName      string
-	OffsetPadding uint64
-	FileOffsets   []uint64
-}
-
-func (t Tensor) BlockSize() uint64 {
-	switch {
-	case t.Kind < 2:
-		return 1
-	case t.Kind < 10:
-		return 32
-	default:
-		return 256
-	}
-}
-
-func (t Tensor) TypeSize() uint64 {
-	blockSize := t.BlockSize()
-
-	switch t.Kind {
-	case 0: // FP32
-		return 4
-	case 1: // FP16
-		return 2
-	case 2: // Q4_0
-		return 2 + blockSize/2
-	case 3: // Q4_1
-		return 2 + 2 + blockSize/2
-	case 6: // Q5_0
-		return 2 + 4 + blockSize/2
-	case 7: // Q5_1
-		return 2 + 2 + 4 + blockSize/2
-	case 8: // Q8_0
-		return 2 + blockSize
-	case 9: // Q8_1
-		return 4 + 4 + blockSize
-	case 10: // Q2_K
-		return blockSize/16 + blockSize/4 + 2 + 2
-	case 11: // Q3_K
-		return blockSize/8 + blockSize/4 + 12 + 2
-	case 12: // Q4_K
-		return 2 + 2 + 12 + blockSize/2
-	case 13: // Q5_K
-		return 2 + 2 + 12 + blockSize/8 + blockSize/2
-	case 14: // Q6_K
-		return blockSize/2 + blockSize/4 + blockSize/16 + 2
-	case 15: // Q8_K
-		return 2 + blockSize + 2*blockSize/16
-	case 16: // IQ2_XXS
-		return 2 + 2*blockSize/8
-	case 17: // IQ2_XS
-		return 2 + 2*blockSize/8 + blockSize/32
-	case 18: // IQ3_XXS
-		return 2 + 3*blockSize/8
-	default:
-		return 0
-	}
-}
-
-func (t Tensor) Parameters() uint64 {
-	var count uint64 = 1
-	for _, n := range t.Shape {
-		count *= n
-	}
-	return count
-}
-
-func (t Tensor) Size() uint64 {
-	return t.Parameters() * t.TypeSize() / t.BlockSize()
-}
-
-func (t Tensor) Repack(data []uint16, heads int) ([]uint16, error) {
-	n := tensor.New(tensor.WithShape(int(t.Shape[0]), int(t.Shape[1])), tensor.WithBacking(data))
-	origShape := n.Shape().Clone()
-
-	// reshape the tensor and swap axes 1 and 2 to unpack the layer for gguf
-	if err := n.Reshape(heads, 2, origShape[0]/heads/2, origShape[1]); err != nil {
-		return []uint16{}, err
-	}
-
-	if err := n.T(0, 2, 1, 3); err != nil {
-		return []uint16{}, err
-	}
-
-	if err := n.Reshape(origShape...); err != nil {
-		return []uint16{}, err
-	}
-
-	if err := n.Transpose(); err != nil {
-		return []uint16{}, err
-	}
-	newN, err := native.SelectU16(n, 1)
-	if err != nil {
-		return []uint16{}, err
-	}
-
-	var fullTensor []uint16
-	for _, v := range newN {
-		fullTensor = append(fullTensor, v...)
-	}
-	return fullTensor, nil
-}
-
-type GGUFModel struct {
-	*ContainerGGUF
+type gguf struct {
+	*containerGGUF
 
 	KV
 	Tensors []Tensor
@@ -209,30 +96,40 @@ type GGUFModel struct {
 	parameters uint64
 }
 
-func NewGGUFModel(container *ContainerGGUF) *GGUFModel {
-	return &GGUFModel{
-		ContainerGGUF: container,
+func newGGUF(container *containerGGUF) *gguf {
+	return &gguf{
+		containerGGUF: container,
 		KV:            make(KV),
 	}
 }
 
-func (llm *GGUFModel) NumTensor() uint64 {
-	if llm.Version == 1 {
+func NewGGUFV3(bo binary.ByteOrder) *gguf {
+	return newGGUF(&containerGGUF{ByteOrder: bo, Version: 3})
+}
+
+func (llm *gguf) numTensor() uint64 {
+	switch llm.Version {
+	case 1:
 		return uint64(llm.V1.NumTensor)
+	case 2:
+		return llm.V2.NumTensor
+	default:
+		return llm.V3.NumTensor
 	}
-
-	return llm.V2.NumTensor
 }
 
-func (llm *GGUFModel) NumKV() uint64 {
-	if llm.Version == 1 {
+func (llm *gguf) numKV() uint64 {
+	switch llm.Version {
+	case 1:
 		return uint64(llm.V1.NumKV)
+	case 2:
+		return llm.V2.NumKV
+	default:
+		return llm.V3.NumKV
 	}
-
-	return llm.V2.NumKV
 }
 
-func (llm *GGUFModel) ModelFamily() string {
+func (llm *gguf) ModelFamily() string {
 	if t, ok := llm.KV["general.architecture"].(string); ok {
 		return t
 	}
@@ -240,7 +137,7 @@ func (llm *GGUFModel) ModelFamily() string {
 	return "unknown"
 }
 
-func (llm *GGUFModel) ModelType() string {
+func (llm *gguf) ModelType() string {
 	if llm.parameters > 0 {
 		return format.HumanNumber(llm.parameters)
 	}
@@ -248,7 +145,7 @@ func (llm *GGUFModel) ModelType() string {
 	return "unknown"
 }
 
-func (llm *GGUFModel) FileType() string {
+func (llm *gguf) FileType() string {
 	if t, ok := llm.KV["general.file_type"].(uint32); ok {
 		return fileType(t)
 	}
@@ -256,463 +153,98 @@ func (llm *GGUFModel) FileType() string {
 	return "unknown"
 }
 
-func (llm *GGUFModel) Encode(f *os.File) error {
-	// this mimics the order of the llama.cpp convert script
-	kOrder := []string{
-		"general.architecture",
-		"general.name",
-		"llama.context_length",
-		"llama.embedding_length",
-		"llama.block_count",
-		"llama.feed_forward_length",
-		"llama.rope.dimension_count",
-		"llama.attention.head_count",
-		"llama.attention.head_count_kv",
-		"llama.attention.layer_norm_rms_epsilon",
-		"llama.rope.freq_base",
-		"general.file_type",
-		"tokenizer.ggml.model",
-		"tokenizer.ggml.tokens",
-		"tokenizer.ggml.scores",
-		"tokenizer.ggml.token_type",
-		"tokenizer.ggml.bos_token_id",
-		"tokenizer.ggml.eos_token_id",
-		"tokenizer.ggml.unknown_token_id",
-		"tokenizer.ggml.add_bos_token",
-		"tokenizer.ggml.add_eos_token",
-		"tokenizer.chat_template",
-	}
-
-	if err := binary.Write(f, llm.ByteOrder, []byte("GGUF")); err != nil {
-		return err
-	}
-
-	if err := binary.Write(f, llm.ByteOrder, uint32(3)); err != nil {
-		return err
-	}
-
-	if err := binary.Write(f, llm.ByteOrder, uint64(llm.V3.NumTensor)); err != nil {
-		return err
-	}
-
-	if err := binary.Write(f, llm.ByteOrder, uint64(llm.V3.NumKV)); err != nil {
-		return err
-	}
-
-	for _, k := range kOrder {
-		val, ok := llm.KV[k]
-		if !ok {
-			continue
-		}
-
-		if err := binary.Write(f, llm.ByteOrder, uint64(len(k))); err != nil {
-			return err
-		}
-		if err := binary.Write(f, llm.ByteOrder, []byte(k)); err != nil {
-			return err
-		}
-
-		switch v := val.(type) {
-		case uint32:
-			if err := binary.Write(f, llm.ByteOrder, GGUFTypeUint32); err != nil {
-				return err
-			}
-
-			if err := llm.writeUint32(f, v); err != nil {
-				return err
-			}
-		case float32:
-			if err := binary.Write(f, llm.ByteOrder, GGUFTypeFloat32); err != nil {
-				return err
-			}
-
-			if err := llm.writeF32(f, v); err != nil {
-				return err
-			}
-		case bool:
-			if err := binary.Write(f, llm.ByteOrder, GGUFTypeBool); err != nil {
-				return err
-			}
-
-			if err := llm.writeBool(f, v); err != nil {
-				return err
-			}
-		case string:
-			if err := binary.Write(f, llm.ByteOrder, GGUFTypeString); err != nil {
-				return err
-			}
-
-			if err := llm.writeString(f, v); err != nil {
-				return err
-			}
-		case []int32:
-			if err := binary.Write(f, llm.ByteOrder, GGUFTypeArray); err != nil {
-				return err
-			}
-
-			if err := binary.Write(f, llm.ByteOrder, GGUFTypeInt32); err != nil {
-				return err
-			}
-
-			if err := binary.Write(f, llm.ByteOrder, uint64(len(v))); err != nil {
-				return err
-			}
-			for _, i := range v {
-				if err := llm.writeInt32(f, i); err != nil {
-					return err
-				}
-			}
-		case []uint32:
-			if err := binary.Write(f, llm.ByteOrder, GGUFTypeArray); err != nil {
-				return err
-			}
-
-			if err := binary.Write(f, llm.ByteOrder, GGUFTypeUint32); err != nil {
-				return err
-			}
-
-			if err := binary.Write(f, llm.ByteOrder, uint64(len(v))); err != nil {
-				return err
-			}
-			for _, i := range v {
-				if err := llm.writeUint32(f, i); err != nil {
-					return err
-				}
-			}
-		case []float32:
-			if err := binary.Write(f, llm.ByteOrder, GGUFTypeArray); err != nil {
-				return err
-			}
-
-			if err := binary.Write(f, llm.ByteOrder, GGUFTypeFloat32); err != nil {
-				return err
-			}
-
-			if err := binary.Write(f, llm.ByteOrder, uint64(len(v))); err != nil {
-				return err
-			}
-			for _, fl := range v {
-				if err := llm.writeF32(f, fl); err != nil {
-					return err
-				}
-			}
-		case []string:
-			if err := binary.Write(f, llm.ByteOrder, GGUFTypeArray); err != nil {
-				return err
-			}
-
-			if err := binary.Write(f, llm.ByteOrder, GGUFTypeString); err != nil {
-				return err
-			}
-
-			if err := binary.Write(f, llm.ByteOrder, uint64(len(v))); err != nil {
-				return err
-			}
-
-			for _, s := range v {
-				if err := llm.writeString(f, s); err != nil {
-					return err
-				}
-			}
-		}
-	}
-
-	// write layer metadata
-	for _, t := range llm.Tensors {
-		if err := llm.writeString(f, t.Name); err != nil {
+func (llm *gguf) Decode(rs io.ReadSeeker) error {
+	// decode key-values
+	for i := 0; uint64(i) < llm.numKV(); i++ {
+		k, err := readGGUFString(llm, rs)
+		if err != nil {
 			return err
 		}
 
-		// the dimensions of the tensor
-		dims := 1
-		if t.Shape[1] > 0 {
-			dims = 2
-		}
-
-		if err := binary.Write(f, llm.ByteOrder, uint32(dims)); err != nil {
+		t, err := readGGUF[uint32](llm, rs)
+		if err != nil {
 			return err
 		}
 
-		for i := 0; i < dims; i++ {
-			if err := binary.Write(f, llm.ByteOrder, uint64(t.Shape[dims-1-i])); err != nil {
-				return err
-			}
-		}
-
-		if err := binary.Write(f, llm.ByteOrder, uint32(t.Kind)); err != nil {
-			return err
+		var v any
+		switch t {
+		case ggufTypeUint8:
+			v, err = readGGUF[uint8](llm, rs)
+		case ggufTypeInt8:
+			v, err = readGGUF[int8](llm, rs)
+		case ggufTypeUint16:
+			v, err = readGGUF[uint16](llm, rs)
+		case ggufTypeInt16:
+			v, err = readGGUF[int16](llm, rs)
+		case ggufTypeUint32:
+			v, err = readGGUF[uint32](llm, rs)
+		case ggufTypeInt32:
+			v, err = readGGUF[int32](llm, rs)
+		case ggufTypeUint64:
+			v, err = readGGUF[uint64](llm, rs)
+		case ggufTypeInt64:
+			v, err = readGGUF[int64](llm, rs)
+		case ggufTypeFloat32:
+			v, err = readGGUF[float32](llm, rs)
+		case ggufTypeFloat64:
+			v, err = readGGUF[float64](llm, rs)
+		case ggufTypeBool:
+			v, err = readGGUF[bool](llm, rs)
+		case ggufTypeString:
+			v, err = readGGUFString(llm, rs)
+		case ggufTypeArray:
+			v, err = readGGUFArray(llm, rs)
+		default:
+			return fmt.Errorf("invalid type: %d", t)
 		}
 
-		if err := binary.Write(f, llm.ByteOrder, uint64(t.Offset)); err != nil {
+		if err != nil {
 			return err
 		}
-	}
-
-	offset, terr := f.Seek(0, io.SeekCurrent)
-	if terr != nil {
-		return terr
-	}
-	slog.Debug(fmt.Sprintf("tensors offset = %x", offset))
 
-	if err := llm.writePadding(f, 32); err != nil {
-		return err
+		llm.KV[k] = v
 	}
 
-	var dataFile *os.File
-	var currentFile string
-	var err error
-	for _, t := range llm.Tensors {
-		if currentFile != t.FileName {
-			if f != nil {
-				dataFile.Close()
-			}
-			currentFile = t.FileName
-			dataFile, err = os.Open(t.FileName)
-			if err != nil {
-				fmt.Println(err)
-				return err
-			}
-		}
-
-		dataFile.Seek(int64(t.OffsetPadding+t.FileOffsets[0]), 0)
-
-		pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
-		re, err := regexp.Compile(pattern)
+	// decode tensors
+	for i := 0; uint64(i) < llm.numTensor(); i++ {
+		name, err := readGGUFString(llm, rs)
 		if err != nil {
 			return err
 		}
 
-		matches := re.FindAllStringSubmatch(t.Name, -1)
-		if len(matches) > 0 {
-			layerSize := t.FileOffsets[1] - t.FileOffsets[0]
-
-			var err error
-			tData := make([]uint16, layerSize/2)
-			if err = binary.Read(dataFile, llm.ByteOrder, tData); err != nil {
-				return err
-			}
-
-			layerType := matches[0][re.SubexpIndex("layer")]
-			var heads uint32
-			switch layerType {
-			case "q":
-				heads = llm.KV["llama.attention.head_count"].(uint32)
-			case "k":
-				heads = llm.KV["llama.attention.head_count_kv"].(uint32)
-				if heads == 0 {
-					heads = llm.KV["llama.attention.head_count"].(uint32)
-				}
-			}
-
-			tData, err = t.Repack(tData, int(heads))
-			if err != nil {
-				return err
-			}
-
-			var buf []byte
-			for _, n := range tData {
-				buf = binary.LittleEndian.AppendUint16(buf, n)
-			}
-
-			tempBuf := make([]uint16, len(tData))
-			tDataF32 := bfloat16.DecodeFloat32(buf)
-			for cnt, v := range tDataF32 {
-				tDataF16 := float16.Fromfloat32(v)
-				tempBuf[cnt] = uint16(tDataF16)
-			}
-
-			if err = binary.Write(f, llm.ByteOrder, tempBuf); err != nil {
-				return err
-			}
-
-			if err := llm.writePadding(f, 32); err != nil {
-				return err
-			}
-			continue
-		}
-
-		remaining := t.FileOffsets[1] - t.FileOffsets[0]
-
-		bufSize := uint64(10240)
-		var finished bool
-		for {
-			data := make([]byte, min(bufSize, remaining))
-
-			b, err := io.ReadFull(dataFile, data)
-			remaining -= uint64(b)
-
-			if errors.Is(err, io.EOF) || remaining <= 0 {
-				finished = true
-			} else if err != nil {
-				return err
-			}
-
-			// convert bfloat16 -> ieee float32
-			tDataF32 := bfloat16.DecodeFloat32(data)
-
-			switch t.Kind {
-			case 0:
-				if err := binary.Write(f, llm.ByteOrder, tDataF32); err != nil {
-					return err
-				}
-			case 1:
-				// convert float32 -> float16
-				tempBuf := make([]uint16, len(data)/2)
-				for cnt, v := range tDataF32 {
-					tDataF16 := float16.Fromfloat32(v)
-					tempBuf[cnt] = uint16(tDataF16)
-				}
-				if err := binary.Write(f, llm.ByteOrder, tempBuf); err != nil {
-					return err
-				}
-			}
-			if finished {
-				break
-			}
-		}
-
-		if err := llm.writePadding(f, 32); err != nil {
-			return err
-		}
-	}
-	f.Close()
-
-	return nil
-}
-
-func (llm *GGUFModel) writePadding(f *os.File, align int64) error {
-	// gguf file padding is defined in https://github.com/ggerganov/ggml/blob/master/docs/gguf.md#file-structure
-	offset, err := f.Seek(0, io.SeekCurrent)
-	if err != nil {
-		return err
-	}
-	padding := ((offset + align - 1) / align) * align
-	buf := make([]byte, padding-offset)
-	if err := binary.Write(f, llm.ByteOrder, buf); err != nil {
-		return err
-	}
-
-	return nil
-}
-
-func (llm *GGUFModel) writeInt32(f *os.File, v int32) error {
-	if err := binary.Write(f, llm.ByteOrder, v); err != nil {
-		return err
-	}
-	return nil
-}
-
-func (llm *GGUFModel) writeUint32(f *os.File, v uint32) error {
-	if err := binary.Write(f, llm.ByteOrder, v); err != nil {
-		return err
-	}
-	return nil
-}
-
-func (llm *GGUFModel) writeF32(f *os.File, v float32) error {
-	if err := binary.Write(f, llm.ByteOrder, v); err != nil {
-		return err
-	}
-	return nil
-}
-
-func (llm *GGUFModel) writeBool(f *os.File, b bool) error {
-	if err := binary.Write(f, llm.ByteOrder, b); err != nil {
-		return err
-	}
-	return nil
-}
-
-func (llm *GGUFModel) writeString(f *os.File, s string) error {
-	if err := binary.Write(f, llm.ByteOrder, uint64(len(s))); err != nil {
-		return err
-	}
-
-	if err := binary.Write(f, llm.ByteOrder, []byte(s)); err != nil {
-		return err
-	}
-	return nil
-}
-
-func (llm *GGUFModel) Decode(rs io.ReadSeeker) error {
-	// decode key-values
-	for i := 0; uint64(i) < llm.NumKV(); i++ {
-		k, err := llm.readString(rs)
+		// dims is the number of dimensions in the tensor
+		dims, err := readGGUF[uint32](llm, rs)
 		if err != nil {
 			return err
 		}
 
-		vtype := llm.readU32(rs)
-
-		var v any
-		switch vtype {
-		case GGUFTypeUint8:
-			v = llm.readU8(rs)
-		case GGUFTypeInt8:
-			v = llm.readI8(rs)
-		case GGUFTypeUint16:
-			v = llm.readU16(rs)
-		case GGUFTypeInt16:
-			v = llm.readI16(rs)
-		case GGUFTypeUint32:
-			v = llm.readU32(rs)
-		case GGUFTypeInt32:
-			v = llm.readI32(rs)
-		case GGUFTypeUint64:
-			v = llm.readU64(rs)
-		case GGUFTypeInt64:
-			v = llm.readI64(rs)
-		case GGUFTypeFloat32:
-			v = llm.readF32(rs)
-		case GGUFTypeFloat64:
-			v = llm.readF64(rs)
-		case GGUFTypeBool:
-			v = llm.readBool(rs)
-		case GGUFTypeString:
-			s, err := llm.readString(rs)
-			if err != nil {
-				return err
-			}
-
-			v = s
-		case GGUFTypeArray:
-			a, err := llm.readArray(rs)
+		shape := [4]uint64{1, 1, 1, 1}
+		for i := 0; uint32(i) < dims; i++ {
+			shape[i], err = readGGUF[uint64](llm, rs)
 			if err != nil {
 				return err
 			}
-
-			v = a
-		default:
-			return fmt.Errorf("invalid type: %d", vtype)
 		}
 
-		llm.KV[k] = v
-	}
-
-	// decode tensors
-	for i := 0; uint64(i) < llm.NumTensor(); i++ {
-		name, err := llm.readString(rs)
+		kind, err := readGGUF[uint32](llm, rs)
 		if err != nil {
 			return err
 		}
 
-		// dims is the number of dimensions in the tensor
-		dims := llm.readU32(rs)
-
-		shape := [4]uint64{1, 1, 1, 1}
-		for i := 0; uint32(i) < dims; i++ {
-			shape[i] = llm.readU64(rs)
+		offset, err := readGGUF[uint64](llm, rs)
+		if err != nil {
+			return err
 		}
 
 		tensor := Tensor{
 			Name:   name,
-			Kind:   llm.readU32(rs),
-			Offset: llm.readU64(rs),
+			Kind:   kind,
+			Offset: offset,
 			Shape:  shape[:],
 		}
 
 		llm.Tensors = append(llm.Tensors, tensor)
-		llm.parameters += tensor.Parameters()
+		llm.parameters += tensor.parameters()
 	}
 
 	alignment, ok := llm.KV["general.alignment"].(uint32)
@@ -725,12 +257,13 @@ func (llm *GGUFModel) Decode(rs io.ReadSeeker) error {
 		return err
 	}
 
-	if _, err := rs.Seek(int64(alignment)-offset%int64(alignment), io.SeekCurrent); err != nil {
+	padding := llm.padding(offset, int64(alignment))
+	if _, err := rs.Seek(padding, io.SeekCurrent); err != nil {
 		return err
 	}
 
 	for _, tensor := range llm.Tensors {
-		padded := (int64(tensor.Size()) + int64(alignment) - 1) & ^(int64(alignment) - 1)
+		padded := (int64(tensor.size()) + int64(alignment) - 1) & ^(int64(alignment) - 1)
 		if _, err := rs.Seek(padded, io.SeekCurrent); err != nil {
 			return err
 		}
@@ -739,7 +272,7 @@ func (llm *GGUFModel) Decode(rs io.ReadSeeker) error {
 	return nil
 }
 
-func (llm *GGUFModel) NumLayers() uint32 {
+func (llm *gguf) NumLayers() uint32 {
 	value, exists := llm.KV[fmt.Sprintf("%s.block_count", llm.ModelFamily())]
 	if !exists {
 		return 0
@@ -748,7 +281,7 @@ func (llm *GGUFModel) NumLayers() uint32 {
 	return value.(uint32)
 }
 
-func (llm *GGUFModel) NumHead() uint32 {
+func (llm *gguf) NumHead() uint32 {
 	value, exists := llm.KV[fmt.Sprintf("%s.attention.head_count", llm.ModelFamily())]
 	if !exists {
 		return 0
@@ -757,7 +290,7 @@ func (llm *GGUFModel) NumHead() uint32 {
 	return value.(uint32)
 }
 
-func (llm *GGUFModel) NumEmbed() uint32 {
+func (llm *gguf) NumEmbed() uint32 {
 	value, exists := llm.KV[fmt.Sprintf("%s.embedding_length", llm.ModelFamily())]
 	if !exists {
 		return 0
@@ -766,7 +299,7 @@ func (llm *GGUFModel) NumEmbed() uint32 {
 	return value.(uint32)
 }
 
-func (llm *GGUFModel) NumHeadKv() uint32 {
+func (llm *gguf) NumHeadKv() uint32 {
 	value, exists := llm.KV[fmt.Sprintf("%s.attention.head_count_kv", llm.ModelFamily())]
 	if !exists {
 		return 0
@@ -775,7 +308,7 @@ func (llm *GGUFModel) NumHeadKv() uint32 {
 	return value.(uint32)
 }
 
-func (llm *GGUFModel) NumCtx() uint32 {
+func (llm *gguf) NumCtx() uint32 {
 	value, exists := llm.KV[fmt.Sprintf("%s.context_length", llm.ModelFamily())]
 	if !exists {
 		return 0
@@ -784,7 +317,7 @@ func (llm *GGUFModel) NumCtx() uint32 {
 	return value.(uint32)
 }
 
-func (llm *GGUFModel) NumGQA() uint32 {
+func (llm *gguf) NumGQA() uint32 {
 	numHeadKv := llm.NumHeadKv()
 	if numHeadKv == 0 {
 		return 0
@@ -793,78 +326,28 @@ func (llm *GGUFModel) NumGQA() uint32 {
 	return llm.NumHead() / numHeadKv
 }
 
-func (llm GGUFModel) readU8(r io.Reader) uint8 {
-	var u8 uint8
-	binary.Read(r, llm.ByteOrder, &u8)
-	return u8
-}
-
-func (llm GGUFModel) readI8(r io.Reader) int8 {
-	var i8 int8
-	binary.Read(r, llm.ByteOrder, &i8)
-	return i8
-}
-
-func (llm GGUFModel) readU16(r io.Reader) uint16 {
-	var u16 uint16
-	binary.Read(r, llm.ByteOrder, &u16)
-	return u16
-}
-
-func (llm GGUFModel) readI16(r io.Reader) int16 {
-	var i16 int16
-	binary.Read(r, llm.ByteOrder, &i16)
-	return i16
-}
-
-func (llm GGUFModel) readU32(r io.Reader) uint32 {
-	var u32 uint32
-	binary.Read(r, llm.ByteOrder, &u32)
-	return u32
-}
-
-func (llm GGUFModel) readI32(r io.Reader) int32 {
-	var i32 int32
-	binary.Read(r, llm.ByteOrder, &i32)
-	return i32
+func readGGUF[T any](llm *gguf, r io.Reader) (T, error) {
+	var t T
+	err := binary.Read(r, llm.ByteOrder, &t)
+	return t, err
 }
 
-func (llm GGUFModel) readU64(r io.Reader) uint64 {
-	var u64 uint64
-	binary.Read(r, llm.ByteOrder, &u64)
-	return u64
-}
-
-func (llm GGUFModel) readI64(r io.Reader) int64 {
-	var i64 int64
-	binary.Read(r, llm.ByteOrder, &i64)
-	return i64
-}
-
-func (llm GGUFModel) readF32(r io.Reader) float32 {
-	var f32 float32
-	binary.Read(r, llm.ByteOrder, &f32)
-	return f32
-}
-
-func (llm GGUFModel) readF64(r io.Reader) float64 {
-	var f64 float64
-	binary.Read(r, llm.ByteOrder, &f64)
-	return f64
-}
+func writeGGUF[V any](llm *gguf, w io.Writer, t uint32, v V) error {
+	if err := binary.Write(w, llm.ByteOrder, t); err != nil {
+		return err
+	}
 
-func (llm GGUFModel) readBool(r io.Reader) bool {
-	var b bool
-	binary.Read(r, llm.ByteOrder, &b)
-	return b
+	return binary.Write(w, llm.ByteOrder, v)
 }
 
-func (llm GGUFModel) readStringV1(r io.Reader) (string, error) {
-	var nameLength uint32
-	binary.Read(r, llm.ByteOrder, &nameLength)
+func readGGUFV1String(llm *gguf, r io.Reader) (string, error) {
+	var length uint64
+	if err := binary.Read(r, llm.ByteOrder, &length); err != nil {
+		return "", err
+	}
 
 	var b bytes.Buffer
-	if _, err := io.CopyN(&b, r, int64(nameLength)); err != nil {
+	if _, err := io.CopyN(&b, r, int64(length)); err != nil {
 		return "", err
 	}
 
@@ -874,102 +357,347 @@ func (llm GGUFModel) readStringV1(r io.Reader) (string, error) {
 	return b.String(), nil
 }
 
-func (llm GGUFModel) readString(r io.Reader) (string, error) {
+func readGGUFString(llm *gguf, r io.Reader) (string, error) {
 	if llm.Version == 1 {
-		return llm.readStringV1(r)
+		return readGGUFV1String(llm, r)
 	}
 
-	var nameLength uint64
-	binary.Read(r, llm.ByteOrder, &nameLength)
+	var length uint64
+	if err := binary.Read(r, llm.ByteOrder, &length); err != nil {
+		return "", err
+	}
 
 	var b bytes.Buffer
-	if _, err := io.CopyN(&b, r, int64(nameLength)); err != nil {
+	if _, err := io.CopyN(&b, r, int64(length)); err != nil {
 		return "", err
 	}
 
 	return b.String(), nil
 }
 
-func (llm *GGUFModel) readArrayV1(r io.Reader) (arr []any, err error) {
-	atype := llm.readU32(r)
-	n := llm.readU32(r)
+func writeGGUFString(llm *gguf, w io.Writer, s string) error {
+	if err := binary.Write(w, llm.ByteOrder, ggufTypeString); err != nil {
+		return err
+	}
 
-	for i := 0; uint32(i) < n; i++ {
-		switch atype {
-		case GGUFTypeUint8:
-			arr = append(arr, llm.readU8(r))
-		case GGUFTypeInt8:
-			arr = append(arr, llm.readI8(r))
-		case GGUFTypeUint16:
-			arr = append(arr, llm.readU16(r))
-		case GGUFTypeInt16:
-			arr = append(arr, llm.readI16(r))
-		case GGUFTypeUint32:
-			arr = append(arr, llm.readU32(r))
-		case GGUFTypeInt32:
-			arr = append(arr, llm.readI32(r))
-		case GGUFTypeFloat32:
-			arr = append(arr, llm.readF32(r))
-		case GGUFTypeBool:
-			arr = append(arr, llm.readBool(r))
-		case GGUFTypeString:
-			s, err := llm.readStringV1(r)
-			if err != nil {
-				return nil, err
-			}
+	if err := binary.Write(w, llm.ByteOrder, uint64(len(s))); err != nil {
+		return err
+	}
+
+	_, err := io.Copy(w, strings.NewReader(s))
+	return err
+}
+
+func readGGUFV1Array(llm *gguf, r io.Reader) (a []any, err error) {
+	t, err := readGGUF[uint32](llm, r)
+	if err != nil {
+		return nil, err
+	}
+
+	n, err := readGGUF[uint32](llm, r)
+	if err != nil {
+		return nil, err
+	}
 
-			arr = append(arr, s)
+	for i := 0; uint32(i) < n; i++ {
+		var e any
+		switch t {
+		case ggufTypeUint8:
+			e, err = readGGUF[uint8](llm, r)
+		case ggufTypeInt8:
+			e, err = readGGUF[int8](llm, r)
+		case ggufTypeUint16:
+			e, err = readGGUF[uint16](llm, r)
+		case ggufTypeInt16:
+			e, err = readGGUF[int16](llm, r)
+		case ggufTypeUint32:
+			e, err = readGGUF[uint32](llm, r)
+		case ggufTypeInt32:
+			e, err = readGGUF[int32](llm, r)
+		case ggufTypeUint64:
+			e, err = readGGUF[uint64](llm, r)
+		case ggufTypeInt64:
+			e, err = readGGUF[int64](llm, r)
+		case ggufTypeFloat32:
+			e, err = readGGUF[float32](llm, r)
+		case ggufTypeFloat64:
+			e, err = readGGUF[float64](llm, r)
+		case ggufTypeBool:
+			e, err = readGGUF[bool](llm, r)
+		case ggufTypeString:
+			e, err = readGGUFV1String(llm, r)
 		default:
-			return nil, fmt.Errorf("invalid array type: %d", atype)
+			return nil, fmt.Errorf("invalid array type: %d", t)
 		}
+		if err != nil {
+			return nil, err
+		}
+
+		a = append(a, e)
 	}
 
 	return
 }
 
-func (llm *GGUFModel) readArray(r io.Reader) (arr []any, err error) {
+func readGGUFArray(llm *gguf, r io.Reader) (a []any, err error) {
 	if llm.Version == 1 {
-		return llm.readArrayV1(r)
+		return readGGUFV1Array(llm, r)
 	}
 
-	atype := llm.readU32(r)
-	n := llm.readU64(r)
+	t, err := readGGUF[uint32](llm, r)
+	if err != nil {
+		return nil, err
+	}
 
-	for i := 0; uint64(i) < n; i++ {
-		switch atype {
-		case GGUFTypeUint8:
-			arr = append(arr, llm.readU8(r))
-		case GGUFTypeInt8:
-			arr = append(arr, llm.readI8(r))
-		case GGUFTypeUint16:
-			arr = append(arr, llm.readU16(r))
-		case GGUFTypeInt16:
-			arr = append(arr, llm.readI16(r))
-		case GGUFTypeUint32:
-			arr = append(arr, llm.readU32(r))
-		case GGUFTypeInt32:
-			arr = append(arr, llm.readI32(r))
-		case GGUFTypeUint64:
-			arr = append(arr, llm.readU64(r))
-		case GGUFTypeInt64:
-			arr = append(arr, llm.readI64(r))
-		case GGUFTypeFloat32:
-			arr = append(arr, llm.readF32(r))
-		case GGUFTypeFloat64:
-			arr = append(arr, llm.readF64(r))
-		case GGUFTypeBool:
-			arr = append(arr, llm.readBool(r))
-		case GGUFTypeString:
-			s, err := llm.readString(r)
-			if err != nil {
-				return nil, err
-			}
+	n, err := readGGUF[uint64](llm, r)
+	if err != nil {
+		return nil, err
+	}
 
-			arr = append(arr, s)
+	for i := 0; uint64(i) < n; i++ {
+		var e any
+		switch t {
+		case ggufTypeUint8:
+			e, err = readGGUF[uint8](llm, r)
+		case ggufTypeInt8:
+			e, err = readGGUF[int8](llm, r)
+		case ggufTypeUint16:
+			e, err = readGGUF[uint16](llm, r)
+		case ggufTypeInt16:
+			e, err = readGGUF[int16](llm, r)
+		case ggufTypeUint32:
+			e, err = readGGUF[uint32](llm, r)
+		case ggufTypeInt32:
+			e, err = readGGUF[int32](llm, r)
+		case ggufTypeUint64:
+			e, err = readGGUF[uint64](llm, r)
+		case ggufTypeInt64:
+			e, err = readGGUF[int64](llm, r)
+		case ggufTypeFloat32:
+			e, err = readGGUF[float32](llm, r)
+		case ggufTypeFloat64:
+			e, err = readGGUF[float64](llm, r)
+		case ggufTypeBool:
+			e, err = readGGUF[bool](llm, r)
+		case ggufTypeString:
+			e, err = readGGUFString(llm, r)
 		default:
-			return nil, fmt.Errorf("invalid array type: %d", atype)
+			return nil, fmt.Errorf("invalid array type: %d", t)
 		}
+		if err != nil {
+			return nil, err
+		}
+
+		a = append(a, e)
 	}
 
 	return
 }
+
+func writeGGUFArray[S ~[]E, E any](llm *gguf, w io.Writer, t uint32, s S) error {
+	if err := binary.Write(w, llm.ByteOrder, ggufTypeArray); err != nil {
+		return err
+	}
+
+	if err := binary.Write(w, llm.ByteOrder, t); err != nil {
+		return err
+	}
+
+	if err := binary.Write(w, llm.ByteOrder, uint64(len(s))); err != nil {
+		return err
+	}
+
+	for _, e := range s {
+		if err := binary.Write(w, llm.ByteOrder, e); err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
+var ggufKVOrder = map[string][]string{
+	"llama": {
+		"general.architecture",
+		"general.name",
+		"llama.context_length",
+		"llama.embedding_length",
+		"llama.block_count",
+		"llama.feed_forward_length",
+		"llama.rope.dimension_count",
+		"llama.attention.head_count",
+		"llama.attention.head_count_kv",
+		"llama.attention.layer_norm_rms_epsilon",
+		"llama.rope.freq_base",
+		"gemma.context_length",
+		"gemma.embedding_length",
+		"gemma.block_count",
+		"gemma.feed_forward_length",
+		"gemma.attention.head_count",
+		"gemma.attention.head_count_kv",
+		"gemma.attention.layer_norm_rms_epsilon",
+		"gemma.attention.key_length",
+		"gemma.attention.value_length",
+		"general.file_type",
+		"tokenizer.ggml.model",
+		"tokenizer.ggml.tokens",
+		"tokenizer.ggml.scores",
+		"tokenizer.ggml.token_type",
+		"tokenizer.ggml.bos_token_id",
+		"tokenizer.ggml.eos_token_id",
+		"tokenizer.ggml.unknown_token_id",
+		"tokenizer.ggml.padding_token_id",
+		"tokenizer.ggml.add_bos_token",
+		"tokenizer.ggml.add_eos_token",
+		"tokenizer.chat_template",
+	},
+}
+
+func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
+	switch llm.Version {
+	case 3:
+		llm.V3.NumTensor = uint64(len(tensors))
+		llm.V3.NumKV = uint64(len(kv))
+	default:
+		return fmt.Errorf("not implemented: ggufv%d", llm.Version)
+	}
+
+	if err := binary.Write(ws, llm.ByteOrder, []byte("GGUF")); err != nil {
+		return err
+	}
+
+	if err := binary.Write(ws, llm.ByteOrder, llm.Version); err != nil {
+		return err
+	}
+
+	if err := binary.Write(ws, llm.ByteOrder, llm.numTensor()); err != nil {
+		return err
+	}
+
+	if err := binary.Write(ws, llm.ByteOrder, llm.numKV()); err != nil {
+		return err
+	}
+
+	for _, k := range ggufKVOrder["llama"] {
+		v, ok := kv[k]
+		if !ok {
+			continue
+		}
+
+		if err := binary.Write(ws, llm.ByteOrder, uint64(len(k))); err != nil {
+			return err
+		}
+
+		if err := binary.Write(ws, llm.ByteOrder, []byte(k)); err != nil {
+			return err
+		}
+
+		var err error
+		switch v := v.(type) {
+		case uint32:
+			err = writeGGUF(llm, ws, ggufTypeUint32, v)
+		case float32:
+			err = writeGGUF(llm, ws, ggufTypeFloat32, v)
+		case bool:
+			err = writeGGUF(llm, ws, ggufTypeBool, v)
+		case string:
+			err = writeGGUFString(llm, ws, v)
+		case []int32:
+			err = writeGGUFArray(llm, ws, ggufTypeInt32, v)
+		case []uint32:
+			err = writeGGUFArray(llm, ws, ggufTypeUint32, v)
+		case []float32:
+			err = writeGGUFArray(llm, ws, ggufTypeFloat32, v)
+		case []string:
+			if err := binary.Write(ws, llm.ByteOrder, ggufTypeArray); err != nil {
+				return err
+			}
+
+			if err := binary.Write(ws, llm.ByteOrder, ggufTypeString); err != nil {
+				return err
+			}
+
+			if err := binary.Write(ws, llm.ByteOrder, uint64(len(v))); err != nil {
+				return err
+			}
+
+			for _, e := range v {
+				if err := binary.Write(ws, llm.ByteOrder, uint64(len(e))); err != nil {
+					return err
+				}
+
+				if err := binary.Write(ws, llm.ByteOrder, []byte(e)); err != nil {
+					return err
+				}
+			}
+		}
+		if err != nil {
+			return err
+		}
+	}
+
+	for _, tensor := range tensors {
+		if err := binary.Write(ws, llm.ByteOrder, uint64(len(tensor.Name))); err != nil {
+			return err
+		}
+
+		if err := binary.Write(ws, llm.ByteOrder, []byte(tensor.Name)); err != nil {
+			return err
+		}
+
+		dims := 1
+		if tensor.Shape[1] > 0 {
+			dims = 2
+		}
+
+		if err := binary.Write(ws, llm.ByteOrder, uint32(dims)); err != nil {
+			return err
+		}
+
+		for i := 0; i < dims; i++ {
+			if err := binary.Write(ws, llm.ByteOrder, uint64(tensor.Shape[dims-1-i])); err != nil {
+				return err
+			}
+		}
+
+		if err := binary.Write(ws, llm.ByteOrder, tensor.Kind); err != nil {
+			return err
+		}
+
+		if err := binary.Write(ws, llm.ByteOrder, tensor.Offset); err != nil {
+			return err
+		}
+	}
+
+	offset, err := ws.Seek(0, io.SeekCurrent)
+	if err != nil {
+		return err
+	}
+
+	padding := llm.padding(offset, 32)
+	if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding-offset))); err != nil {
+		return err
+	}
+
+	for _, tensor := range tensors {
+		if _, err := tensor.WriteTo(ws); err != nil {
+			return err
+		}
+
+		offset, err := ws.Seek(0, io.SeekCurrent)
+		if err != nil {
+			return err
+		}
+
+		padding := llm.padding(offset, 32)
+		if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding-offset))); err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
+func (gguf) padding(offset, align int64) int64 {
+	return (offset + align - 1) / align * align
+}

+ 3 - 1
readline/history.go

@@ -142,7 +142,9 @@ func (h *History) Save() error {
 	for cnt := 0; cnt < h.Size(); cnt++ {
 		v, _ := h.Buf.Get(cnt)
 		line, _ := v.([]rune)
-		buf.WriteString(string(line) + "\n")
+		if _, err := buf.WriteString(string(line) + "\n"); err != nil {
+			return err
+		}
 	}
 	buf.Flush()
 	f.Close()

+ 16 - 8
server/images.go

@@ -321,7 +321,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
 
 			pathName := realpath(modelFileDir, c.Args)
 
-			ggufName, err := convertSafetensors(name, pathName)
+			ggufName, err := convertSafetensors(name, pathName, fn)
 			if err != nil {
 				var pathErr *fs.PathError
 				switch {
@@ -336,6 +336,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
 
 			if ggufName != "" {
 				pathName = ggufName
+				slog.Debug(fmt.Sprintf("new image layer path: %s", pathName))
 				defer os.RemoveAll(ggufName)
 			}
 
@@ -422,10 +423,13 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
 		CREATE:
 			for {
 				fn(api.ProgressResponse{Status: "creating model layer"})
+				if _, err := bin.Seek(offset, io.SeekStart); err != nil {
+					return err
+				}
 
-				bin.Seek(offset, io.SeekStart)
 				ggml, err := llm.DecodeGGML(bin)
 				if err != nil {
+					slog.Error(fmt.Sprintf("error decoding gguf file: %q", err))
 					switch {
 					case errors.Is(err, io.EOF):
 						break CREATE
@@ -621,8 +625,8 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
 	return nil
 }
 
-func convertSafetensors(name, fn string) (string, error) {
-	r, err := zip.OpenReader(fn)
+func convertSafetensors(name, path string, fn func(resp api.ProgressResponse)) (string, error) {
+	r, err := zip.OpenReader(path)
 	if err != nil {
 		return "", err
 	}
@@ -634,6 +638,7 @@ func convertSafetensors(name, fn string) (string, error) {
 	}
 	defer os.RemoveAll(tempDir)
 
+	fn(api.ProgressResponse{Status: "unpacking model metadata"})
 	for _, f := range r.File {
 		fpath := filepath.Join(tempDir, f.Name)
 		outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
@@ -662,6 +667,7 @@ func convertSafetensors(name, fn string) (string, error) {
 
 	SupportedArchs := []string{
 		"MistralForCausalLM",
+		"GemmaForCausalLM",
 	}
 
 	for _, arch := range params.Architectures {
@@ -670,22 +676,24 @@ func convertSafetensors(name, fn string) (string, error) {
 		}
 	}
 
-	t, err := convert.GetSafeTensors(tempDir)
+	fn(api.ProgressResponse{Status: "processing safetensors"})
+	t, err := convert.GetSafeTensors(tempDir, params)
 	if err != nil {
 		return "", err
 	}
 
-	vocab, err := convert.LoadTokens(tempDir)
+	vocab, err := convert.LoadTokens(tempDir, params)
 	if err != nil {
 		return "", err
 	}
 
-	fn, err = convert.WriteGGUF(name, t, params, vocab)
+	fn(api.ProgressResponse{Status: "converting model"})
+	path, err = convert.WriteGGUF(name, t, params, vocab)
 	if err != nil {
 		return "", err
 	}
 
-	return fn, nil
+	return path, nil
 }
 
 func CopyModel(src, dest string) error {

+ 12 - 2
server/routes_test.go

@@ -3,6 +3,7 @@ package server
 import (
 	"bytes"
 	"context"
+	"encoding/binary"
 	"encoding/json"
 	"fmt"
 	"io"
@@ -31,13 +32,22 @@ func Test_Routes(t *testing.T) {
 	}
 
 	createTestFile := func(t *testing.T, name string) string {
+		t.Helper()
+
 		f, err := os.CreateTemp(t.TempDir(), name)
 		assert.Nil(t, err)
 		defer f.Close()
 
-		_, err = f.Write([]byte("GGUF"))
+		err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
+		assert.Nil(t, err)
+
+		err = binary.Write(f, binary.LittleEndian, uint32(3))
 		assert.Nil(t, err)
-		_, err = f.Write([]byte{0x2, 0})
+
+		err = binary.Write(f, binary.LittleEndian, uint64(0))
+		assert.Nil(t, err)
+
+		err = binary.Write(f, binary.LittleEndian, uint64(0))
 		assert.Nil(t, err)
 
 		return f.Name()