convert_mixtral.go 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. package convert
  2. import (
  3. "fmt"
  4. "io"
  5. "slices"
  6. "strings"
  7. "github.com/ollama/ollama/llm"
  8. )
  9. type mixtral struct {
  10. llama
  11. NumLocalExperts uint32 `json:"num_local_experts"`
  12. NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
  13. }
  14. func (p *mixtral) KV(t *Tokenizer) llm.KV {
  15. kv := p.llama.KV(t)
  16. if p.NumLocalExperts > 0 {
  17. kv["llama.expert_count"] = p.NumLocalExperts
  18. }
  19. if p.NumExpertsPerToken > 0 {
  20. kv["llama.expert_used_count"] = p.NumExpertsPerToken
  21. }
  22. return kv
  23. }
  24. func (p *mixtral) Tensors(ts []Tensor) []llm.Tensor {
  25. oldnew := []string{
  26. "model.layers", "blk",
  27. "w1", "ffn_gate_exps",
  28. "w2", "ffn_down_exps",
  29. "w3", "ffn_up_exps",
  30. }
  31. for i := range p.NumLocalExperts {
  32. oldnew = append(oldnew, fmt.Sprintf(".block_sparse_moe.experts.%d.", i), ".")
  33. }
  34. // group experts of the same layer (model.layers.%d) and type (w[123]) into a single tensor
  35. namer := strings.NewReplacer(oldnew...)
  36. experts := make(map[string]experts)
  37. // merge experts into a single tensor while removing them from ts
  38. ts = slices.DeleteFunc(ts, func(t Tensor) bool {
  39. if !strings.Contains(t.Name(), ".block_sparse_moe.experts.") {
  40. return false
  41. }
  42. name := namer.Replace(t.Name())
  43. experts[name] = append(experts[name], t)
  44. return true
  45. })
  46. var out []llm.Tensor
  47. for n, e := range experts {
  48. // TODO(mxyng): sanity check experts
  49. out = append(out, llm.Tensor{
  50. Name: n,
  51. Kind: e[0].Kind(),
  52. Shape: append([]uint64{uint64(len(e))}, e[0].Shape()...),
  53. WriterTo: e,
  54. })
  55. }
  56. return append(out, p.llama.Tensors(ts)...)
  57. }
  58. func (p *mixtral) Replacements() []string {
  59. return append(
  60. p.llama.Replacements(),
  61. "block_sparse_moe.gate", "ffn_gate_inp",
  62. )
  63. }
  64. type experts []Tensor
  65. func (e experts) WriteTo(w io.Writer) (int64, error) {
  66. // TODO(mxyng): experts _should_ be numerically sorted by expert but this should check
  67. for _, t := range e {
  68. // the canonical merged experts tensor stacks all experts along a new, 0 axis,
  69. // e.g. `tensor.Stack(0, e[0], e[1:]...)`, which requires allocating temporary buffers
  70. // this accomplishes the same thing by writing each expert tensor in sequence
  71. if _, err := t.WriteTo(w); err != nil {
  72. return 0, err
  73. }
  74. }
  75. return 0, nil
  76. }