llama.go 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760
  1. package llama
  2. //go:generate make -j 8
  3. /*
  4. #cgo CFLAGS: -O2 -std=c11 -DGGML_BUILD=1 -DNDEBUG -DLOG_DISABLE_LOGS -DGGML_USE_LLAMAFILE
  5. #cgo CXXFLAGS: -O2 -std=c++11 -DGGML_BUILD=1 -DNDEBUG -DLOG_DISABLE_LOGS -DGGML_USE_LLAMAFILE
  6. #cgo amd64,avx CFLAGS: -mavx
  7. #cgo amd64,avx CXXFLAGS: -mavx
  8. #cgo amd64,avx2 CFLAGS: -mavx2 -mfma
  9. #cgo amd64,avx2 CXXFLAGS: -mavx2 -mfma
  10. #cgo amd64,f16c CFLAGS: -mf16c
  11. #cgo amd64,f16c CXXFLAGS: -mf16c
  12. #cgo amd64,fma CFLAGS: -mfma
  13. #cgo amd64,fma CXXFLAGS: -mfma
  14. #cgo avx CFLAGS: -mavx
  15. #cgo avx CXXFLAGS: -mavx
  16. #cgo avx2 CFLAGS: -mavx2 -mfma -mf16c
  17. #cgo avx2 CXXFLAGS: -mavx2 -mfma -mf16c
  18. #cgo cuda CFLAGS: -fPIE -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  19. #cgo cuda CFLAGS: -fPIE -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  20. #cgo cuda CXXFLAGS: -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  21. #cgo cuda CXXFLAGS: -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  22. #cgo cuda_jetpack5 LDFLAGS: -lggml_cuda_jetpack5 -L/usr/local/cuda-11/lib64
  23. #cgo cuda_jetpack6 LDFLAGS: -lggml_cuda_jetpack6 -L/usr/local/cuda-12/lib64
  24. #cgo cuda_v11 LDFLAGS: -lggml_cuda_v11 -L/usr/local/cuda-11/lib64
  25. #cgo cuda_v12 LDFLAGS: -lggml_cuda_v12 -L/usr/local/cuda-12/lib64
  26. #cgo darwin,amd64 CFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
  27. #cgo darwin,amd64 CXXFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
  28. #cgo darwin,amd64 LDFLAGS: -framework Foundation
  29. #cgo darwin,amd64,avx2 CFLAGS: -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
  30. #cgo darwin,amd64,avx2 CXXFLAGS: -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
  31. #cgo darwin,amd64,avx2 LDFLAGS: -framework Accelerate
  32. #cgo darwin,arm64 CFLAGS: -DGGML_USE_METAL -DGGML_USE_ACCELERATE -DGGML_METAL_EMBED_LIBRARY -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64 -DGGML_USE_BLAS
  33. #cgo darwin,arm64 CXXFLAGS: -DGGML_USE_METAL -DGGML_USE_ACCELERATE -DGGML_METAL_EMBED_LIBRARY -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64 -DGGML_USE_BLAS
  34. #cgo darwin,arm64 LDFLAGS: -framework Foundation -framework Metal -framework MetalKit -framework Accelerate
  35. #cgo linux CFLAGS: -D_GNU_SOURCE
  36. #cgo linux CXXFLAGS: -D_GNU_SOURCE
  37. #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/Linux/amd64
  38. #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/Linux/amd64
  39. #cgo linux,arm64 CFLAGS: -D__aarch64__ -D__ARM_NEON -D__ARM_FEATURE_FMA
  40. #cgo linux,arm64 CXXFLAGS: -D__aarch64__ -D__ARM_NEON -D__ARM_FEATURE_FMA
  41. #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/Linux/arm64
  42. #cgo linux,arm64,sve CFLAGS: -march=armv8.6-a+sve
  43. #cgo linux,arm64,sve CXXFLAGS: -march=armv8.6-a+sve
  44. #cgo linux,cuda LDFLAGS: -lcuda -lcudart -lcublas -lcublasLt -lpthread -ldl -lrt -lresolv
  45. #cgo linux,rocm LDFLAGS: -L/opt/rocm/lib -lpthread -ldl -lrt -lresolv
  46. #cgo rocm CFLAGS: -DGGML_USE_CUDA -DGGML_USE_HIPBLAS -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  47. #cgo rocm CXXFLAGS: -DGGML_USE_CUDA -DGGML_USE_HIPBLAS -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  48. #cgo rocm LDFLAGS: -L${SRCDIR} -lggml_rocm -lhipblas -lamdhip64 -lrocblas
  49. #cgo windows CFLAGS: -Wno-discarded-qualifiers -D_WIN32_WINNT=0x602
  50. #cgo windows CXXFLAGS: -D_WIN32_WINNT=0x602
  51. #cgo windows LDFLAGS: -lmsvcrt
  52. #cgo windows LDFLAGS: -lmsvcrt -static-libstdc++ -static-libgcc -static
  53. #cgo windows,amd64 LDFLAGS: -L${SRCDIR}/build/Windows/amd64
  54. #cgo windows,amd64 LDFLAGS: -L${SRCDIR}/build/Windows/amd64
  55. #cgo windows,arm64 CFLAGS: -D__aarch64__ -D__ARM_NEON -D__ARM_FEATURE_FMA
  56. #cgo windows,arm64 CXXFLAGS: -D__aarch64__ -D__ARM_NEON -D__ARM_FEATURE_FMA
  57. #cgo windows,arm64 LDFLAGS: -L${SRCDIR}/build/Windows/arm64
  58. #cgo windows,arm64 LDFLAGS: -L${SRCDIR}/build/Windows/arm64
  59. #cgo windows,cuda LDFLAGS: -lcuda -lcudart -lcublas -lcublasLt
  60. #cgo windows,rocm LDFLAGS: -lggml_rocm -lhipblas -lamdhip64 -lrocblas
  61. #include <stdlib.h>
  62. #include "llama.h"
  63. #include "clip.h"
  64. #include "ggml.h"
  65. #include "llava.h"
  66. #include "mllama.h"
  67. bool llamaProgressCallback(float progress, void *user_data);
  68. typedef enum {COMP_UNKNOWN,COMP_GCC,COMP_CLANG} COMPILER;
  69. COMPILER inline get_compiler() {
  70. #if defined(__clang__)
  71. return COMP_CLANG;
  72. #elif defined(__GNUC__)
  73. return COMP_GCC;
  74. #else
  75. return UNKNOWN_COMPILER;
  76. #endif
  77. }
  78. */
  79. import "C"
  80. import (
  81. _ "embed"
  82. "errors"
  83. "fmt"
  84. "math"
  85. "runtime"
  86. "runtime/cgo"
  87. "slices"
  88. "strings"
  89. "unsafe"
  90. )
  91. var CpuFeatures = ""
  92. func BackendInit() {
  93. C.llama_backend_init()
  94. }
  95. func PrintSystemInfo() string {
  96. var compiler string
  97. switch C.get_compiler() {
  98. case C.COMP_UNKNOWN:
  99. compiler = "cgo(unknown_compiler)"
  100. case C.COMP_GCC:
  101. compiler = "cgo(gcc)"
  102. case C.COMP_CLANG:
  103. compiler = "cgo(clang)"
  104. }
  105. return C.GoString(C.llama_print_system_info()) + compiler
  106. }
  107. func GetModelArch(modelPath string) (string, error) {
  108. mp := C.CString(modelPath)
  109. defer C.free(unsafe.Pointer(mp))
  110. gguf_ctx := C.gguf_init_from_file(mp, C.struct_gguf_init_params{no_alloc: true, ctx: (**C.struct_ggml_context)(C.NULL)})
  111. if gguf_ctx == nil {
  112. return "", errors.New("unable to load model file")
  113. }
  114. defer C.gguf_free(gguf_ctx)
  115. key := C.CString("general.architecture")
  116. defer C.free(unsafe.Pointer(key))
  117. arch_index := C.gguf_find_key(gguf_ctx, key)
  118. if int(arch_index) < 0 {
  119. return "", errors.New("unknown model architecture")
  120. }
  121. arch := C.gguf_get_val_str(gguf_ctx, arch_index)
  122. return C.GoString(arch), nil
  123. }
  124. type ContextParams struct {
  125. c C.struct_llama_context_params
  126. }
  127. func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool) ContextParams {
  128. params := C.llama_context_default_params()
  129. params.n_ctx = C.uint(numCtx)
  130. params.n_batch = C.uint(batchSize)
  131. params.n_seq_max = C.uint(numSeqMax)
  132. params.n_threads = C.int(threads)
  133. params.n_threads_batch = params.n_threads
  134. params.embeddings = C.bool(true)
  135. params.flash_attn = C.bool(flashAttention)
  136. return ContextParams{c: params}
  137. }
  138. type Context struct {
  139. c *C.struct_llama_context
  140. numThreads int
  141. }
  142. var ErrKvCacheFull = errors.New("could not find a kv cache slot")
  143. func (c *Context) Decode(batch *Batch) error {
  144. // Positive return values does not mean a fatal error, but rather a warning.
  145. // 0 - success
  146. // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
  147. // < 0 - error
  148. code := int(C.llama_decode(c.c, batch.c))
  149. if code < 0 {
  150. return fmt.Errorf("llama_decode failed with code %d", code)
  151. }
  152. if code > 0 {
  153. return ErrKvCacheFull
  154. }
  155. return nil
  156. }
  157. func (c *Context) Model() *Model {
  158. return &Model{c: C.llama_get_model(c.c)}
  159. }
  160. func (c *Context) GetLogitsIth(i int) ([]float32, error) {
  161. logits := (*float32)(unsafe.Pointer(C.llama_get_logits_ith(c.c, C.int(i))))
  162. if logits == nil {
  163. return nil, errors.New("unable to get logits")
  164. }
  165. return unsafe.Slice(logits, c.Model().NumVocab()), nil
  166. }
  167. func (c *Context) KvCacheSeqAdd(seqId int, p0 int, p1 int, delta int) {
  168. C.llama_kv_cache_seq_add(c.c, C.int(seqId), C.int(p0), C.int(p1), C.int(delta))
  169. }
  170. func (c *Context) KvCacheSeqRm(seqId int, p0 int, p1 int) bool {
  171. return bool(C.llama_kv_cache_seq_rm(c.c, C.int(seqId), C.int(p0), C.int(p1)))
  172. }
  173. func (c *Context) KvCacheSeqCp(srcSeqId int, dstSeqId int, p0 int, p1 int) {
  174. C.llama_kv_cache_seq_cp(c.c, C.int(srcSeqId), C.int(dstSeqId), C.int(p0), C.int(p1))
  175. }
  176. func (c *Context) KvCacheClear() {
  177. C.llama_kv_cache_clear(c.c)
  178. }
  179. func (c *Context) KvCacheDefrag() {
  180. C.llama_kv_cache_defrag(c.c)
  181. }
  182. // Get the embeddings for a sequence id
  183. func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
  184. embeddings := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
  185. if embeddings == nil {
  186. return nil
  187. }
  188. return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
  189. }
  190. func (c *Context) GetEmbeddingsIth(i int) []float32 {
  191. embeddings := unsafe.Pointer(C.llama_get_embeddings_ith(c.c, C.int32_t(i)))
  192. if embeddings == nil {
  193. return nil
  194. }
  195. return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
  196. }
  197. type ModelParams struct {
  198. NumGpuLayers int
  199. MainGpu int
  200. UseMmap bool
  201. UseMlock bool
  202. TensorSplit []float32
  203. Progress func(float32)
  204. VocabOnly bool
  205. }
  206. //export llamaProgressCallback
  207. func llamaProgressCallback(progress C.float, userData unsafe.Pointer) C.bool {
  208. handle := *(*cgo.Handle)(userData)
  209. callback := handle.Value().(func(float32))
  210. callback(float32(progress))
  211. return true
  212. }
  213. func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
  214. cparams := C.llama_model_default_params()
  215. cparams.n_gpu_layers = C.int(params.NumGpuLayers)
  216. cparams.main_gpu = C.int32_t(params.MainGpu)
  217. cparams.use_mmap = C.bool(params.UseMmap)
  218. cparams.use_mlock = C.bool(params.UseMlock)
  219. cparams.vocab_only = C.bool(params.VocabOnly)
  220. if len(params.TensorSplit) > 0 {
  221. tensorSplitData := &params.TensorSplit[0]
  222. var tensorSplitPin runtime.Pinner
  223. tensorSplitPin.Pin(tensorSplitData)
  224. defer tensorSplitPin.Unpin()
  225. cparams.tensor_split = (*C.float)(unsafe.Pointer(tensorSplitData))
  226. }
  227. if params.Progress != nil {
  228. handle := cgo.NewHandle(params.Progress)
  229. defer handle.Delete()
  230. var handlePin runtime.Pinner
  231. handlePin.Pin(&handle)
  232. defer handlePin.Unpin()
  233. cparams.progress_callback = C.llama_progress_callback(C.llamaProgressCallback)
  234. cparams.progress_callback_user_data = unsafe.Pointer(&handle)
  235. }
  236. m := Model{c: C.llama_load_model_from_file(C.CString(modelPath), cparams)}
  237. if m.c == nil {
  238. return nil, fmt.Errorf("unable to load model: %s", modelPath)
  239. }
  240. return &m, nil
  241. }
  242. func FreeModel(model *Model) {
  243. C.llama_free_model(model.c)
  244. }
  245. func NewContextWithModel(model *Model, params ContextParams) (*Context, error) {
  246. c := Context{
  247. c: C.llama_new_context_with_model(model.c, params.c),
  248. numThreads: int(params.c.n_threads),
  249. }
  250. if c.c == nil {
  251. return nil, errors.New("unable to create llama context")
  252. }
  253. return &c, nil
  254. }
  255. func (m *Model) NumVocab() int {
  256. return int(C.llama_n_vocab(m.c))
  257. }
  258. func (m *Model) TokenIsEog(token int) bool {
  259. return bool(C.llama_token_is_eog(m.c, C.llama_token(token)))
  260. }
  261. func (m *Model) AddBOSToken() bool {
  262. return bool(C.llama_add_bos_token(m.c))
  263. }
  264. func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float32, threads int) error {
  265. cLoraPath := C.CString(loraPath)
  266. defer C.free(unsafe.Pointer(cLoraPath))
  267. loraAdapter := C.llama_lora_adapter_init(m.c, cLoraPath)
  268. if loraAdapter == nil {
  269. return errors.New("unable to load lora")
  270. }
  271. err := -1
  272. if loraAdapter != nil {
  273. err = int(C.llama_lora_adapter_set(context.c, loraAdapter, C.float(scale)))
  274. }
  275. if err != 0 {
  276. return errors.New("error applying lora from file")
  277. }
  278. return nil
  279. }
  280. type Batch struct {
  281. c C.struct_llama_batch
  282. batchSize int
  283. maxSeq int
  284. embedSize int
  285. }
  286. // Creates a new batch for either word tokens or image embeddings (if embedSize is non-zero).
  287. // Batches cannot contain both types at the same time. batchSize is the maximum number of entries
  288. // that can be added per sequence
  289. func NewBatch(batchSize int, maxSeq int, embedSize int) (*Batch, error) {
  290. b := Batch{
  291. c: C.llama_batch_init(C.int(batchSize*maxSeq), C.int(embedSize), C.int(maxSeq)),
  292. batchSize: batchSize,
  293. maxSeq: maxSeq,
  294. embedSize: embedSize,
  295. }
  296. // Check to see if any of the allocations in llama_batch_init() failed
  297. nilPointer := (embedSize == 0 && b.c.token == nil) || (embedSize != 0 && b.c.embd == nil) ||
  298. b.c.pos == nil || b.c.n_seq_id == nil || b.c.seq_id == nil || b.c.logits == nil ||
  299. slices.Contains(unsafe.Slice(b.c.seq_id, b.allocSize()), nil)
  300. if nilPointer {
  301. C.llama_batch_free(b.c)
  302. return nil, fmt.Errorf("unable to allocate batch (batchSize=%v maxSeq=%v embedSize=%v)", batchSize, maxSeq, embedSize)
  303. }
  304. return &b, nil
  305. }
  306. func (b *Batch) Size() int {
  307. return b.batchSize
  308. }
  309. func (b *Batch) allocSize() int {
  310. return b.batchSize * b.maxSeq
  311. }
  312. func (b *Batch) NumTokens() int {
  313. return int(b.c.n_tokens)
  314. }
  315. func (b *Batch) IsEmbedding() bool {
  316. return b.embedSize != 0
  317. }
  318. // Add adds either a token or an image embedding to the batch depending on the type
  319. // when the batch was initialized. The other argument will be ignored. Adds to the
  320. // batch with the given position for the given sequence ids, and optionally instructs
  321. // to include logits.
  322. func (b *Batch) Add(token int, embed []float32, pos int, logits bool, seqIds ...int) {
  323. if !b.IsEmbedding() {
  324. unsafe.Slice(b.c.token, b.allocSize())[b.c.n_tokens] = C.llama_token(token)
  325. } else {
  326. copy(unsafe.Slice((*float32)(b.c.embd), b.allocSize()*b.embedSize)[int(b.c.n_tokens)*b.embedSize:], embed)
  327. }
  328. unsafe.Slice(b.c.pos, b.allocSize())[b.c.n_tokens] = C.llama_pos(pos)
  329. unsafe.Slice(b.c.n_seq_id, b.allocSize())[b.c.n_tokens] = C.int(len(seqIds))
  330. for i, s := range seqIds {
  331. unsafe.Slice((unsafe.Slice(b.c.seq_id, b.allocSize())[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
  332. }
  333. if logits {
  334. unsafe.Slice(b.c.logits, b.allocSize())[b.c.n_tokens] = 1
  335. } else {
  336. unsafe.Slice(b.c.logits, b.allocSize())[b.c.n_tokens] = 0
  337. }
  338. b.c.n_tokens += 1
  339. }
  340. func (b *Batch) Clear() {
  341. b.c.n_tokens = 0
  342. }
  343. func (b *Batch) Free() {
  344. b.batchSize = 0
  345. C.llama_batch_free(b.c)
  346. }
  347. type Model struct {
  348. c *C.struct_llama_model
  349. }
  350. func (m *Model) TokenToPiece(token int) string {
  351. tokenLen := 12
  352. buf := make([]byte, tokenLen)
  353. tokenLen = int(C.llama_token_to_piece(
  354. m.c,
  355. C.int32_t(token),
  356. (*C.char)(unsafe.Pointer(&buf[0])),
  357. C.int32_t(tokenLen),
  358. C.int32_t(0),
  359. C.bool(true),
  360. ))
  361. if tokenLen < 0 {
  362. tokenLen = -tokenLen
  363. buf = make([]byte, tokenLen)
  364. C.llama_token_to_piece(
  365. m.c,
  366. C.int32_t(token),
  367. (*C.char)(unsafe.Pointer(&buf[0])),
  368. C.int32_t(tokenLen),
  369. C.int32_t(0),
  370. C.bool(true),
  371. )
  372. }
  373. return strings.TrimRight(string(buf), "\x00")
  374. }
  375. func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int, error) {
  376. maxTokens := len(text) + 2
  377. cTokens := make([]C.llama_token, maxTokens)
  378. cText := C.CString(text)
  379. defer C.free(unsafe.Pointer(cText))
  380. result := C.llama_tokenize(
  381. m.c,
  382. cText,
  383. C.int32_t(len(text)),
  384. &cTokens[0],
  385. C.int32_t(maxTokens),
  386. C.bool(addSpecial),
  387. C.bool(parseSpecial),
  388. )
  389. // if the result is negative, reallocate and retry with the correct buffer size
  390. if result < 0 {
  391. maxTokens = int(-result)
  392. cTokens = make([]C.llama_token, maxTokens)
  393. result = C.llama_tokenize(
  394. m.c,
  395. cText,
  396. C.int32_t(len(text)),
  397. &cTokens[0],
  398. C.int32_t(maxTokens),
  399. C.bool(addSpecial),
  400. C.bool(parseSpecial),
  401. )
  402. if result < 0 {
  403. return nil, fmt.Errorf("tokenization failed, required %d tokens", -result)
  404. }
  405. }
  406. tokens := make([]int, result)
  407. for i := range result {
  408. tokens[i] = int(cTokens[i])
  409. }
  410. return tokens, nil
  411. }
  412. func (m *Model) NEmbd() int {
  413. return int(C.llama_n_embd(m.c))
  414. }
  415. func Quantize(infile, outfile string, ftype uint32) error {
  416. cinfile := C.CString(infile)
  417. defer C.free(unsafe.Pointer(cinfile))
  418. coutfile := C.CString(outfile)
  419. defer C.free(unsafe.Pointer(coutfile))
  420. params := C.llama_model_quantize_default_params()
  421. params.nthread = -1
  422. params.ftype = ftype
  423. if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
  424. return fmt.Errorf("llama_model_quantize: %d", rc)
  425. }
  426. return nil
  427. }
  428. // vision processing
  429. type ClipContext struct {
  430. c *C.struct_clip_ctx
  431. }
  432. func NewClipContext(llamaContext *Context, modelPath string) (*ClipContext, error) {
  433. mp := C.CString(modelPath)
  434. defer C.free(unsafe.Pointer(mp))
  435. c := C.clip_model_load(mp, 1)
  436. if c == nil {
  437. return nil, fmt.Errorf("unable to load clip model: %v", modelPath)
  438. }
  439. projEmbedSize := int(C.clip_n_mmproj_embd(c))
  440. modelEmbedSize := llamaContext.Model().NEmbd()
  441. if projEmbedSize != modelEmbedSize {
  442. return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
  443. }
  444. return &ClipContext{c: c}, nil
  445. }
  446. func (c *ClipContext) Free() {
  447. C.clip_free(c.c)
  448. }
  449. func (c *ClipContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32, error) {
  450. l := C.llava_image_embed_make_with_bytes(c.c, C.int(llamaContext.numThreads), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data)))
  451. if l == nil {
  452. return nil, errors.New("unable to make llava embedding from image")
  453. }
  454. numTokens := int(l.n_image_pos)
  455. numEmbed := llamaContext.Model().NEmbd()
  456. s := unsafe.Slice((*float32)(l.embed), numEmbed*numTokens)
  457. embed := make([][]float32, numTokens)
  458. rows := make([]float32, len(s))
  459. copy(rows, s)
  460. for i := range embed {
  461. embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
  462. }
  463. C.llava_image_embed_free(l)
  464. return embed, nil
  465. }
  466. type MllamaContext struct {
  467. c *C.struct_mllama_ctx
  468. }
  469. func NewMllamaContext(llamaContext *Context, modelPath string) (*MllamaContext, error) {
  470. mp := C.CString(modelPath)
  471. defer C.free(unsafe.Pointer(mp))
  472. c := C.mllama_model_load(mp, 1)
  473. if c == nil {
  474. return nil, fmt.Errorf("unable to load mllama model: %v", modelPath)
  475. }
  476. projEmbedSize := int(C.mllama_n_embd(c))
  477. modelEmbedSize := llamaContext.Model().NEmbd()
  478. if projEmbedSize != modelEmbedSize {
  479. return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
  480. }
  481. return &MllamaContext{c: c}, nil
  482. }
  483. func (m *MllamaContext) Free() {
  484. C.mllama_free(m.c)
  485. }
  486. func (m *MllamaContext) NewEmbed(llamaContext *Context, data []byte, aspectRatioId int) ([][]float32, error) {
  487. img := C.mllama_image_init()
  488. defer C.mllama_image_free(img)
  489. ok := bool(C.mllama_image_load_from_data(unsafe.Pointer(&data[0]), C.int(len(data)), 560, 560, 3, 4, C.int(aspectRatioId), img))
  490. if !ok {
  491. return nil, errors.New("unable to load mllama image data")
  492. }
  493. rows := make([]float32, m.EmbedSize(llamaContext))
  494. ok = bool(C.mllama_image_encode(m.c, C.int(llamaContext.numThreads), img, (*C.float)(unsafe.Pointer(&rows[0]))))
  495. if !ok {
  496. return nil, errors.New("unable to make mllama embedding from image")
  497. }
  498. embed := make([][]float32, 1)
  499. embed[0] = rows
  500. return embed, nil
  501. }
  502. func (m *MllamaContext) EmbedSize(llamaContext *Context) int {
  503. numTokens := int(C.mllama_n_positions(m.c) * C.mllama_n_tiles(m.c))
  504. numEmbed := llamaContext.Model().NEmbd()
  505. return numTokens * numEmbed
  506. }
  507. func (c *Context) SetCrossAttention(state bool) {
  508. C.llama_set_cross_attention(c.c, C.bool(state))
  509. }
  510. func (c *Context) Synchronize() {
  511. C.llama_synchronize(c.c)
  512. }
  513. // sampling
  514. type SamplingParams struct {
  515. TopK int
  516. TopP float32
  517. MinP float32
  518. TfsZ float32
  519. TypicalP float32
  520. Temp float32
  521. RepeatLastN int
  522. PenaltyRepeat float32
  523. PenaltyFreq float32
  524. PenaltyPresent float32
  525. Mirostat int
  526. MirostatTau float32
  527. MirostatEta float32
  528. PenalizeNl bool
  529. Seed uint32
  530. Grammar string
  531. }
  532. type SamplingContext struct {
  533. chain *C.struct_llama_sampler
  534. grammar *C.struct_llama_sampler
  535. }
  536. func NewSamplingContext(model *Model, params SamplingParams) (*SamplingContext, error) {
  537. var s SamplingContext
  538. runtime.SetFinalizer(&s, func(s *SamplingContext) { s.free() })
  539. sparams := C.llama_sampler_chain_default_params()
  540. s.chain = C.llama_sampler_chain_init(sparams)
  541. grammar := C.CString(params.Grammar)
  542. defer C.free(unsafe.Pointer(grammar))
  543. root := C.CString("root")
  544. defer C.free(unsafe.Pointer(root))
  545. s.grammar = C.llama_sampler_init_grammar(model.c, grammar, root)
  546. C.llama_sampler_chain_add(s.chain,
  547. C.llama_sampler_init_penalties(
  548. C.llama_n_vocab(model.c),
  549. C.llama_token_eos(model.c),
  550. C.llama_token_nl(model.c),
  551. C.int32_t(params.RepeatLastN),
  552. C.float(params.PenaltyRepeat),
  553. C.float(params.PenaltyFreq),
  554. C.float(params.PenaltyPresent),
  555. C.bool(params.PenalizeNl),
  556. false))
  557. if params.Temp > 0 {
  558. switch params.Mirostat {
  559. case 0:
  560. C.llama_sampler_chain_add(s.chain, C.llama_sampler_init_top_k(C.int32_t(params.TopK)))
  561. C.llama_sampler_chain_add(s.chain, C.llama_sampler_init_tail_free(C.float(params.TfsZ), 0))
  562. C.llama_sampler_chain_add(s.chain, C.llama_sampler_init_typical(C.float(params.TypicalP), 0))
  563. C.llama_sampler_chain_add(s.chain, C.llama_sampler_init_top_p(C.float(params.TopP), 0))
  564. C.llama_sampler_chain_add(s.chain, C.llama_sampler_init_min_p(C.float(params.MinP), 0))
  565. C.llama_sampler_chain_add(s.chain, C.llama_sampler_init_temp(C.float(params.Temp)))
  566. C.llama_sampler_chain_add(s.chain, C.llama_sampler_init_softmax())
  567. C.llama_sampler_chain_add(s.chain, C.llama_sampler_init_dist(C.uint32_t(params.Seed)))
  568. case 1:
  569. C.llama_sampler_chain_add(s.chain, C.llama_sampler_init_temp(C.float(params.Temp)))
  570. C.llama_sampler_chain_add(s.chain, C.llama_sampler_init_mirostat(C.llama_n_vocab(model.c),
  571. C.uint32_t(params.Seed), C.float(params.MirostatTau), C.float(params.MirostatEta), 100))
  572. case 2:
  573. C.llama_sampler_chain_add(s.chain, C.llama_sampler_init_temp(C.float(params.Temp)))
  574. C.llama_sampler_chain_add(s.chain, C.llama_sampler_init_mirostat_v2(C.uint32_t(params.Seed),
  575. C.float(params.MirostatTau), C.float(params.MirostatEta)))
  576. default:
  577. return nil, fmt.Errorf("sampling: unknown mirostat version: %v", params.Mirostat)
  578. }
  579. } else {
  580. C.llama_sampler_chain_add(s.chain, C.llama_sampler_init_greedy())
  581. }
  582. return &s, nil
  583. }
  584. func (s *SamplingContext) Sample(llamaContext *Context, idx int) (int, error) {
  585. logits, err := llamaContext.GetLogitsIth(idx)
  586. if err != nil {
  587. return 0, err
  588. }
  589. numVocab := llamaContext.Model().NumVocab()
  590. tokenData := make([]C.llama_token_data, numVocab)
  591. var tokenDataPin runtime.Pinner
  592. tokenDataPin.Pin(&tokenData[0])
  593. defer tokenDataPin.Unpin()
  594. for i := range tokenData {
  595. tokenData[i] = C.llama_token_data{id: C.llama_token(i), logit: C.float(logits[i])}
  596. }
  597. tokenDataArray := C.llama_token_data_array{data: &tokenData[0], size: C.size_t(len(tokenData)), selected: -1}
  598. C.llama_sampler_apply(s.chain, &tokenDataArray)
  599. id := tokenData[tokenDataArray.selected].id
  600. // Check if the selected token is allowed by the grammar
  601. // If it is allowed then return it, otherwise evaluate the grammar on all
  602. // tokens and resample (slow)
  603. tokenData[0] = C.llama_token_data{id: id, logit: 1}
  604. tokenDataArray = C.llama_token_data_array{data: &tokenData[0], size: 1, selected: -1}
  605. C.llama_sampler_apply(s.grammar, &tokenDataArray)
  606. if !math.IsInf(float64(tokenData[0].logit), -1) {
  607. return int(id), nil
  608. }
  609. for i := range tokenData {
  610. tokenData[i] = C.llama_token_data{id: C.llama_token(i), logit: C.float(logits[i])}
  611. }
  612. tokenDataArray = C.llama_token_data_array{data: &tokenData[0], size: C.size_t(len(tokenData)), selected: -1}
  613. C.llama_sampler_apply(s.grammar, &tokenDataArray)
  614. C.llama_sampler_apply(s.chain, &tokenDataArray)
  615. return int(tokenData[tokenDataArray.selected].id), nil
  616. }
  617. func (s *SamplingContext) Accept(id int, applyGrammar bool) {
  618. if applyGrammar {
  619. C.llama_sampler_accept(s.grammar, C.llama_token(id))
  620. }
  621. C.llama_sampler_accept(s.chain, C.llama_token(id))
  622. }
  623. func (s *SamplingContext) free() {
  624. if s != nil {
  625. C.llama_sampler_free(s.grammar)
  626. C.llama_sampler_free(s.chain)
  627. }
  628. }