llama.go 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302
  1. // MIT License
  2. // Copyright (c) 2023 go-skynet authors
  3. // Permission is hereby granted, free of charge, to any person obtaining a copy
  4. // of this software and associated documentation files (the "Software"), to deal
  5. // in the Software without restriction, including without limitation the rights
  6. // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  7. // copies of the Software, and to permit persons to whom the Software is
  8. // furnished to do so, subject to the following conditions:
  9. // The above copyright notice and this permission notice shall be included in all
  10. // copies or substantial portions of the Software.
  11. // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  12. // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  13. // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  14. // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  15. // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  16. // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  17. // SOFTWARE.
  18. //go:generate cmake -S . -B build
  19. //go:generate cmake --build build
  20. package llama
  21. // #cgo LDFLAGS: -Lbuild -lbinding -lllama -lggml_static -lstdc++
  22. // #cgo darwin LDFLAGS: -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
  23. // #cgo darwin CXXFLAGS: -std=c++11
  24. // #include "binding/binding.h"
  25. import "C"
  26. import (
  27. "fmt"
  28. "os"
  29. "strings"
  30. "sync"
  31. "unsafe"
  32. )
  33. type LLama struct {
  34. state unsafe.Pointer
  35. embeddings bool
  36. contextSize int
  37. }
  38. func New(model string, opts ...ModelOption) (*LLama, error) {
  39. mo := NewModelOptions(opts...)
  40. modelPath := C.CString(model)
  41. result := C.load_model(modelPath, C.int(mo.ContextSize), C.int(mo.Seed), C.bool(mo.F16Memory), C.bool(mo.MLock), C.bool(mo.Embeddings), C.bool(mo.MMap), C.bool(mo.LowVRAM), C.bool(mo.VocabOnly), C.int(mo.NGPULayers), C.int(mo.NBatch), C.CString(mo.MainGPU), C.CString(mo.TensorSplit), C.bool(mo.NUMA))
  42. if result == nil {
  43. return nil, fmt.Errorf("failed loading model")
  44. }
  45. ll := &LLama{state: result, contextSize: mo.ContextSize, embeddings: mo.Embeddings}
  46. return ll, nil
  47. }
  48. func (l *LLama) Free() {
  49. C.llama_binding_free_model(l.state)
  50. }
  51. func (l *LLama) LoadState(state string) error {
  52. d := C.CString(state)
  53. w := C.CString("rb")
  54. result := C.load_state(l.state, d, w)
  55. if result != 0 {
  56. return fmt.Errorf("error while loading state")
  57. }
  58. return nil
  59. }
  60. func (l *LLama) SaveState(dst string) error {
  61. d := C.CString(dst)
  62. w := C.CString("wb")
  63. C.save_state(l.state, d, w)
  64. _, err := os.Stat(dst)
  65. return err
  66. }
  67. // Token Embeddings
  68. func (l *LLama) TokenEmbeddings(tokens []int, opts ...PredictOption) ([]float32, error) {
  69. if !l.embeddings {
  70. return []float32{}, fmt.Errorf("model loaded without embeddings")
  71. }
  72. po := NewPredictOptions(opts...)
  73. outSize := po.Tokens
  74. if po.Tokens == 0 {
  75. outSize = 9999999
  76. }
  77. floats := make([]float32, outSize)
  78. myArray := (*C.int)(C.malloc(C.size_t(len(tokens)) * C.sizeof_int))
  79. // Copy the values from the Go slice to the C array
  80. for i, v := range tokens {
  81. (*[1<<31 - 1]int32)(unsafe.Pointer(myArray))[i] = int32(v)
  82. }
  83. params := C.llama_allocate_params(C.CString(""), C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK),
  84. C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat),
  85. C.bool(po.IgnoreEOS), C.bool(po.F16KV),
  86. C.int(po.Batch), C.int(po.NKeep), nil, C.int(0),
  87. C.float(po.TailFreeSamplingZ), C.float(po.TypicalP), C.float(po.FrequencyPenalty), C.float(po.PresencePenalty),
  88. C.int(po.Mirostat), C.float(po.MirostatETA), C.float(po.MirostatTAU), C.bool(po.PenalizeNL), C.CString(po.LogitBias),
  89. C.CString(po.PathPromptCache), C.bool(po.PromptCacheAll), C.bool(po.MLock), C.bool(po.MMap),
  90. C.CString(po.MainGPU), C.CString(po.TensorSplit),
  91. C.bool(po.PromptCacheRO),
  92. )
  93. ret := C.get_token_embeddings(params, l.state, myArray, C.int(len(tokens)), (*C.float)(&floats[0]))
  94. if ret != 0 {
  95. return floats, fmt.Errorf("embedding inference failed")
  96. }
  97. return floats, nil
  98. }
  99. // Embeddings
  100. func (l *LLama) Embeddings(text string, opts ...PredictOption) ([]float32, error) {
  101. if !l.embeddings {
  102. return []float32{}, fmt.Errorf("model loaded without embeddings")
  103. }
  104. po := NewPredictOptions(opts...)
  105. input := C.CString(text)
  106. if po.Tokens == 0 {
  107. po.Tokens = 99999999
  108. }
  109. floats := make([]float32, po.Tokens)
  110. reverseCount := len(po.StopPrompts)
  111. reversePrompt := make([]*C.char, reverseCount)
  112. var pass **C.char
  113. for i, s := range po.StopPrompts {
  114. cs := C.CString(s)
  115. reversePrompt[i] = cs
  116. pass = &reversePrompt[0]
  117. }
  118. params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK),
  119. C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat),
  120. C.bool(po.IgnoreEOS), C.bool(po.F16KV),
  121. C.int(po.Batch), C.int(po.NKeep), pass, C.int(reverseCount),
  122. C.float(po.TailFreeSamplingZ), C.float(po.TypicalP), C.float(po.FrequencyPenalty), C.float(po.PresencePenalty),
  123. C.int(po.Mirostat), C.float(po.MirostatETA), C.float(po.MirostatTAU), C.bool(po.PenalizeNL), C.CString(po.LogitBias),
  124. C.CString(po.PathPromptCache), C.bool(po.PromptCacheAll), C.bool(po.MLock), C.bool(po.MMap),
  125. C.CString(po.MainGPU), C.CString(po.TensorSplit),
  126. C.bool(po.PromptCacheRO),
  127. )
  128. ret := C.get_embeddings(params, l.state, (*C.float)(&floats[0]))
  129. if ret != 0 {
  130. return floats, fmt.Errorf("embedding inference failed")
  131. }
  132. return floats, nil
  133. }
  134. func (l *LLama) Eval(text string, opts ...PredictOption) error {
  135. po := NewPredictOptions(opts...)
  136. input := C.CString(text)
  137. if po.Tokens == 0 {
  138. po.Tokens = 99999999
  139. }
  140. reverseCount := len(po.StopPrompts)
  141. reversePrompt := make([]*C.char, reverseCount)
  142. var pass **C.char
  143. for i, s := range po.StopPrompts {
  144. cs := C.CString(s)
  145. reversePrompt[i] = cs
  146. pass = &reversePrompt[0]
  147. }
  148. params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK),
  149. C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat),
  150. C.bool(po.IgnoreEOS), C.bool(po.F16KV),
  151. C.int(po.Batch), C.int(po.NKeep), pass, C.int(reverseCount),
  152. C.float(po.TailFreeSamplingZ), C.float(po.TypicalP), C.float(po.FrequencyPenalty), C.float(po.PresencePenalty),
  153. C.int(po.Mirostat), C.float(po.MirostatETA), C.float(po.MirostatTAU), C.bool(po.PenalizeNL), C.CString(po.LogitBias),
  154. C.CString(po.PathPromptCache), C.bool(po.PromptCacheAll), C.bool(po.MLock), C.bool(po.MMap),
  155. C.CString(po.MainGPU), C.CString(po.TensorSplit),
  156. C.bool(po.PromptCacheRO),
  157. )
  158. ret := C.eval(params, l.state, input)
  159. if ret != 0 {
  160. return fmt.Errorf("inference failed")
  161. }
  162. C.llama_free_params(params)
  163. return nil
  164. }
  165. func (l *LLama) Predict(text string, opts ...PredictOption) (string, error) {
  166. po := NewPredictOptions(opts...)
  167. if po.TokenCallback != nil {
  168. setCallback(l.state, po.TokenCallback)
  169. }
  170. input := C.CString(text)
  171. if po.Tokens == 0 {
  172. po.Tokens = 99999999
  173. }
  174. out := make([]byte, po.Tokens)
  175. reverseCount := len(po.StopPrompts)
  176. reversePrompt := make([]*C.char, reverseCount)
  177. var pass **C.char
  178. for i, s := range po.StopPrompts {
  179. cs := C.CString(s)
  180. reversePrompt[i] = cs
  181. pass = &reversePrompt[0]
  182. }
  183. params := C.llama_allocate_params(input, C.int(po.Seed), C.int(po.Threads), C.int(po.Tokens), C.int(po.TopK),
  184. C.float(po.TopP), C.float(po.Temperature), C.float(po.Penalty), C.int(po.Repeat),
  185. C.bool(po.IgnoreEOS), C.bool(po.F16KV),
  186. C.int(po.Batch), C.int(po.NKeep), pass, C.int(reverseCount),
  187. C.float(po.TailFreeSamplingZ), C.float(po.TypicalP), C.float(po.FrequencyPenalty), C.float(po.PresencePenalty),
  188. C.int(po.Mirostat), C.float(po.MirostatETA), C.float(po.MirostatTAU), C.bool(po.PenalizeNL), C.CString(po.LogitBias),
  189. C.CString(po.PathPromptCache), C.bool(po.PromptCacheAll), C.bool(po.MLock), C.bool(po.MMap),
  190. C.CString(po.MainGPU), C.CString(po.TensorSplit),
  191. C.bool(po.PromptCacheRO),
  192. )
  193. ret := C.llama_predict(params, l.state, (*C.char)(unsafe.Pointer(&out[0])), C.bool(po.DebugMode))
  194. if ret != 0 {
  195. return "", fmt.Errorf("inference failed")
  196. }
  197. res := C.GoString((*C.char)(unsafe.Pointer(&out[0])))
  198. res = strings.TrimPrefix(res, " ")
  199. res = strings.TrimPrefix(res, text)
  200. res = strings.TrimPrefix(res, "\n")
  201. for _, s := range po.StopPrompts {
  202. res = strings.TrimRight(res, s)
  203. }
  204. C.llama_free_params(params)
  205. if po.TokenCallback != nil {
  206. setCallback(l.state, nil)
  207. }
  208. return res, nil
  209. }
  210. // CGo only allows us to use static calls from C to Go, we can't just dynamically pass in func's.
  211. // This is the next best thing, we register the callbacks in this map and call tokenCallback from
  212. // the C code. We also attach a finalizer to LLama, so it will unregister the callback when the
  213. // garbage collection frees it.
  214. // SetTokenCallback registers a callback for the individual tokens created when running Predict. It
  215. // will be called once for each token. The callback shall return true as long as the model should
  216. // continue predicting the next token. When the callback returns false the predictor will return.
  217. // The tokens are just converted into Go strings, they are not trimmed or otherwise changed. Also
  218. // the tokens may not be valid UTF-8.
  219. // Pass in nil to remove a callback.
  220. //
  221. // It is save to call this method while a prediction is running.
  222. func (l *LLama) SetTokenCallback(callback func(token string) bool) {
  223. setCallback(l.state, callback)
  224. }
  225. var (
  226. m sync.Mutex
  227. callbacks = map[uintptr]func(string) bool{}
  228. )
  229. //export tokenCallback
  230. func tokenCallback(statePtr unsafe.Pointer, token *C.char) bool {
  231. m.Lock()
  232. defer m.Unlock()
  233. if callback, ok := callbacks[uintptr(statePtr)]; ok {
  234. return callback(C.GoString(token))
  235. }
  236. return true
  237. }
  238. // setCallback can be used to register a token callback for LLama. Pass in a nil callback to
  239. // remove the callback.
  240. func setCallback(statePtr unsafe.Pointer, callback func(string) bool) {
  241. m.Lock()
  242. defer m.Unlock()
  243. if callback == nil {
  244. delete(callbacks, uintptr(statePtr))
  245. } else {
  246. callbacks[uintptr(statePtr)] = callback
  247. }
  248. }