ggml.go 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. package llm
  2. import (
  3. "encoding/binary"
  4. "errors"
  5. "fmt"
  6. "io"
  7. )
  8. type ModelFamily string
  9. type ModelType uint32
  10. const (
  11. ModelType3B ModelType = 26
  12. ModelType7B ModelType = 32
  13. ModelType13B ModelType = 40
  14. ModelType30B ModelType = 60
  15. ModelType65B ModelType = 80
  16. )
  17. func (mt ModelType) String() string {
  18. switch mt {
  19. case ModelType3B:
  20. return "3B"
  21. case ModelType7B:
  22. return "7B"
  23. case ModelType13B:
  24. return "13B"
  25. case ModelType30B:
  26. return "30B"
  27. case ModelType65B:
  28. return "65B"
  29. default:
  30. return "Unknown"
  31. }
  32. }
  33. type FileType interface {
  34. String() string
  35. }
  36. type GGML struct {
  37. magic uint32
  38. container
  39. model
  40. }
  41. type model interface {
  42. ModelFamily() ModelFamily
  43. ModelType() ModelType
  44. FileType() FileType
  45. }
  46. type container interface {
  47. Name() string
  48. Decode(io.Reader) error
  49. }
  50. type containerGGML struct {
  51. }
  52. func (c *containerGGML) Name() string {
  53. return "ggml"
  54. }
  55. func (c *containerGGML) Decode(r io.Reader) error {
  56. return nil
  57. }
  58. type containerGGMF struct {
  59. version uint32
  60. }
  61. func (c *containerGGMF) Name() string {
  62. return "ggmf"
  63. }
  64. func (c *containerGGMF) Decode(r io.Reader) error {
  65. var version uint32
  66. binary.Read(r, binary.LittleEndian, &version)
  67. switch version {
  68. case 1:
  69. default:
  70. return errors.New("invalid version")
  71. }
  72. c.version = version
  73. return nil
  74. }
  75. type containerGGJT struct {
  76. version uint32
  77. }
  78. func (c *containerGGJT) Name() string {
  79. return "ggjt"
  80. }
  81. func (c *containerGGJT) Decode(r io.Reader) error {
  82. var version uint32
  83. binary.Read(r, binary.LittleEndian, &version)
  84. switch version {
  85. case 1, 2, 3:
  86. default:
  87. return errors.New("invalid version")
  88. }
  89. c.version = version
  90. return nil
  91. }
  92. type containerLORA struct {
  93. version uint32
  94. }
  95. func (c *containerLORA) Name() string {
  96. return "ggla"
  97. }
  98. func (c *containerLORA) Decode(r io.Reader) error {
  99. var version uint32
  100. binary.Read(r, binary.LittleEndian, &version)
  101. switch version {
  102. case 1:
  103. default:
  104. return errors.New("invalid version")
  105. }
  106. c.version = version
  107. return nil
  108. }
  109. const (
  110. // / Magic constant for `ggml` files (unversioned).
  111. FILE_MAGIC_GGML = 0x67676d6c
  112. // / Magic constant for `ggml` files (versioned, ggmf).
  113. FILE_MAGIC_GGMF = 0x67676d66
  114. // / Magic constant for `ggml` files (versioned, ggjt).
  115. FILE_MAGIC_GGJT = 0x67676a74
  116. // / Magic constant for `ggla` files (LoRA adapter).
  117. FILE_MAGIC_GGLA = 0x67676C61
  118. )
  119. func DecodeGGML(r io.ReadSeeker, hint ModelFamily) (*GGML, error) {
  120. var ggml GGML
  121. binary.Read(r, binary.LittleEndian, &ggml.magic)
  122. switch ggml.magic {
  123. case FILE_MAGIC_GGML:
  124. ggml.container = &containerGGML{}
  125. case FILE_MAGIC_GGMF:
  126. ggml.container = &containerGGMF{}
  127. case FILE_MAGIC_GGJT:
  128. ggml.container = &containerGGJT{}
  129. case FILE_MAGIC_GGLA:
  130. ggml.container = &containerLORA{}
  131. default:
  132. return nil, errors.New("invalid file magic")
  133. }
  134. if err := ggml.Decode(r); err != nil {
  135. return nil, err
  136. }
  137. // different model types may have different layouts for hyperparameters
  138. switch hint {
  139. case ModelFamilyLlama:
  140. var llama llamaModel
  141. binary.Read(r, binary.LittleEndian, &llama.hyperparameters)
  142. ggml.model = &llama
  143. // TODO: sanity check hyperparameters
  144. default:
  145. return nil, fmt.Errorf("unsupported model type: %s", hint)
  146. }
  147. // final model type
  148. return &ggml, nil
  149. }