convert_mixtral.go 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. package convert
  2. import (
  3. "fmt"
  4. "io"
  5. "slices"
  6. "strings"
  7. "github.com/ollama/ollama/llm"
  8. )
  9. type mixtral struct {
  10. llama
  11. NumLocalExperts uint32 `json:"num_local_experts"`
  12. NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
  13. }
  14. var _ Converter = (*mixtral)(nil)
  15. func (p *mixtral) KV(t *Tokenizer) llm.KV {
  16. kv := p.llama.KV(t)
  17. if p.NumLocalExperts > 0 {
  18. kv["llama.expert_count"] = p.NumLocalExperts
  19. }
  20. if p.NumExpertsPerToken > 0 {
  21. kv["llama.expert_used_count"] = p.NumExpertsPerToken
  22. }
  23. return kv
  24. }
  25. func (p *mixtral) Tensors(ts []Tensor) []llm.Tensor {
  26. oldnew := []string{
  27. "model.layers", "blk",
  28. "w1", "ffn_gate_exps",
  29. "w2", "ffn_down_exps",
  30. "w3", "ffn_up_exps",
  31. }
  32. for i := range p.NumLocalExperts {
  33. oldnew = append(oldnew, fmt.Sprintf(".block_sparse_moe.experts.%d.", i), ".")
  34. }
  35. // group experts of the same layer (model.layers.%d) and type (w[123]) into a single tensor
  36. namer := strings.NewReplacer(oldnew...)
  37. experts := make(map[string]experts)
  38. // merge experts into a single tensor while removing them from ts
  39. ts = slices.DeleteFunc(ts, func(t Tensor) bool {
  40. if !strings.Contains(t.Name(), ".block_sparse_moe.experts.") {
  41. return false
  42. }
  43. name := namer.Replace(t.Name())
  44. experts[name] = append(experts[name], t)
  45. return true
  46. })
  47. var out []llm.Tensor
  48. for n, e := range experts {
  49. // TODO(mxyng): sanity check experts
  50. out = append(out, llm.Tensor{
  51. Name: n,
  52. Kind: e[0].Kind(),
  53. Shape: append([]uint64{uint64(len(e))}, e[0].Shape()...),
  54. WriterTo: e,
  55. })
  56. }
  57. return append(out, p.llama.Tensors(ts)...)
  58. }
  59. type experts []Tensor
  60. func (e experts) WriteTo(w io.Writer) (int64, error) {
  61. // TODO(mxyng): experts _should_ be numerically sorted by expert but this should check
  62. for _, t := range e {
  63. // the canonical merged experts tensor stacks all experts along a new, 0 axis,
  64. // e.g. `tensor.Stack(0, e[0], e[1:]...)`, which requires allocating temporary buffers
  65. // this accomplishes the same thing by writing each expert tensor in sequence
  66. if _, err := t.WriteTo(w); err != nil {
  67. return 0, err
  68. }
  69. }
  70. return 0, nil
  71. }