mistral.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. package convert
  2. import (
  3. "encoding/binary"
  4. "fmt"
  5. "io"
  6. "os"
  7. "regexp"
  8. "strings"
  9. "github.com/d4l3k/go-bfloat16"
  10. "github.com/pdevine/tensor"
  11. "github.com/pdevine/tensor/native"
  12. "github.com/x448/float16"
  13. "github.com/ollama/ollama/llm"
  14. )
  15. type MistralModel struct {
  16. ModelData
  17. }
  18. func mistralLayerHandler(w io.Writer, r safetensorWriterTo, f *os.File) error {
  19. layerSize := r.end - r.start
  20. var err error
  21. tData := make([]uint16, layerSize/2)
  22. if err = binary.Read(f, r.bo, tData); err != nil {
  23. return err
  24. }
  25. var heads uint32
  26. if strings.Contains(r.t.Name, "attn_q") {
  27. heads = uint32(r.params.AttentionHeads)
  28. } else if strings.Contains(r.t.Name, "attn_k") {
  29. heads = uint32(r.params.KeyValHeads)
  30. if heads == 0 {
  31. heads = uint32(r.params.AttentionHeads)
  32. }
  33. } else {
  34. return fmt.Errorf("unknown layer type")
  35. }
  36. tData, err = repack(tData, int(heads), r.t.Shape)
  37. if err != nil {
  38. return err
  39. }
  40. var buf []byte
  41. for _, n := range tData {
  42. buf = r.bo.AppendUint16(buf, n)
  43. }
  44. tempBuf := make([]uint16, len(tData))
  45. tDataF32 := bfloat16.DecodeFloat32(buf)
  46. for cnt, v := range tDataF32 {
  47. tDataF16 := float16.Fromfloat32(v)
  48. tempBuf[cnt] = uint16(tDataF16)
  49. }
  50. if err = binary.Write(w, r.bo, tempBuf); err != nil {
  51. return err
  52. }
  53. return nil
  54. }
  55. func repack(data []uint16, heads int, shape []uint64) ([]uint16, error) {
  56. n := tensor.New(tensor.WithShape(int(shape[0]), int(shape[1])), tensor.WithBacking(data))
  57. origShape := n.Shape().Clone()
  58. // reshape the tensor and swap axes 1 and 2 to unpack the layer for gguf
  59. if err := n.Reshape(heads, 2, origShape[0]/heads/2, origShape[1]); err != nil {
  60. return nil, err
  61. }
  62. if err := n.T(0, 2, 1, 3); err != nil {
  63. return nil, err
  64. }
  65. if err := n.Reshape(origShape...); err != nil {
  66. return nil, err
  67. }
  68. if err := n.Transpose(); err != nil {
  69. return nil, err
  70. }
  71. newN, err := native.SelectU16(n, 1)
  72. if err != nil {
  73. return nil, err
  74. }
  75. var fullTensor []uint16
  76. for _, v := range newN {
  77. fullTensor = append(fullTensor, v...)
  78. }
  79. return fullTensor, nil
  80. }
  81. func (m *MistralModel) GetTensors() error {
  82. t, err := GetSafeTensors(m.Path, m.Params)
  83. if err != nil {
  84. return err
  85. }
  86. m.Tensors = []llm.Tensor{}
  87. pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
  88. re, err := regexp.Compile(pattern)
  89. if err != nil {
  90. return err
  91. }
  92. for _, l := range t {
  93. matches := re.FindAllStringSubmatch(l.Name, -1)
  94. if len(matches) > 0 {
  95. wt := l.WriterTo.(safetensorWriterTo)
  96. wt.handler = mistralLayerHandler
  97. l.WriterTo = wt
  98. }
  99. m.Tensors = append(m.Tensors, l)
  100. }
  101. return nil
  102. }
  103. func (m *MistralModel) LoadVocab() error {
  104. v, err := LoadSentencePieceTokens(m.Path, m.Params.VocabSize)
  105. if err != nil {
  106. return err
  107. }
  108. m.Vocab = v
  109. return nil
  110. }
  111. func (m *MistralModel) WriteGGUF() (string, error) {
  112. kv := llm.KV{
  113. "general.architecture": "llama",
  114. "general.name": m.Name,
  115. "llama.context_length": uint32(m.Params.ContextSize),
  116. "llama.embedding_length": uint32(m.Params.HiddenSize),
  117. "llama.block_count": uint32(m.Params.HiddenLayers),
  118. "llama.feed_forward_length": uint32(m.Params.IntermediateSize),
  119. "llama.rope.dimension_count": uint32(m.Params.HiddenSize / m.Params.AttentionHeads),
  120. "llama.attention.head_count": uint32(m.Params.AttentionHeads),
  121. "llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
  122. "llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
  123. "general.file_type": uint32(1),
  124. "tokenizer.ggml.model": "llama",
  125. "tokenizer.ggml.tokens": m.Vocab.Tokens,
  126. "tokenizer.ggml.scores": m.Vocab.Scores,
  127. "tokenizer.ggml.token_type": m.Vocab.Types,
  128. "tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
  129. "tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
  130. "tokenizer.ggml.add_bos_token": true,
  131. "tokenizer.ggml.add_eos_token": false,
  132. "tokenizer.ggml.unknown_token_id": uint32(0),
  133. }
  134. f, err := os.CreateTemp("", "ollama-gguf")
  135. if err != nil {
  136. return "", err
  137. }
  138. defer f.Close()
  139. mod := llm.NewGGUFV3(m.Params.ByteOrder)
  140. if err := mod.Encode(f, kv, m.Tensors); err != nil {
  141. return "", err
  142. }
  143. return f.Name(), nil
  144. }