ggml.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. package llm
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "fmt"
  6. "io"
  7. )
  8. type ModelFamily string
  9. type ModelType uint32
  10. const (
  11. ModelType3B ModelType = 26
  12. ModelType7B ModelType = 32
  13. ModelType13B ModelType = 40
  14. ModelType34B ModelType = 48
  15. ModelType30B ModelType = 60
  16. ModelType65B ModelType = 80
  17. )
  18. func (mt ModelType) String() string {
  19. switch mt {
  20. case ModelType3B:
  21. return "3B"
  22. case ModelType7B:
  23. return "7B"
  24. case ModelType13B:
  25. return "13B"
  26. case ModelType34B:
  27. return "34B"
  28. case ModelType30B:
  29. return "30B"
  30. case ModelType65B:
  31. return "65B"
  32. default:
  33. return "Unknown"
  34. }
  35. }
  36. type FileType interface {
  37. String() string
  38. }
  39. type GGML struct {
  40. magic uint32
  41. container
  42. model
  43. }
  44. type model interface {
  45. ModelFamily() ModelFamily
  46. ModelType() ModelType
  47. FileType() FileType
  48. }
  49. type container interface {
  50. Name() string
  51. Decode(io.Reader) error
  52. }
  53. type containerGGML struct {
  54. }
  55. func (c *containerGGML) Name() string {
  56. return "ggml"
  57. }
  58. func (c *containerGGML) Decode(r io.Reader) error {
  59. return nil
  60. }
  61. type containerGGMF struct {
  62. version uint32
  63. }
  64. func (c *containerGGMF) Name() string {
  65. return "ggmf"
  66. }
  67. func (c *containerGGMF) Decode(r io.Reader) error {
  68. var version uint32
  69. binary.Read(r, binary.LittleEndian, &version)
  70. switch version {
  71. case 1:
  72. default:
  73. return errors.New("invalid version")
  74. }
  75. c.version = version
  76. return nil
  77. }
  78. type containerGGJT struct {
  79. version uint32
  80. }
  81. func (c *containerGGJT) Name() string {
  82. return "ggjt"
  83. }
  84. func (c *containerGGJT) Decode(r io.Reader) error {
  85. var version uint32
  86. binary.Read(r, binary.LittleEndian, &version)
  87. switch version {
  88. case 1, 2, 3:
  89. default:
  90. return errors.New("invalid version")
  91. }
  92. c.version = version
  93. return nil
  94. }
  95. type containerLORA struct {
  96. version uint32
  97. }
  98. func (c *containerLORA) Name() string {
  99. return "ggla"
  100. }
  101. func (c *containerLORA) Decode(r io.Reader) error {
  102. var version uint32
  103. binary.Read(r, binary.LittleEndian, &version)
  104. switch version {
  105. case 1:
  106. default:
  107. return errors.New("invalid version")
  108. }
  109. c.version = version
  110. return nil
  111. }
  112. const (
  113. // / Magic constant for `ggml` files (unversioned).
  114. FILE_MAGIC_GGML = 0x67676d6c
  115. // / Magic constant for `ggml` files (versioned, ggmf).
  116. FILE_MAGIC_GGMF = 0x67676d66
  117. // / Magic constant for `ggml` files (versioned, ggjt).
  118. FILE_MAGIC_GGJT = 0x67676a74
  119. // / Magic constant for `ggla` files (LoRA adapter).
  120. FILE_MAGIC_GGLA = 0x67676C61
  121. )
  122. func DecodeGGML(r io.ReadSeeker, hint ModelFamily) (*GGML, error) {
  123. var ggml GGML
  124. binary.Read(r, binary.LittleEndian, &ggml.magic)
  125. switch ggml.magic {
  126. case FILE_MAGIC_GGML:
  127. ggml.container = &containerGGML{}
  128. case FILE_MAGIC_GGMF:
  129. ggml.container = &containerGGMF{}
  130. case FILE_MAGIC_GGJT:
  131. ggml.container = &containerGGJT{}
  132. case FILE_MAGIC_GGLA:
  133. ggml.container = &containerLORA{}
  134. default:
  135. return nil, errors.New("invalid file magic")
  136. }
  137. if err := ggml.Decode(r); err != nil {
  138. return nil, err
  139. }
  140. // different model types may have different layouts for hyperparameters
  141. switch hint {
  142. case ModelFamilyLlama:
  143. var llama llamaModel
  144. binary.Read(r, binary.LittleEndian, &llama.hyperparameters)
  145. ggml.model = &llama
  146. // TODO: sanity check hyperparameters
  147. default:
  148. return nil, fmt.Errorf("unsupported model type: %s", hint)
  149. }
  150. // final model type
  151. return &ggml, nil
  152. }