llama.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466
  1. package llama
  2. /*
  3. #cgo CFLAGS: -std=c11 -DNDEBUG -DLOG_DISABLE_LOGS
  4. #cgo CXXFLAGS: -std=c++11 -DNDEBUG -DLOG_DISABLE_LOGS
  5. #cgo darwin,arm64 CFLAGS: -DGGML_USE_METAL -DGGML_USE_ACCELERATE -DGGML_METAL_EMBED_LIBRARY -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
  6. #cgo darwin,arm64 CXXFLAGS: -DGGML_USE_METAL -DGGML_USE_ACCELERATE -DGGML_METAL_EMBED_LIBRARY -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
  7. #cgo darwin,arm64 LDFLAGS: -framework Foundation -framework Metal -framework MetalKit -framework Accelerate
  8. #cgo darwin,amd64 CFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
  9. #cgo darwin,amd64 CXXFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
  10. #cgo darwin,amd64 LDFLAGS: -framework Foundation
  11. #cgo darwin,amd64 LDFLAGS: -L${SRCDIR}/build/Darwin/amd64
  12. #cgo darwin,amd64,avx2 CFLAGS: -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
  13. #cgo darwin,amd64,avx2 CXXFLAGS: -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
  14. #cgo darwin,amd64,avx2 LDFLAGS: -framework Accelerate
  15. #cgo linux CFLAGS: -D_GNU_SOURCE
  16. #cgo linux CXXFLAGS: -D_GNU_SOURCE
  17. #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/Linux/arm64
  18. #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/Linux/amd64
  19. #cgo windows CFLAGS: -Wno-discarded-qualifiers
  20. #cgo windows LDFLAGS: -lmsvcrt
  21. #cgo windows,arm64 LDFLAGS: -L${SRCDIR}/build/Windows/arm64
  22. #cgo windows,amd64 LDFLAGS: -L${SRCDIR}/build/Windows/amd64
  23. #cgo avx CFLAGS: -mavx
  24. #cgo avx CXXFLAGS: -mavx
  25. #cgo avx2 CFLAGS: -mavx2 -mfma
  26. #cgo avx2 CXXFLAGS: -mavx2 -mfma
  27. #cgo cuda CFLAGS: -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  28. #cgo cuda CXXFLAGS: -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  29. #cgo rocm CFLAGS: -DGGML_USE_CUDA -DGGML_USE_HIPBLAS -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  30. #cgo rocm CXXFLAGS: -DGGML_USE_CUDA -DGGML_USE_HIPBLAS -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  31. #cgo rocm LDFLAGS: -L${SRCDIR} -lggml_hipblas -lhipblas -lamdhip64 -lrocblas
  32. #cgo windows,cuda LDFLAGS: -lggml_cuda -lcuda -lcudart -lcublas -lcublasLt
  33. #cgo windows,rocm LDFLAGS: -lggml_hipblas -lhipblas -lamdhip64 -lrocblas
  34. #cgo linux,cuda LDFLAGS: -L/usr/local/cuda/lib64 -lggml_cuda -lcuda -lcudart -lcublas -lcublasLt -lpthread -ldl -lrt
  35. #cgo linux,rocm LDFLAGS: -L/opt/rocm/lib
  36. #include <stdlib.h>
  37. #include "llama.h"
  38. #include "clip.h"
  39. #include "llava.h"
  40. #include "sampling_ext.h"
  41. bool llamaProgressCallback(float progress, void *user_data);
  42. */
  43. import "C"
  44. import (
  45. _ "embed"
  46. "errors"
  47. "fmt"
  48. "runtime"
  49. "runtime/cgo"
  50. "strings"
  51. "unsafe"
  52. )
  53. func BackendInit() {
  54. C.llama_backend_init()
  55. }
  56. func PrintSystemInfo() string {
  57. return C.GoString(C.llama_print_system_info())
  58. }
  59. type ContextParams struct {
  60. c C.struct_llama_context_params
  61. }
  62. func NewContextParams(numCtx int, threads int, flashAttention bool) ContextParams {
  63. params := C.llama_context_default_params()
  64. params.n_ctx = C.uint(numCtx)
  65. params.n_threads = C.uint(runtime.NumCPU())
  66. params.n_threads_batch = params.n_threads
  67. params.embeddings = C.bool(true)
  68. params.flash_attn = C.bool(flashAttention)
  69. params.n_threads = C.uint(threads)
  70. return ContextParams{c: params}
  71. }
  72. type Context struct {
  73. c *C.struct_llama_context
  74. }
  75. func (c *Context) KvCacheClear() {
  76. C.llama_kv_cache_clear(c.c)
  77. }
  78. func (c *Context) Decode(batch Batch) error {
  79. // Positive return values does not mean a fatal error, but rather a warning.
  80. // 0 - success
  81. // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
  82. // < 0 - error
  83. code := int(C.llama_decode(c.c, batch.c))
  84. if code < 0 {
  85. return fmt.Errorf("llama_decode failed with code %d", code)
  86. }
  87. if code > 0 {
  88. return fmt.Errorf("could not find a KV slot for the batch - try reducing the size of the batch or increase the context. code: %d", code)
  89. }
  90. return nil
  91. }
  92. func (c *Context) Model() *Model {
  93. return &Model{c: C.llama_get_model(c.c)}
  94. }
  95. func (c *Context) GetLogitsIth(i int) []float32 {
  96. return unsafe.Slice((*float32)(unsafe.Pointer(C.llama_get_logits_ith(c.c, C.int(i)))), c.Model().NumVocab())
  97. }
  98. func (c *Context) SampleTokenGreedy(logits []float32) int {
  99. candidates := (*C.struct_llama_token_data)(C.malloc(C.size_t(len(logits)) * C.size_t(unsafe.Sizeof(C.struct_llama_token_data{}))))
  100. defer C.free(unsafe.Pointer(candidates))
  101. for i, logit := range logits {
  102. ptr := (*C.struct_llama_token_data)(unsafe.Pointer(uintptr(unsafe.Pointer(candidates)) + uintptr(i)*unsafe.Sizeof(C.struct_llama_token_data{})))
  103. ptr.id = C.int(i)
  104. ptr.logit = C.float(logit)
  105. ptr.p = 0.0
  106. }
  107. return int(C.llama_sample_token_greedy(c.c, &C.llama_token_data_array{
  108. data: candidates,
  109. size: C.size_t(len(logits)),
  110. sorted: C.bool(false),
  111. }))
  112. }
  113. func (c *Context) KvCacheSeqAdd(seqId int, p0 int, p1 int, delta int) {
  114. C.llama_kv_cache_seq_add(c.c, C.int(seqId), C.int(p0), C.int(p1), C.int(delta))
  115. }
  116. func (c *Context) KvCacheSeqRm(seqId int, p0 int, p1 int) bool {
  117. return bool(C.llama_kv_cache_seq_rm(c.c, C.int(seqId), C.int(p0), C.int(p1)))
  118. }
  119. // Get the embeddings for a sequence id
  120. func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
  121. embeddings := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
  122. if embeddings == nil {
  123. return nil
  124. }
  125. return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
  126. }
  127. func (c *Context) GetEmbeddingsIth(i int) []float32 {
  128. return unsafe.Slice((*float32)(unsafe.Pointer(C.llama_get_embeddings_ith(c.c, C.int32_t(i)))), c.Model().NEmbd())
  129. }
  130. type ModelParams struct {
  131. NumGpuLayers int
  132. MainGpu int
  133. UseMmap bool
  134. UseMlock bool
  135. TensorSplit []float32
  136. Progress func(float32)
  137. }
  138. //export llamaProgressCallback
  139. func llamaProgressCallback(progress C.float, userData unsafe.Pointer) C.bool {
  140. handle := cgo.Handle(userData)
  141. callback := handle.Value().(func(float32))
  142. callback(float32(progress))
  143. return true
  144. }
  145. func LoadModelFromFile(modelPath string, params ModelParams) *Model {
  146. cparams := C.llama_model_default_params()
  147. cparams.n_gpu_layers = C.int(params.NumGpuLayers)
  148. cparams.main_gpu = C.int32_t(params.MainGpu)
  149. cparams.use_mmap = C.bool(params.UseMmap)
  150. cparams.use_mlock = C.bool(params.UseMlock)
  151. if len(params.TensorSplit) > 0 {
  152. tensorSplitData := &params.TensorSplit[0]
  153. var tensorSplitPin runtime.Pinner
  154. tensorSplitPin.Pin(tensorSplitData)
  155. defer tensorSplitPin.Unpin()
  156. cparams.tensor_split = (*C.float)(unsafe.Pointer(tensorSplitData))
  157. }
  158. if params.Progress != nil {
  159. handle := cgo.NewHandle(params.Progress)
  160. defer handle.Delete()
  161. cparams.progress_callback = C.llama_progress_callback(C.llamaProgressCallback)
  162. cparams.progress_callback_user_data = unsafe.Pointer(handle)
  163. }
  164. return &Model{c: C.llama_load_model_from_file(C.CString(modelPath), cparams)}
  165. }
  166. func NewContextWithModel(model *Model, params ContextParams) *Context {
  167. return &Context{c: C.llama_new_context_with_model(model.c, params.c)}
  168. }
  169. func (m *Model) NumVocab() int {
  170. return int(C.llama_n_vocab(m.c))
  171. }
  172. func (m *Model) TokenIsEog(token int) bool {
  173. return bool(C.llama_token_is_eog(m.c, C.llama_token(token)))
  174. }
  175. func (m *Model) ShouldAddBOSToken() bool {
  176. addBos := int(C.llama_add_bos_token(m.c))
  177. if addBos != -1 {
  178. return addBos != 0
  179. } else {
  180. return C.llama_vocab_type(m.c) == C.LLAMA_VOCAB_TYPE_SPM
  181. }
  182. }
  183. func (m *Model) ApplyLoraFromFile(loraPath string, scale float32, baseModelPath string, threads int) error {
  184. cLoraPath := C.CString(loraPath)
  185. defer C.free(unsafe.Pointer(cLoraPath))
  186. var cBaseModelPath *C.char
  187. if baseModelPath != "" {
  188. cBaseModelPath = C.CString(baseModelPath)
  189. }
  190. code := int(C.llama_model_apply_lora_from_file(m.c, cLoraPath, C.float(scale), cBaseModelPath, C.int32_t(threads)))
  191. if code != 0 {
  192. return errors.New("error applying lora from file")
  193. }
  194. return nil
  195. }
  196. type Batch struct {
  197. c C.struct_llama_batch
  198. batchSize int
  199. }
  200. func NewBatch(nTokens int, embd int, maxSeq int) Batch {
  201. return Batch{
  202. c: C.llama_batch_init(C.int(nTokens), C.int(embd), C.int(maxSeq)),
  203. batchSize: nTokens,
  204. }
  205. }
  206. func (b *Batch) NumTokens() int {
  207. return int(b.c.n_tokens)
  208. }
  209. // Add adds a token to the batch with the given position for the given
  210. // sequence ids, and optionally instructs to include logits.
  211. func (b *Batch) Add(token int, pos int, seqIds []int, logits bool) {
  212. unsafe.Slice(b.c.token, b.batchSize)[b.c.n_tokens] = C.llama_token(token)
  213. unsafe.Slice(b.c.pos, b.batchSize)[b.c.n_tokens] = C.llama_pos(pos)
  214. unsafe.Slice(b.c.n_seq_id, b.batchSize)[b.c.n_tokens] = C.int(len(seqIds))
  215. for i, s := range seqIds {
  216. unsafe.Slice((unsafe.Slice(b.c.seq_id, b.batchSize)[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
  217. }
  218. if logits {
  219. unsafe.Slice(b.c.logits, b.batchSize)[b.c.n_tokens] = 1
  220. }
  221. b.c.n_tokens += 1
  222. }
  223. func (b *Batch) Clear() {
  224. b.c.n_tokens = 0
  225. }
  226. func (b *Batch) Free() {
  227. b.batchSize = 0
  228. C.llama_batch_free(b.c)
  229. }
  230. func BatchGetOne(tokens []int, pos0 int, seqId int) Batch {
  231. return Batch{c: C.llama_batch_get_one((*C.int)(unsafe.Pointer(&tokens[0])), C.int32_t(len(tokens)), C.int(pos0), C.int(seqId))}
  232. }
  233. type Model struct {
  234. c *C.struct_llama_model
  235. }
  236. func (m *Model) TokenToPiece(token int) string {
  237. tokenLen := 12
  238. buf := make([]byte, tokenLen)
  239. tokenLen = int(C.llama_token_to_piece(
  240. m.c,
  241. C.int32_t(token),
  242. (*C.char)(unsafe.Pointer(&buf[0])),
  243. C.int32_t(tokenLen),
  244. C.int32_t(0),
  245. C.bool(true),
  246. ))
  247. if tokenLen < 0 {
  248. tokenLen = -tokenLen
  249. buf = make([]byte, tokenLen)
  250. C.llama_token_to_piece(
  251. m.c,
  252. C.int32_t(token),
  253. (*C.char)(unsafe.Pointer(&buf[0])),
  254. C.int32_t(tokenLen),
  255. C.int32_t(0),
  256. C.bool(true),
  257. )
  258. }
  259. return strings.TrimRight(string(buf), "\x00")
  260. }
  261. func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int, error) {
  262. maxTokens := len(text) + 2
  263. cTokens := make([]C.llama_token, maxTokens)
  264. cText := C.CString(text)
  265. defer C.free(unsafe.Pointer(cText))
  266. result := C.llama_tokenize(
  267. m.c,
  268. cText,
  269. C.int32_t(len(text)),
  270. &cTokens[0],
  271. C.int32_t(maxTokens),
  272. C.bool(addSpecial),
  273. C.bool(parseSpecial),
  274. )
  275. if result < 0 {
  276. return nil, fmt.Errorf("tokenization failed, required %d tokens", -result)
  277. }
  278. tokens := make([]int, result)
  279. for i := range result {
  280. tokens[i] = int(cTokens[i])
  281. }
  282. return tokens, nil
  283. }
  284. func (m *Model) NEmbd() int {
  285. return int(C.llama_n_embd(m.c))
  286. }
  287. func Quantize(infile, outfile string, ftype uint32) error {
  288. cinfile := C.CString(infile)
  289. defer C.free(unsafe.Pointer(cinfile))
  290. coutfile := C.CString(outfile)
  291. defer C.free(unsafe.Pointer(coutfile))
  292. params := C.llama_model_quantize_default_params()
  293. params.nthread = -1
  294. params.ftype = ftype
  295. if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
  296. return fmt.Errorf("llama_model_quantize: %d", rc)
  297. }
  298. return nil
  299. }
  300. // llava
  301. type ClipContext struct {
  302. c *C.struct_clip_ctx
  303. }
  304. func NewClipContext(modelPath string) *ClipContext {
  305. mp := C.CString(modelPath)
  306. defer C.free(unsafe.Pointer(mp))
  307. cc := C.clip_model_load(mp, 1)
  308. return &ClipContext{c: cc}
  309. }
  310. type LlavaContext struct {
  311. c *C.struct_llava_context
  312. }
  313. type LlavaImageEmbed struct {
  314. c *C.struct_llava_image_embed
  315. }
  316. func NewLlavaImageEmbed(clipContext *ClipContext, data []byte) *LlavaImageEmbed {
  317. return &LlavaImageEmbed{c: C.llava_image_embed_make_with_bytes(clipContext.c, C.int(runtime.NumCPU()), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data)))}
  318. }
  319. func LlavaEvalImageEmbed(llamaContext *Context, embed *LlavaImageEmbed, nBatch int, nPast *int) {
  320. C.llava_eval_image_embed(llamaContext.c, embed.c, C.int(nBatch), (*C.int)(unsafe.Pointer(nPast)))
  321. }
  322. // sampling
  323. // TODO: this is a temporary wrapper to allow calling C++ code from CGo
  324. type SamplingContext struct {
  325. c *C.struct_llama_sampling_context
  326. }
  327. type SamplingParams struct {
  328. TopK int
  329. TopP float32
  330. MinP float32
  331. TfsZ float32
  332. TypicalP float32
  333. Temp float32
  334. RepeatLastN int
  335. PenaltyRepeat float32
  336. PenaltyFreq float32
  337. PenaltyPresent float32
  338. Mirostat int
  339. MirostatTau float32
  340. MirostatEta float32
  341. PenalizeNl bool
  342. Seed uint32
  343. Grammar string
  344. }
  345. func NewSamplingContext(params SamplingParams) *SamplingContext {
  346. var cparams C.struct_llama_sampling_cparams
  347. cparams.top_k = C.int32_t(params.TopK)
  348. cparams.top_p = C.float(params.TopP)
  349. cparams.min_p = C.float(params.MinP)
  350. cparams.tfs_z = C.float(params.TfsZ)
  351. cparams.typical_p = C.float(params.TypicalP)
  352. cparams.temp = C.float(params.Temp)
  353. cparams.penalty_last_n = C.int32_t(params.RepeatLastN)
  354. cparams.penalty_repeat = C.float(params.PenaltyRepeat)
  355. cparams.penalty_freq = C.float(params.PenaltyFreq)
  356. cparams.penalty_present = C.float(params.PenaltyFreq)
  357. cparams.mirostat = C.int32_t(params.Mirostat)
  358. cparams.mirostat_tau = C.float(params.MirostatTau)
  359. cparams.mirostat_eta = C.float(params.MirostatEta)
  360. cparams.penalize_nl = C.bool(params.PenalizeNl)
  361. cparams.seed = C.uint32_t(params.Seed)
  362. grammar := C.CString(params.Grammar)
  363. defer C.free(unsafe.Pointer(grammar))
  364. cparams.grammar = grammar
  365. return &SamplingContext{c: C.llama_sampling_cinit(&cparams)}
  366. }
  367. func (s *SamplingContext) Free() {
  368. if s != nil {
  369. C.llama_sampling_cfree(s.c)
  370. }
  371. }
  372. func (s *SamplingContext) Reset() {
  373. C.llama_sampling_creset(s.c)
  374. }
  375. func (s *SamplingContext) Sample(ctxMain *Context, ctxConfig *Context, idx int) int {
  376. // TODO (jmorganca): handle nil for all args
  377. if ctxConfig == nil {
  378. return int(C.llama_sampling_csample(s.c, ctxMain.c, nil, C.int(idx)))
  379. }
  380. return int(C.llama_sampling_csample(s.c, ctxMain.c, ctxConfig.c, C.int(idx)))
  381. }
  382. func (s *SamplingContext) Accept(ctxMain *Context, id int, applyGrammar bool) {
  383. C.llama_sampling_caccept(s.c, ctxMain.c, C.llama_token(id), C.bool(applyGrammar))
  384. }