mixtral.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. package convert
  2. import (
  3. "os"
  4. "regexp"
  5. "github.com/ollama/ollama/llm"
  6. )
  7. type MixtralModel struct {
  8. ModelData
  9. }
  10. func (m *MixtralModel) GetTensors() error {
  11. t, err := m.Format.GetTensors(m.Path, m.Params)
  12. if err != nil {
  13. return err
  14. }
  15. m.Tensors = []llm.Tensor{}
  16. pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
  17. re, err := regexp.Compile(pattern)
  18. if err != nil {
  19. return err
  20. }
  21. for _, l := range t {
  22. matches := re.FindAllStringSubmatch(l.Name, -1)
  23. if len(matches) > 0 {
  24. wt := l.WriterTo.(safetensorWriterTo)
  25. wt.handler = mistralLayerHandler
  26. l.WriterTo = wt
  27. }
  28. m.Tensors = append(m.Tensors, l)
  29. }
  30. return nil
  31. }
  32. func (m *MixtralModel) LoadVocab() error {
  33. v, err := LoadSentencePieceTokens(m.Path, m.Params)
  34. if err != nil {
  35. return err
  36. }
  37. m.Vocab = v
  38. return nil
  39. }
  40. func (m *MixtralModel) WriteGGUF() (string, error) {
  41. kv := llm.KV{
  42. "general.architecture": "llama",
  43. "general.name": m.Name,
  44. "llama.block_count": uint32(m.Params.HiddenLayers),
  45. "llama.context_length": uint32(m.Params.ContextSize),
  46. "llama.embedding_length": uint32(m.Params.HiddenSize),
  47. "llama.feed_forward_length": uint32(m.Params.IntermediateSize),
  48. "llama.attention.head_count": uint32(m.Params.AttentionHeads),
  49. "llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
  50. "llama.rope.freq_base": float32(m.Params.RopeFrequencyBase),
  51. "llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
  52. "llama.expert_count": uint32(m.Params.Experts),
  53. "llama.expert_used_count": uint32(m.Params.ExpertsUsed),
  54. "llama.vocab_size": uint32(len(m.Vocab.Tokens)),
  55. "llama.rope.dimension_count": uint32(m.Params.HiddenSize / m.Params.AttentionHeads),
  56. "general.file_type": uint32(1),
  57. "tokenizer.ggml.model": "llama",
  58. "tokenizer.ggml.tokens": m.Vocab.Tokens,
  59. "tokenizer.ggml.scores": m.Vocab.Scores,
  60. "tokenizer.ggml.token_type": m.Vocab.Types,
  61. "tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
  62. "tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
  63. "tokenizer.ggml.unknown_token_id": uint32(0),
  64. "tokenizer.ggml.add_bos_token": true,
  65. "tokenizer.ggml.add_eos_token": false,
  66. }
  67. f, err := os.CreateTemp("", "ollama-gguf")
  68. if err != nil {
  69. return "", err
  70. }
  71. defer f.Close()
  72. mod := llm.NewGGUFV3(m.Params.ByteOrder)
  73. if err := mod.Encode(f, kv, m.Tensors); err != nil {
  74. return "", err
  75. }
  76. return f.Name(), nil
  77. }