gguf.go 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. package llm
  2. import (
  3. "bytes"
  4. "encoding/binary"
  5. "errors"
  6. "fmt"
  7. "io"
  8. "log"
  9. "path"
  10. "sync"
  11. )
  12. type containerGGUF struct {
  13. Version uint32
  14. V1 struct {
  15. NumTensor uint32
  16. NumKV uint32
  17. }
  18. V2 struct {
  19. NumTensor uint64
  20. NumKV uint64
  21. }
  22. }
  23. func (c *containerGGUF) Name() string {
  24. return "gguf"
  25. }
  26. func (c *containerGGUF) Decode(r io.Reader) (model, error) {
  27. binary.Read(r, binary.LittleEndian, &c.Version)
  28. switch c.Version {
  29. case 1:
  30. binary.Read(r, binary.LittleEndian, &c.V1)
  31. case 2:
  32. binary.Read(r, binary.LittleEndian, &c.V2)
  33. default:
  34. return nil, errors.New("invalid version")
  35. }
  36. model := newGGUFModel(c)
  37. if err := model.Decode(r); err != nil {
  38. return nil, err
  39. }
  40. return model, nil
  41. }
  42. const (
  43. ggufTypeUint8 uint32 = iota
  44. ggufTypeInt8
  45. ggufTypeUint16
  46. ggufTypeInt16
  47. ggufTypeUint32
  48. ggufTypeInt32
  49. ggufTypeFloat32
  50. ggufTypeBool
  51. ggufTypeString
  52. ggufTypeArray
  53. ggufTypeUint64
  54. ggufTypeInt64
  55. ggufTypeFloat64
  56. )
  57. type kv map[string]any
  58. type ggufModel struct {
  59. *containerGGUF
  60. kv
  61. }
  62. func newGGUFModel(container *containerGGUF) *ggufModel {
  63. return &ggufModel{
  64. containerGGUF: container,
  65. kv: make(kv),
  66. }
  67. }
  68. func (llm *ggufModel) NumKV() uint64 {
  69. if llm.Version == 1 {
  70. return uint64(llm.V1.NumKV)
  71. }
  72. return llm.V2.NumKV
  73. }
  74. func (llm *ggufModel) ModelFamily() ModelFamily {
  75. t, ok := llm.kv["general.architecture"].(string)
  76. if ok {
  77. return ModelFamily(t)
  78. }
  79. log.Printf("unknown model family: %T", t)
  80. return ModelFamilyUnknown
  81. }
  82. func (llm *ggufModel) ModelType() ModelType {
  83. switch llm.ModelFamily() {
  84. case ModelFamilyLlama:
  85. blocks, ok := llm.kv["llama.block_count"].(uint32)
  86. if ok {
  87. return ModelType(blocks)
  88. }
  89. }
  90. return ModelType7B
  91. }
  92. func (llm *ggufModel) FileType() FileType {
  93. switch llm.ModelFamily() {
  94. case ModelFamilyLlama:
  95. t, ok := llm.kv["general.file_type"].(uint32)
  96. if ok {
  97. return llamaFileType(t)
  98. }
  99. }
  100. return llamaFileTypeF16
  101. }
  102. func (llm *ggufModel) Decode(r io.Reader) error {
  103. read := llm.readString
  104. if llm.Version == 1 {
  105. read = llm.readStringV1
  106. }
  107. for i := 0; uint64(i) < llm.NumKV(); i++ {
  108. k, err := read(r)
  109. if err != nil {
  110. return err
  111. }
  112. vtype := llm.readU32(r)
  113. var v any
  114. switch vtype {
  115. case ggufTypeUint8:
  116. v = llm.readU8(r)
  117. case ggufTypeInt8:
  118. v = llm.readI8(r)
  119. case ggufTypeUint16:
  120. v = llm.readU16(r)
  121. case ggufTypeInt16:
  122. v = llm.readI16(r)
  123. case ggufTypeUint32:
  124. v = llm.readU32(r)
  125. case ggufTypeInt32:
  126. v = llm.readI32(r)
  127. case ggufTypeUint64:
  128. v = llm.readU64(r)
  129. case ggufTypeInt64:
  130. v = llm.readI64(r)
  131. case ggufTypeFloat32:
  132. v = llm.readF32(r)
  133. case ggufTypeFloat64:
  134. v = llm.readF64(r)
  135. case ggufTypeBool:
  136. v = llm.readBool(r)
  137. case ggufTypeString:
  138. fn := llm.readString
  139. if llm.Version == 1 {
  140. fn = llm.readStringV1
  141. }
  142. s, err := fn(r)
  143. if err != nil {
  144. return err
  145. }
  146. v = s
  147. case ggufTypeArray:
  148. fn := llm.readArray
  149. if llm.Version == 1 {
  150. fn = llm.readArrayV1
  151. }
  152. a, err := fn(r)
  153. if err != nil {
  154. return err
  155. }
  156. v = a
  157. default:
  158. return fmt.Errorf("invalid type: %d", vtype)
  159. }
  160. llm.kv[k] = v
  161. }
  162. return nil
  163. }
  164. func (ggufModel) readU8(r io.Reader) uint8 {
  165. var u8 uint8
  166. binary.Read(r, binary.LittleEndian, &u8)
  167. return u8
  168. }
  169. func (ggufModel) readI8(r io.Reader) int8 {
  170. var i8 int8
  171. binary.Read(r, binary.LittleEndian, &i8)
  172. return i8
  173. }
  174. func (ggufModel) readU16(r io.Reader) uint16 {
  175. var u16 uint16
  176. binary.Read(r, binary.LittleEndian, &u16)
  177. return u16
  178. }
  179. func (ggufModel) readI16(r io.Reader) int16 {
  180. var i16 int16
  181. binary.Read(r, binary.LittleEndian, &i16)
  182. return i16
  183. }
  184. func (ggufModel) readU32(r io.Reader) uint32 {
  185. var u32 uint32
  186. binary.Read(r, binary.LittleEndian, &u32)
  187. return u32
  188. }
  189. func (ggufModel) readI32(r io.Reader) int32 {
  190. var i32 int32
  191. binary.Read(r, binary.LittleEndian, &i32)
  192. return i32
  193. }
  194. func (ggufModel) readU64(r io.Reader) uint64 {
  195. var u64 uint64
  196. binary.Read(r, binary.LittleEndian, &u64)
  197. return u64
  198. }
  199. func (ggufModel) readI64(r io.Reader) int64 {
  200. var i64 int64
  201. binary.Read(r, binary.LittleEndian, &i64)
  202. return i64
  203. }
  204. func (ggufModel) readF32(r io.Reader) float32 {
  205. var f32 float32
  206. binary.Read(r, binary.LittleEndian, &f32)
  207. return f32
  208. }
  209. func (ggufModel) readF64(r io.Reader) float64 {
  210. var f64 float64
  211. binary.Read(r, binary.LittleEndian, &f64)
  212. return f64
  213. }
  214. func (ggufModel) readBool(r io.Reader) bool {
  215. var b bool
  216. binary.Read(r, binary.LittleEndian, &b)
  217. return b
  218. }
  219. func (ggufModel) readStringV1(r io.Reader) (string, error) {
  220. var nameLength uint32
  221. binary.Read(r, binary.LittleEndian, &nameLength)
  222. var b bytes.Buffer
  223. if _, err := io.CopyN(&b, r, int64(nameLength)); err != nil {
  224. return "", err
  225. }
  226. // gguf v1 strings are null-terminated
  227. b.Truncate(b.Len() - 1)
  228. return b.String(), nil
  229. }
  230. func (llm ggufModel) readString(r io.Reader) (string, error) {
  231. var nameLength uint64
  232. binary.Read(r, binary.LittleEndian, &nameLength)
  233. var b bytes.Buffer
  234. if _, err := io.CopyN(&b, r, int64(nameLength)); err != nil {
  235. return "", err
  236. }
  237. return b.String(), nil
  238. }
  239. func (llm *ggufModel) readArrayV1(r io.Reader) (arr []any, err error) {
  240. atype := llm.readU32(r)
  241. n := llm.readU32(r)
  242. for i := 0; uint32(i) < n; i++ {
  243. switch atype {
  244. case ggufTypeUint8:
  245. arr = append(arr, llm.readU8(r))
  246. case ggufTypeInt8:
  247. arr = append(arr, llm.readU8(r))
  248. case ggufTypeUint16:
  249. arr = append(arr, llm.readU16(r))
  250. case ggufTypeInt16:
  251. arr = append(arr, llm.readI16(r))
  252. case ggufTypeUint32:
  253. arr = append(arr, llm.readU32(r))
  254. case ggufTypeInt32:
  255. arr = append(arr, llm.readI32(r))
  256. case ggufTypeFloat32:
  257. arr = append(arr, llm.readF32(r))
  258. case ggufTypeBool:
  259. arr = append(arr, llm.readBool(r))
  260. case ggufTypeString:
  261. s, err := llm.readStringV1(r)
  262. if err != nil {
  263. return nil, err
  264. }
  265. arr = append(arr, s)
  266. default:
  267. return nil, fmt.Errorf("invalid array type: %d", atype)
  268. }
  269. }
  270. return
  271. }
  272. func (llm *ggufModel) readArray(r io.Reader) (arr []any, err error) {
  273. atype := llm.readU32(r)
  274. n := llm.readU64(r)
  275. for i := 0; uint64(i) < n; i++ {
  276. switch atype {
  277. case ggufTypeUint8:
  278. arr = append(arr, llm.readU8(r))
  279. case ggufTypeInt8:
  280. arr = append(arr, llm.readU8(r))
  281. case ggufTypeUint16:
  282. arr = append(arr, llm.readU16(r))
  283. case ggufTypeInt16:
  284. arr = append(arr, llm.readI16(r))
  285. case ggufTypeUint32:
  286. arr = append(arr, llm.readU32(r))
  287. case ggufTypeInt32:
  288. arr = append(arr, llm.readI32(r))
  289. case ggufTypeUint64:
  290. arr = append(arr, llm.readU64(r))
  291. case ggufTypeInt64:
  292. arr = append(arr, llm.readI64(r))
  293. case ggufTypeFloat32:
  294. arr = append(arr, llm.readF32(r))
  295. case ggufTypeFloat64:
  296. arr = append(arr, llm.readF64(r))
  297. case ggufTypeBool:
  298. arr = append(arr, llm.readBool(r))
  299. case ggufTypeString:
  300. s, err := llm.readString(r)
  301. if err != nil {
  302. return nil, err
  303. }
  304. arr = append(arr, s)
  305. default:
  306. return nil, fmt.Errorf("invalid array type: %d", atype)
  307. }
  308. }
  309. return
  310. }
  311. var (
  312. ggufGPU = path.Join("llama.cpp", "gguf", "build", "gpu", "bin")
  313. ggufCPU = path.Join("llama.cpp", "gguf", "build", "cpu", "bin")
  314. )
  315. var (
  316. ggufInit sync.Once
  317. ggufRunnerPath string
  318. )
  319. func ggufRunner() ModelRunner {
  320. ggufInit.Do(func() {
  321. ggufRunnerPath = chooseRunner(ggufGPU, ggufCPU)
  322. })
  323. return ModelRunner{Path: ggufRunnerPath}
  324. }