llama.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. package convert
  2. import (
  3. "encoding/binary"
  4. "fmt"
  5. "io"
  6. "log/slog"
  7. "os"
  8. "regexp"
  9. "strings"
  10. "github.com/nlpodyssey/gopickle/pytorch"
  11. "github.com/pdevine/tensor"
  12. "github.com/pdevine/tensor/native"
  13. "github.com/x448/float16"
  14. "github.com/ollama/ollama/llm"
  15. )
  16. type LlamaModel struct {
  17. ModelData
  18. }
  19. func llamaLayerHandler(w io.Writer, r torchWriterTo) error {
  20. slog.Debug(fmt.Sprintf("repacking layer '%s'", r.t.Name))
  21. data := r.storage.(*pytorch.HalfStorage).Data
  22. tData := make([]uint16, len(data))
  23. for cnt, v := range data {
  24. tData[cnt] = uint16(float16.Fromfloat32(v))
  25. }
  26. var err error
  27. var heads uint32
  28. if strings.Contains(r.t.Name, "attn_q") {
  29. heads = uint32(r.params.AttentionHeads)
  30. } else if strings.Contains(r.t.Name, "attn_k") {
  31. heads = uint32(r.params.KeyValHeads)
  32. if heads == 0 {
  33. heads = uint32(r.params.AttentionHeads)
  34. }
  35. } else {
  36. return fmt.Errorf("unknown layer type")
  37. }
  38. slog.Debug(fmt.Sprintf("heads = %d", heads))
  39. tData, err = llamaRepack(tData, int(heads), r.t.Shape)
  40. if err != nil {
  41. return err
  42. }
  43. if err = binary.Write(w, r.bo, tData); err != nil {
  44. return err
  45. }
  46. return nil
  47. }
  48. func llamaRepack(data []uint16, heads int, shape []uint64) ([]uint16, error) {
  49. n := tensor.New(tensor.WithShape(int(shape[0]), int(shape[1])), tensor.WithBacking(data))
  50. origShape := n.Shape().Clone()
  51. // reshape the tensor and swap axes 1 and 2 to unpack the layer for gguf
  52. if err := n.Reshape(heads, 2, origShape[0]/heads/2, origShape[1]); err != nil {
  53. return nil, err
  54. }
  55. if err := n.T(0, 2, 1, 3); err != nil {
  56. return nil, err
  57. }
  58. if err := n.Reshape(origShape...); err != nil {
  59. return nil, err
  60. }
  61. if err := n.Transpose(); err != nil {
  62. return nil, err
  63. }
  64. newN, err := native.SelectU16(n, 1)
  65. if err != nil {
  66. return nil, err
  67. }
  68. var fullTensor []uint16
  69. for _, v := range newN {
  70. fullTensor = append(fullTensor, v...)
  71. }
  72. return fullTensor, nil
  73. }
  74. func (m *LlamaModel) GetTensors() error {
  75. t, err := m.Format.GetTensors(m.Path, m.Params)
  76. if err != nil {
  77. return err
  78. }
  79. m.Tensors = []llm.Tensor{}
  80. pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
  81. re, err := regexp.Compile(pattern)
  82. if err != nil {
  83. return err
  84. }
  85. for _, l := range t {
  86. matches := re.FindAllStringSubmatch(l.Name, -1)
  87. if len(matches) > 0 {
  88. slog.Debug(fmt.Sprintf("setting handler for: %s", l.Name))
  89. wt := l.WriterTo.(torchWriterTo)
  90. wt.handler = llamaLayerHandler
  91. l.WriterTo = wt
  92. }
  93. m.Tensors = append(m.Tensors, l)
  94. }
  95. return nil
  96. }
  97. func (m *LlamaModel) LoadVocab() error {
  98. var v *Vocab
  99. var err error
  100. slog.Debug("loading vocab")
  101. v, err = LoadSentencePieceTokens(m.Path, m.Params)
  102. if err != nil {
  103. return err
  104. }
  105. slog.Debug("vocab loaded")
  106. m.Vocab = v
  107. return nil
  108. }
  109. func (m *LlamaModel) WriteGGUF(ws io.WriteSeeker) error {
  110. kv := llm.KV{
  111. "general.architecture": "llama",
  112. "general.name": m.Name,
  113. "llama.vocab_size": uint32(len(m.Vocab.Tokens)),
  114. "llama.context_length": uint32(m.Params.ContextSize),
  115. "llama.embedding_length": uint32(m.Params.HiddenSize),
  116. "llama.block_count": uint32(m.Params.HiddenLayers),
  117. "llama.feed_forward_length": uint32(m.Params.IntermediateSize),
  118. "llama.rope.dimension_count": uint32(m.Params.HiddenSize / m.Params.AttentionHeads),
  119. "llama.attention.head_count": uint32(m.Params.AttentionHeads),
  120. "llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
  121. "llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
  122. "general.file_type": uint32(1),
  123. "tokenizer.ggml.model": "llama",
  124. "tokenizer.ggml.tokens": m.Vocab.Tokens,
  125. "tokenizer.ggml.scores": m.Vocab.Scores,
  126. "tokenizer.ggml.token_type": m.Vocab.Types,
  127. "tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
  128. "tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
  129. "tokenizer.ggml.unknown_token_id": uint32(0),
  130. "tokenizer.ggml.add_bos_token": true,
  131. "tokenizer.ggml.add_eos_token": false,
  132. }
  133. f, err := os.CreateTemp("", "ollama-gguf")
  134. if err != nil {
  135. return err
  136. }
  137. defer f.Close()
  138. return llm.NewGGUFV3(m.Params.ByteOrder).Encode(f, kv, m.Tensors)
  139. }