llama.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513
  1. package llama
  2. /*
  3. #cgo CFLAGS: -O2 -std=c11 -DGGML_BUILD=1 -DNDEBUG -DLOG_DISABLE_LOGS -DGGML_USE_LLAMAFILE
  4. #cgo CXXFLAGS: -O2 -std=c++11 -DGGML_BUILD=1 -DNDEBUG -DLOG_DISABLE_LOGS -DGGML_USE_LLAMAFILE
  5. #cgo darwin,arm64 CFLAGS: -DGGML_USE_METAL -DGGML_USE_ACCELERATE -DGGML_METAL_EMBED_LIBRARY -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64 -DGGML_USE_BLAS
  6. #cgo darwin,arm64 CXXFLAGS: -DGGML_USE_METAL -DGGML_USE_ACCELERATE -DGGML_METAL_EMBED_LIBRARY -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64 -DGGML_USE_BLAS
  7. #cgo darwin,arm64 LDFLAGS: -framework Foundation -framework Metal -framework MetalKit -framework Accelerate
  8. #cgo darwin,amd64 CFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
  9. #cgo darwin,amd64 CXXFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
  10. #cgo darwin,amd64 LDFLAGS: -framework Foundation
  11. #cgo darwin,amd64,avx2 CFLAGS: -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
  12. #cgo darwin,amd64,avx2 CXXFLAGS: -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
  13. #cgo darwin,amd64,avx2 LDFLAGS: -framework Accelerate
  14. #cgo linux CFLAGS: -D_GNU_SOURCE
  15. #cgo linux CXXFLAGS: -D_GNU_SOURCE
  16. #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/Linux/arm64
  17. #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/Linux/amd64
  18. #cgo windows CFLAGS: -Wno-discarded-qualifiers
  19. #cgo windows LDFLAGS: -lmsvcrt -static-libstdc++ -static-libgcc -static
  20. #cgo windows,arm64 LDFLAGS: -L${SRCDIR}/build/Windows/arm64
  21. #cgo windows,amd64 LDFLAGS: -L${SRCDIR}/build/Windows/amd64
  22. #cgo avx CFLAGS: -mavx
  23. #cgo avx CXXFLAGS: -mavx
  24. #cgo avx2 CFLAGS: -mavx2 -mfma -mf16c
  25. #cgo avx2 CXXFLAGS: -mavx2 -mfma -mf16c
  26. #cgo cuda CFLAGS: -fPIE -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1
  27. #cgo cuda CXXFLAGS: -fPIE -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1
  28. #cgo rocm CFLAGS: -DGGML_USE_CUDA -DGGML_USE_HIPBLAS -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -D__HIP_PLATFORM_AMD__=1 -D__HIP_ROCclr__=1
  29. #cgo rocm CXXFLAGS: -DGGML_USE_CUDA -DGGML_USE_HIPBLAS -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -D__HIP_PLATFORM_AMD__=1 -D__HIP_ROCclr__=1
  30. #cgo rocm LDFLAGS: -L${SRCDIR} -lggml_rocm -lhipblas -lamdhip64 -lrocblas
  31. #cgo cuda_v11 LDFLAGS: -lggml_cuda_v11 -L/usr/local/cuda-11/lib64
  32. #cgo cuda_v12 LDFLAGS: -lggml_cuda_v12 -L/usr/local/cuda-12/lib64
  33. #cgo windows,cuda LDFLAGS: -lcuda -lcudart -lcublas -lcublasLt
  34. #cgo windows,rocm LDFLAGS: -lggml_rocm -lhipblas -lamdhip64 -lrocblas
  35. #cgo linux,cuda LDFLAGS: -lcuda -lcudart -lcublas -lcublasLt -lpthread -ldl -lrt -lresolv
  36. #cgo linux,rocm LDFLAGS: -L/opt/rocm/lib -lpthread -ldl -lrt -lresolv
  37. #include <stdlib.h>
  38. #include "llama.h"
  39. #include "clip.h"
  40. #include "llava.h"
  41. #include "sampling_ext.h"
  42. bool llamaProgressCallback(float progress, void *user_data);
  43. */
  44. import "C"
  45. import (
  46. _ "embed"
  47. "errors"
  48. "fmt"
  49. "runtime"
  50. "runtime/cgo"
  51. "strings"
  52. "unsafe"
  53. )
  54. var CpuFeatures = ""
  55. func BackendInit() {
  56. C.llama_backend_init()
  57. }
  58. func PrintSystemInfo() string {
  59. return C.GoString(C.llama_print_system_info())
  60. }
  61. type ContextParams struct {
  62. c C.struct_llama_context_params
  63. }
  64. func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool) ContextParams {
  65. params := C.llama_context_default_params()
  66. params.n_ctx = C.uint(numCtx)
  67. params.n_batch = C.uint(batchSize)
  68. params.n_seq_max = C.uint(numSeqMax)
  69. params.n_threads = C.int(threads)
  70. params.n_threads_batch = params.n_threads
  71. params.embeddings = C.bool(true)
  72. params.flash_attn = C.bool(flashAttention)
  73. return ContextParams{c: params}
  74. }
  75. type Context struct {
  76. c *C.struct_llama_context
  77. numThreads int
  78. }
  79. func (c *Context) KvCacheClear() {
  80. C.llama_kv_cache_clear(c.c)
  81. }
  82. func (c *Context) Decode(batch *Batch) error {
  83. // Positive return values does not mean a fatal error, but rather a warning.
  84. // 0 - success
  85. // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
  86. // < 0 - error
  87. code := int(C.llama_decode(c.c, batch.c))
  88. if code < 0 {
  89. return fmt.Errorf("llama_decode failed with code %d", code)
  90. }
  91. if code > 0 {
  92. return fmt.Errorf("could not find a KV slot for the batch - try reducing the size of the batch or increase the context. code: %d", code)
  93. }
  94. return nil
  95. }
  96. func (c *Context) Model() *Model {
  97. return &Model{c: C.llama_get_model(c.c)}
  98. }
  99. func (c *Context) GetLogitsIth(i int) []float32 {
  100. return unsafe.Slice((*float32)(unsafe.Pointer(C.llama_get_logits_ith(c.c, C.int(i)))), c.Model().NumVocab())
  101. }
  102. func (c *Context) SampleTokenGreedy(logits []float32) int {
  103. candidates := (*C.struct_llama_token_data)(C.malloc(C.size_t(len(logits)) * C.size_t(unsafe.Sizeof(C.struct_llama_token_data{}))))
  104. defer C.free(unsafe.Pointer(candidates))
  105. for i, logit := range logits {
  106. ptr := (*C.struct_llama_token_data)(unsafe.Pointer(uintptr(unsafe.Pointer(candidates)) + uintptr(i)*unsafe.Sizeof(C.struct_llama_token_data{})))
  107. ptr.id = C.int(i)
  108. ptr.logit = C.float(logit)
  109. ptr.p = 0.0
  110. }
  111. return int(C.llama_sample_token_greedy(c.c, &C.llama_token_data_array{
  112. data: candidates,
  113. size: C.size_t(len(logits)),
  114. sorted: C.bool(false),
  115. }))
  116. }
  117. func (c *Context) KvCacheSeqAdd(seqId int, p0 int, p1 int, delta int) {
  118. C.llama_kv_cache_seq_add(c.c, C.int(seqId), C.int(p0), C.int(p1), C.int(delta))
  119. }
  120. func (c *Context) KvCacheSeqRm(seqId int, p0 int, p1 int) bool {
  121. return bool(C.llama_kv_cache_seq_rm(c.c, C.int(seqId), C.int(p0), C.int(p1)))
  122. }
  123. func (c *Context) KvCacheSeqCp(srcSeqId int, dstSeqId int, p0 int, p1 int) {
  124. C.llama_kv_cache_seq_cp(c.c, C.int(srcSeqId), C.int(dstSeqId), C.int(p0), C.int(p1))
  125. }
  126. // Get the embeddings for a sequence id
  127. func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
  128. embeddings := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
  129. if embeddings == nil {
  130. return nil
  131. }
  132. return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
  133. }
  134. func (c *Context) GetEmbeddingsIth(i int) []float32 {
  135. return unsafe.Slice((*float32)(unsafe.Pointer(C.llama_get_embeddings_ith(c.c, C.int32_t(i)))), c.Model().NEmbd())
  136. }
  137. type ModelParams struct {
  138. NumGpuLayers int
  139. MainGpu int
  140. UseMmap bool
  141. UseMlock bool
  142. TensorSplit []float32
  143. Progress func(float32)
  144. VocabOnly bool
  145. }
  146. //export llamaProgressCallback
  147. func llamaProgressCallback(progress C.float, userData unsafe.Pointer) C.bool {
  148. handle := *(*cgo.Handle)(userData)
  149. callback := handle.Value().(func(float32))
  150. callback(float32(progress))
  151. return true
  152. }
  153. func LoadModelFromFile(modelPath string, params ModelParams) *Model {
  154. cparams := C.llama_model_default_params()
  155. cparams.n_gpu_layers = C.int(params.NumGpuLayers)
  156. cparams.main_gpu = C.int32_t(params.MainGpu)
  157. cparams.use_mmap = C.bool(params.UseMmap)
  158. cparams.use_mlock = C.bool(params.UseMlock)
  159. cparams.vocab_only = C.bool(params.VocabOnly)
  160. if len(params.TensorSplit) > 0 {
  161. tensorSplitData := &params.TensorSplit[0]
  162. var tensorSplitPin runtime.Pinner
  163. tensorSplitPin.Pin(tensorSplitData)
  164. defer tensorSplitPin.Unpin()
  165. cparams.tensor_split = (*C.float)(unsafe.Pointer(tensorSplitData))
  166. }
  167. if params.Progress != nil {
  168. handle := cgo.NewHandle(params.Progress)
  169. defer handle.Delete()
  170. var handlePin runtime.Pinner
  171. handlePin.Pin(&handle)
  172. defer handlePin.Unpin()
  173. cparams.progress_callback = C.llama_progress_callback(C.llamaProgressCallback)
  174. cparams.progress_callback_user_data = unsafe.Pointer(&handle)
  175. }
  176. return &Model{c: C.llama_load_model_from_file(C.CString(modelPath), cparams)}
  177. }
  178. func FreeModel(model *Model) {
  179. C.llama_free_model(model.c)
  180. }
  181. func NewContextWithModel(model *Model, params ContextParams) *Context {
  182. return &Context{
  183. c: C.llama_new_context_with_model(model.c, params.c),
  184. numThreads: int(params.c.n_threads),
  185. }
  186. }
  187. func (m *Model) NumVocab() int {
  188. return int(C.llama_n_vocab(m.c))
  189. }
  190. func (m *Model) TokenIsEog(token int) bool {
  191. return bool(C.llama_token_is_eog(m.c, C.llama_token(token)))
  192. }
  193. func (m *Model) AddBOSToken() bool {
  194. return bool(C.llama_add_bos_token(m.c))
  195. }
  196. func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float32, threads int) error {
  197. cLoraPath := C.CString(loraPath)
  198. defer C.free(unsafe.Pointer(cLoraPath))
  199. loraAdapter := C.llama_lora_adapter_init(m.c, cLoraPath)
  200. err := -1
  201. if loraAdapter != nil {
  202. err = int(C.llama_lora_adapter_set(context.c, loraAdapter, C.float(scale)))
  203. }
  204. if err != 0 {
  205. return errors.New("error applying lora from file")
  206. }
  207. return nil
  208. }
  209. type Batch struct {
  210. c C.struct_llama_batch
  211. batchSize int
  212. embedSize int
  213. }
  214. // Creates a new batch for either word tokens if embed is 0 or
  215. // image embeddings if embed is specified. Batches cannot contain
  216. // both types at the same time
  217. func NewBatch(nTokens int, embed int, maxSeq int) *Batch {
  218. return &Batch{
  219. c: C.llama_batch_init(C.int(nTokens), C.int(embed), C.int(maxSeq)),
  220. batchSize: nTokens,
  221. embedSize: embed,
  222. }
  223. }
  224. func (b *Batch) NumTokens() int {
  225. return int(b.c.n_tokens)
  226. }
  227. func (b *Batch) IsEmbedding() bool {
  228. return b.embedSize != 0
  229. }
  230. // Add adds either a token or an image embedding to the batch depending on the type
  231. // when the batch was initialized. The other argument will be ignored. Adds to the
  232. // batch with the given position for the given sequence ids, and optionally instructs
  233. // to include logits.
  234. func (b *Batch) Add(token int, embed []float32, pos int, seqIds []int, logits bool) {
  235. if !b.IsEmbedding() {
  236. unsafe.Slice(b.c.token, b.batchSize)[b.c.n_tokens] = C.llama_token(token)
  237. } else {
  238. copy(unsafe.Slice((*float32)(b.c.embd), b.batchSize*b.embedSize)[int(b.c.n_tokens)*b.embedSize:], embed)
  239. }
  240. unsafe.Slice(b.c.pos, b.batchSize)[b.c.n_tokens] = C.llama_pos(pos)
  241. unsafe.Slice(b.c.n_seq_id, b.batchSize)[b.c.n_tokens] = C.int(len(seqIds))
  242. for i, s := range seqIds {
  243. unsafe.Slice((unsafe.Slice(b.c.seq_id, b.batchSize)[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
  244. }
  245. if logits {
  246. unsafe.Slice(b.c.logits, b.batchSize)[b.c.n_tokens] = 1
  247. }
  248. b.c.n_tokens += 1
  249. }
  250. func (b *Batch) Clear() {
  251. b.c.n_tokens = 0
  252. }
  253. func (b *Batch) Free() {
  254. b.batchSize = 0
  255. C.llama_batch_free(b.c)
  256. }
  257. type Model struct {
  258. c *C.struct_llama_model
  259. }
  260. func (m *Model) TokenToPiece(token int) string {
  261. tokenLen := 12
  262. buf := make([]byte, tokenLen)
  263. tokenLen = int(C.llama_token_to_piece(
  264. m.c,
  265. C.int32_t(token),
  266. (*C.char)(unsafe.Pointer(&buf[0])),
  267. C.int32_t(tokenLen),
  268. C.int32_t(0),
  269. C.bool(true),
  270. ))
  271. if tokenLen < 0 {
  272. tokenLen = -tokenLen
  273. buf = make([]byte, tokenLen)
  274. C.llama_token_to_piece(
  275. m.c,
  276. C.int32_t(token),
  277. (*C.char)(unsafe.Pointer(&buf[0])),
  278. C.int32_t(tokenLen),
  279. C.int32_t(0),
  280. C.bool(true),
  281. )
  282. }
  283. return strings.TrimRight(string(buf), "\x00")
  284. }
  285. func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int, error) {
  286. maxTokens := len(text) + 2
  287. cTokens := make([]C.llama_token, maxTokens)
  288. cText := C.CString(text)
  289. defer C.free(unsafe.Pointer(cText))
  290. result := C.llama_tokenize(
  291. m.c,
  292. cText,
  293. C.int32_t(len(text)),
  294. &cTokens[0],
  295. C.int32_t(maxTokens),
  296. C.bool(addSpecial),
  297. C.bool(parseSpecial),
  298. )
  299. // if the result is negative, reallocate and retry with the correct buffer size
  300. if result < 0 {
  301. maxTokens = int(-result)
  302. cTokens = make([]C.llama_token, maxTokens)
  303. result = C.llama_tokenize(
  304. m.c,
  305. cText,
  306. C.int32_t(len(text)),
  307. &cTokens[0],
  308. C.int32_t(maxTokens),
  309. C.bool(addSpecial),
  310. C.bool(parseSpecial),
  311. )
  312. if result < 0 {
  313. return nil, fmt.Errorf("tokenization failed, required %d tokens", -result)
  314. }
  315. }
  316. tokens := make([]int, result)
  317. for i := range result {
  318. tokens[i] = int(cTokens[i])
  319. }
  320. return tokens, nil
  321. }
  322. func (m *Model) NEmbd() int {
  323. return int(C.llama_n_embd(m.c))
  324. }
  325. func Quantize(infile, outfile string, ftype uint32) error {
  326. cinfile := C.CString(infile)
  327. defer C.free(unsafe.Pointer(cinfile))
  328. coutfile := C.CString(outfile)
  329. defer C.free(unsafe.Pointer(coutfile))
  330. params := C.llama_model_quantize_default_params()
  331. params.nthread = -1
  332. params.ftype = ftype
  333. if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
  334. return fmt.Errorf("llama_model_quantize: %d", rc)
  335. }
  336. return nil
  337. }
  338. // llava
  339. type ClipContext struct {
  340. c *C.struct_clip_ctx
  341. }
  342. func NewClipContext(modelPath string) *ClipContext {
  343. mp := C.CString(modelPath)
  344. defer C.free(unsafe.Pointer(mp))
  345. cc := C.clip_model_load(mp, 1)
  346. return &ClipContext{c: cc}
  347. }
  348. func (c *ClipContext) Free() {
  349. C.clip_free(c.c)
  350. }
  351. func NewLlavaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []byte) [][]float32 {
  352. c := C.llava_image_embed_make_with_bytes(clipContext.c, C.int(llamaContext.numThreads), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data)))
  353. numTokens := int(c.n_image_pos)
  354. numEmbed := llamaContext.Model().NEmbd()
  355. s := unsafe.Slice((*float32)(c.embed), numEmbed*numTokens)
  356. embed := make([][]float32, numTokens)
  357. rows := make([]float32, len(s))
  358. copy(rows, s)
  359. for i := range embed {
  360. embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
  361. }
  362. C.llava_image_embed_free(c)
  363. return embed
  364. }
  365. // sampling
  366. // TODO: this is a temporary wrapper to allow calling C++ code from CGo
  367. type SamplingContext struct {
  368. c *C.struct_llama_sampling_context
  369. }
  370. type SamplingParams struct {
  371. TopK int
  372. TopP float32
  373. MinP float32
  374. TfsZ float32
  375. TypicalP float32
  376. Temp float32
  377. RepeatLastN int
  378. PenaltyRepeat float32
  379. PenaltyFreq float32
  380. PenaltyPresent float32
  381. Mirostat int
  382. MirostatTau float32
  383. MirostatEta float32
  384. PenalizeNl bool
  385. Seed uint32
  386. Grammar string
  387. }
  388. func NewSamplingContext(params SamplingParams) *SamplingContext {
  389. var cparams C.struct_llama_sampling_cparams
  390. cparams.top_k = C.int32_t(params.TopK)
  391. cparams.top_p = C.float(params.TopP)
  392. cparams.min_p = C.float(params.MinP)
  393. cparams.tfs_z = C.float(params.TfsZ)
  394. cparams.typical_p = C.float(params.TypicalP)
  395. cparams.temp = C.float(params.Temp)
  396. cparams.penalty_last_n = C.int32_t(params.RepeatLastN)
  397. cparams.penalty_repeat = C.float(params.PenaltyRepeat)
  398. cparams.penalty_freq = C.float(params.PenaltyFreq)
  399. cparams.penalty_present = C.float(params.PenaltyFreq)
  400. cparams.mirostat = C.int32_t(params.Mirostat)
  401. cparams.mirostat_tau = C.float(params.MirostatTau)
  402. cparams.mirostat_eta = C.float(params.MirostatEta)
  403. cparams.penalize_nl = C.bool(params.PenalizeNl)
  404. cparams.seed = C.uint32_t(params.Seed)
  405. grammar := C.CString(params.Grammar)
  406. defer C.free(unsafe.Pointer(grammar))
  407. cparams.grammar = grammar
  408. context := &SamplingContext{c: C.llama_sampling_cinit(&cparams)}
  409. runtime.SetFinalizer(context, func(s *SamplingContext) { C.llama_sampling_cfree(s.c) })
  410. return context
  411. }
  412. func (s *SamplingContext) Reset() {
  413. C.llama_sampling_creset(s.c)
  414. }
  415. func (s *SamplingContext) Sample(ctxMain *Context, ctxConfig *Context, idx int) int {
  416. // TODO (jmorganca): handle nil for all args
  417. if ctxConfig == nil {
  418. return int(C.llama_sampling_csample(s.c, ctxMain.c, nil, C.int(idx)))
  419. }
  420. return int(C.llama_sampling_csample(s.c, ctxMain.c, ctxConfig.c, C.int(idx)))
  421. }
  422. func (s *SamplingContext) Accept(ctxMain *Context, id int, applyGrammar bool) {
  423. C.llama_sampling_caccept(s.c, ctxMain.c, C.llama_token(id), C.bool(applyGrammar))
  424. }