llama.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. package llama
  2. // #cgo CFLAGS: -std=c11 -DNDEBUG -DLOG_DISABLE_LOGS
  3. // #cgo CXXFLAGS: -std=c++11 -DNDEBUG -DLOG_DISABLE_LOGS
  4. // #cgo darwin,arm64 CFLAGS: -DGGML_USE_METAL -DGGML_USE_ACCELERATE -DGGML_METAL_EMBED_LIBRARY -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
  5. // #cgo darwin,arm64 CXXFLAGS: -DGGML_USE_METAL -DGGML_USE_ACCELERATE -DGGML_METAL_EMBED_LIBRARY -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
  6. // #cgo darwin,arm64 LDFLAGS: -framework Foundation -framework Metal -framework MetalKit -framework Accelerate
  7. // #cgo darwin,amd64 CFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
  8. // #cgo darwin,amd64 CXXFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
  9. // #cgo darwin,amd64 LDFLAGS: -framework Foundation
  10. // #cgo darwin,amd64,avx2 CFLAGS: -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
  11. // #cgo darwin,amd64,avx2 CXXFLAGS: -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
  12. // #cgo darwin,amd64,avx2 LDFLAGS: -framework Accelerate
  13. // #cgo linux CFLAGS: -D_GNU_SOURCE
  14. // #cgo linux CXXFLAGS: -D_GNU_SOURCE
  15. // #cgo windows CFLAGS: -Wno-discarded-qualifiers
  16. // #cgo windows LDFLAGS: -lmsvcrt
  17. // #cgo avx CFLAGS: -mavx
  18. // #cgo avx CXXFLAGS: -mavx
  19. // #cgo avx2 CFLAGS: -mavx2 -mfma
  20. // #cgo avx2 CXXFLAGS: -mavx2 -mfma
  21. // #cgo cuda CFLAGS: -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  22. // #cgo cuda CXXFLAGS: -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  23. // #cgo rocm CFLAGS: -DGGML_USE_CUDA -DGGML_USE_HIPBLAS -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  24. // #cgo rocm CXXFLAGS: -DGGML_USE_CUDA -DGGML_USE_HIPBLAS -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  25. // #cgo rocm LDFLAGS: -L${SRCDIR} -lggml_hipblas -lhipblas -lamdhip64 -lrocblas
  26. // #cgo windows,cuda LDFLAGS: -L${SRCDIR} -L"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.3/lib/x64" -lggml_cuda -lcuda -lcudart -lcublas -lcublasLt
  27. // #cgo windows,rocm LDFLAGS: -L${SRCDIR} -L"C:/Program Files/AMD/ROCm/5.7/lib" -lggml_hipblas -lhipblas -lamdhip64 -lrocblas
  28. // #cgo linux,cuda LDFLAGS: -L${SRCDIR} -L/usr/local/cuda/lib64 -lggml_cuda -lcuda -lcudart -lcublas -lcublasLt -lpthread -ldl -lrt
  29. // #cgo linux,rocm LDFLAGS: -L/opt/rocm/lib
  30. // #include <stdlib.h>
  31. // #include "llama.h"
  32. // #include "clip.h"
  33. // #include "llava.h"
  34. // #include "sampling_ext.h"
  35. //
  36. // bool llamaProgressCallback(float progress, void *user_data);
  37. import "C"
  38. import (
  39. _ "embed"
  40. "errors"
  41. "fmt"
  42. "runtime"
  43. "runtime/cgo"
  44. "strings"
  45. "unsafe"
  46. )
  47. //go:embed ggml-common.h
  48. var ggmlCommon string
  49. //go:embed ggml-metal.metal
  50. var ggmlMetal string
  51. // TODO: write me somewhere else
  52. func init() {
  53. metal := strings.ReplaceAll(ggmlMetal, `#include "ggml-common.h"`, ggmlCommon)
  54. fmt.Println(metal)
  55. cMetal := C.CString(metal)
  56. C.ggml_metallib_start = cMetal
  57. C.ggml_metallib_end = (*C.char)(unsafe.Pointer(uintptr(unsafe.Pointer(cMetal)) + uintptr(len(metal))))
  58. }
  59. func BackendInit() {
  60. C.llama_backend_init()
  61. }
  62. func PrintSystemInfo() string {
  63. return C.GoString(C.llama_print_system_info())
  64. }
  65. type ContextParams struct {
  66. c C.struct_llama_context_params
  67. }
  68. func NewContextParams(numCtx int, threads int, flashAttention bool) ContextParams {
  69. params := C.llama_context_default_params()
  70. params.n_ctx = C.uint(numCtx)
  71. params.n_threads = C.uint(runtime.NumCPU())
  72. params.n_threads_batch = params.n_threads
  73. params.embeddings = C.bool(true)
  74. params.flash_attn = C.bool(flashAttention)
  75. params.n_threads = C.uint(threads)
  76. return ContextParams{c: params}
  77. }
  78. type ModelParams struct {
  79. c C.struct_llama_model_params
  80. }
  81. //export llamaProgressCallback
  82. func llamaProgressCallback(progress C.float, userData unsafe.Pointer) C.bool {
  83. handle := cgo.Handle(userData)
  84. callback := handle.Value().(func(float32))
  85. callback(float32(progress))
  86. return true
  87. }
  88. func NewModelParams(numGpuLayers int, mainGpu int, callback func(float32)) ModelParams {
  89. fmt.Println("Contents of ggml-common.h:")
  90. fmt.Println(ggmlCommon)
  91. fmt.Println("\nContents of ggml-metal.in.metal:")
  92. fmt.Println(ggmlMetal)
  93. params := C.llama_model_default_params()
  94. params.n_gpu_layers = C.int(numGpuLayers)
  95. params.main_gpu = C.int32_t(mainGpu)
  96. handle := cgo.NewHandle(callback)
  97. params.progress_callback = C.llama_progress_callback(C.llamaProgressCallback)
  98. params.progress_callback_user_data = unsafe.Pointer(handle)
  99. runtime.SetFinalizer(&params, func(p *C.struct_llama_model_params) {
  100. handle.Delete()
  101. })
  102. return ModelParams{c: params}
  103. }
  104. type Context struct {
  105. c *C.struct_llama_context
  106. }
  107. func (c *Context) KvCacheClear() {
  108. C.llama_kv_cache_clear(c.c)
  109. }
  110. func (c *Context) Decode(batch Batch) error {
  111. // Positive return values does not mean a fatal error, but rather a warning.
  112. // 0 - success
  113. // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
  114. // < 0 - error
  115. code := int(C.llama_decode(c.c, batch.c))
  116. if code < 0 {
  117. return fmt.Errorf("llama_decode failed with code %d", code)
  118. }
  119. if code > 0 {
  120. return fmt.Errorf("could not find a KV slot for the batch - try reducing the size of the batch or increase the context. code: %d", code)
  121. }
  122. return nil
  123. }
  124. func (c *Context) Model() *Model {
  125. return &Model{c: C.llama_get_model(c.c)}
  126. }
  127. func (c *Context) GetLogitsIth(i int) []float32 {
  128. return unsafe.Slice((*float32)(unsafe.Pointer(C.llama_get_logits_ith(c.c, C.int(i)))), c.Model().NumVocab())
  129. }
  130. func (c *Context) SampleTokenGreedy(logits []float32) int {
  131. candidates := (*C.struct_llama_token_data)(C.malloc(C.size_t(len(logits)) * C.size_t(unsafe.Sizeof(C.struct_llama_token_data{}))))
  132. defer C.free(unsafe.Pointer(candidates))
  133. for i, logit := range logits {
  134. ptr := (*C.struct_llama_token_data)(unsafe.Pointer(uintptr(unsafe.Pointer(candidates)) + uintptr(i)*unsafe.Sizeof(C.struct_llama_token_data{})))
  135. ptr.id = C.int(i)
  136. ptr.logit = C.float(logit)
  137. ptr.p = 0.0
  138. }
  139. return int(C.llama_sample_token_greedy(c.c, &C.llama_token_data_array{
  140. data: candidates,
  141. size: C.size_t(len(logits)),
  142. sorted: C.bool(false),
  143. }))
  144. }
  145. func (c *Context) KvCacheSeqRm(seqId int, p0 int, p1 int) bool {
  146. return bool(C.llama_kv_cache_seq_rm(c.c, C.int(seqId), C.int(p0), C.int(p1)))
  147. }
  148. // Get the embeddings for a sequence id
  149. func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
  150. embeddings := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
  151. if embeddings == nil {
  152. return nil
  153. }
  154. return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
  155. }
  156. func (c *Context) GetEmbeddingsIth(i int) []float32 {
  157. return unsafe.Slice((*float32)(unsafe.Pointer(C.llama_get_embeddings_ith(c.c, C.int32_t(i)))), c.Model().NEmbd())
  158. }
  159. func LoadModelFromFile(modelPath string, params ModelParams) *Model {
  160. return &Model{c: C.llama_load_model_from_file(C.CString(modelPath), params.c)}
  161. }
  162. func NewContextWithModel(model *Model, params ContextParams) *Context {
  163. return &Context{c: C.llama_new_context_with_model(model.c, params.c)}
  164. }
  165. func (m *Model) NumVocab() int {
  166. return int(C.llama_n_vocab(m.c))
  167. }
  168. func (m *Model) TokenIsEog(token int) bool {
  169. return bool(C.llama_token_is_eog(m.c, C.llama_token(token)))
  170. }
  171. func (m *Model) ApplyLoraFromFile(loraPath string, scale float32, baseModelPath string, threads int) error {
  172. cLoraPath := C.CString(loraPath)
  173. defer C.free(unsafe.Pointer(cLoraPath))
  174. var cBaseModelPath *C.char
  175. if baseModelPath != "" {
  176. cBaseModelPath = C.CString(baseModelPath)
  177. }
  178. code := int(C.llama_model_apply_lora_from_file(m.c, cLoraPath, C.float(scale), cBaseModelPath, C.int32_t(threads)))
  179. if code != 0 {
  180. return errors.New("error applying lora from file")
  181. }
  182. return nil
  183. }
  184. type Batch struct {
  185. c C.struct_llama_batch
  186. }
  187. func NewBatch(nTokens int, embd int, maxSeq int) Batch {
  188. return Batch{c: C.llama_batch_init(C.int(nTokens), C.int(embd), C.int(maxSeq))}
  189. }
  190. func (b *Batch) NumTokens() int {
  191. return int(b.c.n_tokens)
  192. }
  193. // Add adds a token to the batch with the given position for the given
  194. // sequence ids, and optionally instructs to include logits.
  195. func (b *Batch) Add(token int, pos int, seqIds []int, logits bool) {
  196. unsafe.Slice(b.c.token, 512)[b.c.n_tokens] = C.llama_token(token)
  197. unsafe.Slice(b.c.pos, 512)[b.c.n_tokens] = C.llama_pos(pos)
  198. unsafe.Slice(b.c.n_seq_id, 512)[b.c.n_tokens] = C.int(len(seqIds))
  199. for i, s := range seqIds {
  200. unsafe.Slice((unsafe.Slice(b.c.seq_id, 512)[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
  201. }
  202. if logits {
  203. unsafe.Slice(b.c.logits, 512)[b.c.n_tokens] = 1
  204. }
  205. b.c.n_tokens += 1
  206. }
  207. func (b *Batch) Clear() {
  208. b.c.n_tokens = 0
  209. }
  210. func (b *Batch) Free() {
  211. C.llama_batch_free(b.c)
  212. }
  213. func BatchGetOne(tokens []int, pos0 int, seqId int) Batch {
  214. return Batch{c: C.llama_batch_get_one((*C.int)(unsafe.Pointer(&tokens[0])), C.int32_t(len(tokens)), C.int(pos0), C.int(seqId))}
  215. }
  216. type Model struct {
  217. c *C.struct_llama_model
  218. }
  219. func (m *Model) TokenToPiece(token int) string {
  220. buf := make([]byte, 12)
  221. C.llama_token_to_piece(
  222. m.c,
  223. C.int32_t(token),
  224. (*C.char)(unsafe.Pointer(&buf[0])),
  225. C.int32_t(12),
  226. C.bool(true),
  227. )
  228. return strings.TrimRight(string(buf), "\x00")
  229. }
  230. func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int, error) {
  231. maxTokens := len(text) + 2
  232. cTokens := make([]C.llama_token, maxTokens)
  233. cText := C.CString(text)
  234. defer C.free(unsafe.Pointer(cText))
  235. result := C.llama_tokenize(
  236. m.c,
  237. cText,
  238. C.int32_t(len(text)),
  239. &cTokens[0],
  240. C.int32_t(maxTokens),
  241. C.bool(addSpecial),
  242. C.bool(parseSpecial),
  243. )
  244. if result < 0 {
  245. return nil, fmt.Errorf("tokenization failed, required %d tokens", -result)
  246. }
  247. tokens := make([]int, result)
  248. for i := 0; i < int(result); i++ {
  249. tokens[i] = int(cTokens[i])
  250. }
  251. return tokens, nil
  252. }
  253. func (m *Model) NEmbd() int {
  254. return int(C.llama_n_embd(m.c))
  255. }
  256. func Quantize(infile, outfile string, ftype uint32) error {
  257. cinfile := C.CString(infile)
  258. defer C.free(unsafe.Pointer(cinfile))
  259. coutfile := C.CString(outfile)
  260. defer C.free(unsafe.Pointer(coutfile))
  261. params := C.llama_model_quantize_default_params()
  262. params.nthread = -1
  263. params.ftype = ftype
  264. if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
  265. return fmt.Errorf("llama_model_quantize: %d", rc)
  266. }
  267. return nil
  268. }
  269. // llava
  270. type ClipContext struct {
  271. c *C.struct_clip_ctx
  272. }
  273. func NewClipContext(modelPath string) *ClipContext {
  274. mp := C.CString(modelPath)
  275. defer C.free(unsafe.Pointer(mp))
  276. cc := C.clip_model_load(mp, 1)
  277. return &ClipContext{c: cc}
  278. }
  279. type LlavaContext struct {
  280. c *C.struct_llava_context
  281. }
  282. type LlavaImageEmbed struct {
  283. c *C.struct_llava_image_embed
  284. }
  285. func NewLlavaImageEmbed(clipContext *ClipContext, data []byte) *LlavaImageEmbed {
  286. return &LlavaImageEmbed{c: C.llava_image_embed_make_with_bytes(clipContext.c, C.int(runtime.NumCPU()), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data)))}
  287. }
  288. func LlavaEvalImageEmbed(llamaContext *Context, embed *LlavaImageEmbed, nBatch int, nPast *int) {
  289. C.llava_eval_image_embed(llamaContext.c, embed.c, C.int(nBatch), (*C.int)(unsafe.Pointer(nPast)))
  290. }
  291. // sampling
  292. // TODO: this is a temporary wrapper to allow calling C++ code from CGo
  293. type SamplingContext struct {
  294. c *C.struct_llama_sampling_context
  295. }
  296. type SamplingParams struct {
  297. TopK int
  298. TopP float32
  299. TfsZ float32
  300. TypicalP float32
  301. Temp float32
  302. PenaltyRepeat float32
  303. PenaltyFreq float32
  304. PenaltyPresent float32
  305. Mirostat int
  306. MirostatTau float32
  307. MirostatEta float32
  308. PenalizeNl bool
  309. Seed uint32
  310. Grammar string
  311. }
  312. func NewSamplingContext(params SamplingParams) *SamplingContext {
  313. var cparams C.struct_llama_sampling_cparams
  314. cparams.top_k = C.int32_t(params.TopK)
  315. cparams.top_p = C.float(params.TopP)
  316. cparams.tfs_z = C.float(params.TfsZ)
  317. cparams.typical_p = C.float(params.TypicalP)
  318. cparams.temp = C.float(params.Temp)
  319. cparams.penalty_repeat = C.float(params.PenaltyRepeat)
  320. cparams.penalty_freq = C.float(params.PenaltyFreq)
  321. cparams.penalty_present = C.float(params.PenaltyFreq)
  322. cparams.mirostat = C.int32_t(params.Mirostat)
  323. cparams.mirostat_tau = C.float(params.MirostatTau)
  324. cparams.mirostat_eta = C.float(params.MirostatEta)
  325. cparams.penalize_nl = C.bool(params.PenalizeNl)
  326. cparams.seed = C.uint32_t(params.Seed)
  327. grammar := C.CString(params.Grammar)
  328. defer C.free(unsafe.Pointer(grammar))
  329. cparams.grammar = grammar
  330. return &SamplingContext{c: C.llama_sampling_cinit(&cparams)}
  331. }
  332. func (s *SamplingContext) Free() {
  333. C.llama_sampling_cfree(s.c)
  334. }
  335. func (s *SamplingContext) Reset() {
  336. C.llama_sampling_creset(s.c)
  337. }
  338. func (s *SamplingContext) Sample(ctxMain *Context, ctxConfig *Context, idx int) int {
  339. // TODO (jmorganca): handle nil for all args
  340. if ctxConfig == nil {
  341. return int(C.llama_sampling_csample(s.c, ctxMain.c, nil, C.int(idx)))
  342. }
  343. return int(C.llama_sampling_csample(s.c, ctxMain.c, ctxConfig.c, C.int(idx)))
  344. }
  345. func (s *SamplingContext) Accept(ctxMain *Context, id int, applyGrammar bool) {
  346. C.llama_sampling_caccept(s.c, ctxMain.c, C.llama_token(id), C.bool(applyGrammar))
  347. }