llama.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340
  1. package llama
  2. // #cgo CFLAGS: -std=c11 -DNDEBUG -DLOG_DISABLE_LOGS
  3. // #cgo CXXFLAGS: -std=c++11 -DNDEBUG -DLOG_DISABLE_LOGS
  4. // #cgo darwin,arm64 CFLAGS: -DGGML_USE_METAL -DGGML_METAL_EMBED_LIBRARY -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
  5. // #cgo darwin,arm64 CXXFLAGS: -DGGML_USE_METAL -DGGML_METAL_EMBED_LIBRARY -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
  6. // #cgo darwin,arm64 LDFLAGS: -ld_classic ${SRCDIR}/ggml-metal.o -framework Foundation -framework Metal -framework MetalKit -framework Accelerate
  7. // #cgo darwin,amd64 CFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
  8. // #cgo darwin,amd64 CXXFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
  9. // #cgo darwin,amd64 LDFLAGS: -ld_classic -framework Foundation -framework Accelerate
  10. // #cgo linux CFLAGS: -D_GNU_SOURCE
  11. // #cgo linux CXXFLAGS: -D_GNU_SOURCE
  12. // #cgo windows LDFLAGS: -lmsvcrt
  13. // #cgo avx CFLAGS: -mavx
  14. // #cgo avx CXXFLAGS: -mavx
  15. // #cgo avx2 CFLAGS: -mavx2 -mfma
  16. // #cgo avx2 CXXFLAGS: -mavx2 -mfma
  17. // #cgo cuda CFLAGS: -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  18. // #cgo cuda CXXFLAGS: -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  19. // #cgo rocm CFLAGS: -DGGML_USE_CUDA -DGGML_USE_HIPBLAS -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  20. // #cgo rocm CXXFLAGS: -DGGML_USE_CUDA -DGGML_USE_HIPBLAS -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  21. // #cgo rocm LDFLAGS: -L${SRCDIR} -lggml-hipblas -lhipblas -lamdhip64 -lrocblas
  22. // #cgo windows,cuda LDFLAGS: -L. -L"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.3/lib/x64" -lggml-cuda -lcuda -lcudart -lcublas -lcublasLt
  23. // #cgo windows,rocm LDFLAGS: -L. -L"C:/Program Files/AMD/ROCm/5.7/lib"
  24. // #cgo linux,cuda LDFLAGS: -L${SRCDIR} -L/usr/local/cuda/lib64 -lggml-cuda -lcuda -lcudart -lcublas -lcublasLt -lpthread -ldl -lrt
  25. // #cgo linux,rocm LDFLAGS: -L/opt/rocm/lib
  26. // #include <stdlib.h>
  27. // #include "llama.h"
  28. // #include "clip.h"
  29. // #include "llava.h"
  30. // #include "sampling_ext.h"
  31. import "C"
  32. import (
  33. "fmt"
  34. "runtime"
  35. "strings"
  36. "unsafe"
  37. )
  38. func BackendInit() {
  39. C.llama_backend_init()
  40. }
  41. func PrintSystemInfo() string {
  42. return C.GoString(C.llama_print_system_info())
  43. }
  44. type ContextParams struct {
  45. c C.struct_llama_context_params
  46. }
  47. func NewContextParams() ContextParams {
  48. params := C.llama_context_default_params()
  49. params.seed = C.uint(1234)
  50. params.n_ctx = C.uint(2048)
  51. params.n_threads = C.uint(runtime.NumCPU())
  52. params.n_threads_batch = params.n_threads
  53. return ContextParams{c: params}
  54. }
  55. type ModelParams struct {
  56. c C.struct_llama_model_params
  57. }
  58. func NewModelParams() ModelParams {
  59. params := C.llama_model_default_params()
  60. params.n_gpu_layers = 999
  61. return ModelParams{c: params}
  62. }
  63. type Context struct {
  64. c *C.struct_llama_context
  65. }
  66. func (c *Context) KvCacheClear() {
  67. C.llama_kv_cache_clear(c.c)
  68. }
  69. func (c *Context) Decode(batch Batch) error {
  70. // Positive return values does not mean a fatal error, but rather a warning.
  71. // 0 - success
  72. // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
  73. // < 0 - error
  74. code := int(C.llama_decode(c.c, batch.c))
  75. if code < 0 {
  76. return fmt.Errorf("llama_decode failed with code %d", code)
  77. }
  78. if code > 0 {
  79. return fmt.Errorf("could not find a KV slot for the batch - try reducing the size of the batch or increase the context. code: %d", code)
  80. }
  81. return nil
  82. }
  83. func (c *Context) Model() *Model {
  84. return &Model{c: C.llama_get_model(c.c)}
  85. }
  86. func (c *Context) GetLogitsIth(i int) []float32 {
  87. return unsafe.Slice((*float32)(unsafe.Pointer(C.llama_get_logits_ith(c.c, C.int(i)))), c.Model().NumVocab())
  88. }
  89. func (c *Context) SampleTokenGreedy(logits []float32) int {
  90. candidates := (*C.struct_llama_token_data)(C.malloc(C.size_t(len(logits)) * C.size_t(unsafe.Sizeof(C.struct_llama_token_data{}))))
  91. defer C.free(unsafe.Pointer(candidates))
  92. for i, logit := range logits {
  93. ptr := (*C.struct_llama_token_data)(unsafe.Pointer(uintptr(unsafe.Pointer(candidates)) + uintptr(i)*unsafe.Sizeof(C.struct_llama_token_data{})))
  94. ptr.id = C.int(i)
  95. ptr.logit = C.float(logit)
  96. ptr.p = 0.0
  97. }
  98. return int(C.llama_sample_token_greedy(c.c, &C.llama_token_data_array{
  99. data: candidates,
  100. size: C.size_t(len(logits)),
  101. sorted: C.bool(false),
  102. }))
  103. }
  104. func (c *Context) KvCacheSeqRm(seqId int, p0 int, p1 int) bool {
  105. return bool(C.llama_kv_cache_seq_rm(c.c, C.int(seqId), C.int(p0), C.int(p1)))
  106. }
  107. func LoadModelFromFile(modelPath string, params ModelParams) *Model {
  108. return &Model{c: C.llama_load_model_from_file(C.CString(modelPath), params.c)}
  109. }
  110. func NewContextWithModel(model *Model, params ContextParams) *Context {
  111. return &Context{c: C.llama_new_context_with_model(model.c, params.c)}
  112. }
  113. func (m *Model) NumVocab() int {
  114. return int(C.llama_n_vocab(m.c))
  115. }
  116. func (m *Model) TokenIsEog(token int) bool {
  117. return bool(C.llama_token_is_eog(m.c, C.llama_token(token)))
  118. }
  119. type Batch struct {
  120. c C.struct_llama_batch
  121. }
  122. func NewBatch(nTokens int, embd int, maxSeq int) Batch {
  123. return Batch{c: C.llama_batch_init(C.int(nTokens), C.int(embd), C.int(maxSeq))}
  124. }
  125. func (b *Batch) NumTokens() int {
  126. return int(b.c.n_tokens)
  127. }
  128. // Add adds a token to the batch with the given position for the given
  129. // sequence ids, and optionally instructs to include logits.
  130. func (b *Batch) Add(token int, pos int, seqIds []int, logits bool) {
  131. unsafe.Slice(b.c.token, 512)[b.c.n_tokens] = C.llama_token(token)
  132. unsafe.Slice(b.c.pos, 512)[b.c.n_tokens] = C.llama_pos(pos)
  133. unsafe.Slice(b.c.n_seq_id, 512)[b.c.n_tokens] = C.int(len(seqIds))
  134. for i, s := range seqIds {
  135. unsafe.Slice((unsafe.Slice(b.c.seq_id, 512)[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
  136. }
  137. if logits {
  138. unsafe.Slice(b.c.logits, 512)[b.c.n_tokens] = 1
  139. }
  140. b.c.n_tokens += 1
  141. }
  142. func (b *Batch) Clear() {
  143. b.c.n_tokens = 0
  144. }
  145. func (b *Batch) Free() {
  146. C.llama_batch_free(b.c)
  147. }
  148. func BatchGetOne(tokens []int, pos0 int, seqId int) Batch {
  149. return Batch{c: C.llama_batch_get_one((*C.int)(unsafe.Pointer(&tokens[0])), C.int32_t(len(tokens)), C.int(pos0), C.int(seqId))}
  150. }
  151. type Model struct {
  152. c *C.struct_llama_model
  153. }
  154. func (m *Model) TokenToPiece(token int) string {
  155. buf := make([]byte, 12)
  156. C.llama_token_to_piece(
  157. m.c,
  158. C.int32_t(token),
  159. (*C.char)(unsafe.Pointer(&buf[0])),
  160. C.int32_t(12),
  161. C.bool(true),
  162. )
  163. return strings.TrimRight(string(buf), "\x00")
  164. }
  165. func (m *Model) Tokenize(text string, maxTokens int, addSpecial bool, parseSpecial bool) ([]int, error) {
  166. cTokens := make([]C.llama_token, maxTokens)
  167. cText := C.CString(text)
  168. defer C.free(unsafe.Pointer(cText))
  169. result := C.llama_tokenize(
  170. m.c,
  171. cText,
  172. C.int32_t(len(text)),
  173. &cTokens[0],
  174. C.int32_t(maxTokens),
  175. C.bool(addSpecial),
  176. C.bool(parseSpecial),
  177. )
  178. if result < 0 {
  179. return nil, fmt.Errorf("tokenization failed, required %d tokens", -result)
  180. }
  181. tokens := make([]int, result)
  182. for i := 0; i < int(result); i++ {
  183. tokens[i] = int(cTokens[i])
  184. }
  185. return tokens, nil
  186. }
  187. func Quantize(infile, outfile string, ftype uint32) error {
  188. cinfile := C.CString(infile)
  189. defer C.free(unsafe.Pointer(cinfile))
  190. coutfile := C.CString(outfile)
  191. defer C.free(unsafe.Pointer(coutfile))
  192. params := C.llama_model_quantize_default_params()
  193. params.nthread = -1
  194. params.ftype = ftype
  195. if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
  196. return fmt.Errorf("llama_model_quantize: %d", rc)
  197. }
  198. return nil
  199. }
  200. // llava
  201. type ClipContext struct {
  202. c *C.struct_clip_ctx
  203. }
  204. func NewClipContext(modelPath string) *ClipContext {
  205. mp := C.CString(modelPath)
  206. defer C.free(unsafe.Pointer(mp))
  207. cc := C.clip_model_load(mp, 1)
  208. return &ClipContext{c: cc}
  209. }
  210. type LlavaContext struct {
  211. c *C.struct_llava_context
  212. }
  213. type LlavaImageEmbed struct {
  214. c *C.struct_llava_image_embed
  215. }
  216. func NewLlavaImageEmbed(clipContext *ClipContext, data []byte) *LlavaImageEmbed {
  217. return &LlavaImageEmbed{c: C.llava_image_embed_make_with_bytes(clipContext.c, C.int(runtime.NumCPU()), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data)))}
  218. }
  219. func LlavaEvalImageEmbed(llamaContext *Context, embed *LlavaImageEmbed, nBatch int, nPast *int) {
  220. C.llava_eval_image_embed(llamaContext.c, embed.c, C.int(nBatch), (*C.int)(unsafe.Pointer(nPast)))
  221. }
  222. // sampling
  223. // TODO: this is a temporary wrapper to allow calling C++ code from CGo
  224. type SamplingContext struct {
  225. c *C.struct_llama_sampling_context
  226. }
  227. type SamplingParams struct {
  228. TopK int
  229. TopP float32
  230. TfsZ float32
  231. TypicalP float32
  232. Temp float32
  233. PenaltyRepeat float32
  234. PenaltyFreq float32
  235. PenaltyPresent float32
  236. Mirostat int
  237. MirostatTau float32
  238. MirostatEta float32
  239. PenalizeNl bool
  240. Seed uint32
  241. Grammar string
  242. }
  243. func NewSamplingContext(params SamplingParams) *SamplingContext {
  244. var cparams C.struct_llama_sampling_cparams
  245. cparams.top_k = C.int32_t(params.TopK)
  246. cparams.top_p = C.float(params.TopP)
  247. cparams.tfs_z = C.float(params.TfsZ)
  248. cparams.typical_p = C.float(params.TypicalP)
  249. cparams.temp = C.float(params.Temp)
  250. cparams.penalty_repeat = C.float(params.PenaltyRepeat)
  251. cparams.penalty_freq = C.float(params.PenaltyFreq)
  252. cparams.penalty_present = C.float(params.PenaltyFreq)
  253. cparams.mirostat = C.int32_t(params.Mirostat)
  254. cparams.mirostat_tau = C.float(params.MirostatTau)
  255. cparams.mirostat_eta = C.float(params.MirostatEta)
  256. cparams.penalize_nl = C.bool(params.PenalizeNl)
  257. cparams.seed = C.uint32_t(params.Seed)
  258. grammar := C.CString(params.Grammar)
  259. defer C.free(unsafe.Pointer(grammar))
  260. cparams.grammar = grammar
  261. return &SamplingContext{c: C.llama_sampling_cinit(&cparams)}
  262. }
  263. func (s *SamplingContext) Free() {
  264. C.llama_sampling_cfree(s.c)
  265. }
  266. func (s *SamplingContext) Reset() {
  267. C.llama_sampling_creset(s.c)
  268. }
  269. func (s *SamplingContext) Sample(ctxMain *Context, ctxConfig *Context, idx int) int {
  270. // TODO (jmorganca): handle nil for all args
  271. if ctxConfig == nil {
  272. return int(C.llama_sampling_csample(s.c, ctxMain.c, nil, C.int(idx)))
  273. }
  274. return int(C.llama_sampling_csample(s.c, ctxMain.c, ctxConfig.c, C.int(idx)))
  275. }
  276. func (s *SamplingContext) Accept(ctxMain *Context, id int, applyGrammar bool) {
  277. C.llama_sampling_caccept(s.c, ctxMain.c, C.llama_token(id), C.bool(applyGrammar))
  278. }