llama.go 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764
  1. package llama
  2. //go:generate make -j 8
  3. /*
  4. #cgo CFLAGS: -O3 -std=c17 -DGGML_BUILD=1 -DNDEBUG -DLOG_DISABLE_LOGS -DGGML_USE_LLAMAFILE -DGGML_USE_CPU -DGGML_USE_CPU_AARCH64
  5. #cgo CXXFLAGS: -O3 -std=c++17 -DGGML_BUILD=1 -DNDEBUG -DLOG_DISABLE_LOGS -DGGML_USE_LLAMAFILE -DGGML_USE_CPU -DGGML_USE_CPU_AARCH64
  6. #cgo amd64,avx CFLAGS: -mavx
  7. #cgo amd64,avx CXXFLAGS: -mavx
  8. #cgo amd64,avx2 CFLAGS: -mavx2 -mfma -mf16c
  9. #cgo amd64,avx2 CXXFLAGS: -mavx2 -mfma -mf16c
  10. #cgo amd64,avx512 CFLAGS: -mavx512f -mavx512dq -mavx512bw
  11. #cgo amd64,avx512 CXXFLAGS: -mavx512f -mavx512dq -mavx512bw
  12. #cgo amd64,avx512bf16 CFLAGS: -mavx512bf16 -D__AVX512BF16__
  13. #cgo amd64,avx512bf16 CXXFLAGS: -mavx512bf16 -D__AVX512BF16__
  14. #cgo amd64,avx512vbmi CFLAGS: -mavx512vbmi -D__AVX512VBMI__
  15. #cgo amd64,avx512vbmi CXXFLAGS: -mavx512vbmi -D__AVX512VBMI__
  16. #cgo amd64,avx512vnni CFLAGS: -mavx512vnni -D__AVX512VNNI__
  17. #cgo amd64,avx512vnni CXXFLAGS: -mavx512vnni -D__AVX512VNNI__
  18. #cgo amd64,f16c CFLAGS: -mf16c
  19. #cgo amd64,f16c CXXFLAGS: -mf16c
  20. #cgo amd64,fma CFLAGS: -mfma
  21. #cgo amd64,fma CXXFLAGS: -mfma
  22. #cgo cuda CFLAGS: -fPIE -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  23. #cgo cuda CXXFLAGS: -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  24. #cgo cuda_jetpack5 LDFLAGS: -lggml_cuda_jetpack5
  25. #cgo cuda_jetpack6 LDFLAGS: -lggml_cuda_jetpack6
  26. #cgo cuda_v11 LDFLAGS: -lggml_cuda_v11
  27. #cgo cuda_v12 LDFLAGS: -lggml_cuda_v12
  28. #cgo darwin,amd64 CFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
  29. #cgo darwin,amd64 CXXFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
  30. #cgo darwin,amd64 LDFLAGS: -framework Foundation
  31. #cgo darwin,amd64,avx2 CFLAGS: -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
  32. #cgo darwin,amd64,avx2 CXXFLAGS: -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
  33. #cgo darwin,amd64,avx2 LDFLAGS: -framework Accelerate
  34. #cgo darwin,arm64 CFLAGS: -DGGML_USE_METAL -DGGML_USE_ACCELERATE -DGGML_METAL_EMBED_LIBRARY -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64 -DGGML_USE_BLAS -DGGML_BLAS_USE_ACCELERATE
  35. #cgo darwin,arm64 CXXFLAGS: -DGGML_USE_METAL -DGGML_USE_ACCELERATE -DGGML_METAL_EMBED_LIBRARY -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64 -DGGML_USE_BLAS -DGGML_BLAS_USE_ACCELERATE
  36. #cgo darwin,arm64 LDFLAGS: -framework Foundation -framework Metal -framework MetalKit -framework Accelerate
  37. #cgo linux CFLAGS: -D_GNU_SOURCE
  38. #cgo linux CXXFLAGS: -D_GNU_SOURCE
  39. #cgo linux LDFLAGS: -ldl
  40. #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/linux-amd64
  41. #cgo linux,arm64 CFLAGS: -D__aarch64__ -D__ARM_NEON -D__ARM_FEATURE_FMA
  42. #cgo linux,arm64 CXXFLAGS: -D__aarch64__ -D__ARM_NEON -D__ARM_FEATURE_FMA
  43. #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux-arm64
  44. #cgo linux,arm64,sve CFLAGS: -march=armv8.6-a+sve
  45. #cgo linux,arm64,sve CXXFLAGS: -march=armv8.6-a+sve
  46. #cgo linux,cuda LDFLAGS: -lcuda -lcudart -lcublas -lcublasLt -lpthread -lrt -lresolv
  47. #cgo linux,rocm LDFLAGS: -lpthread -lrt -lresolv
  48. #cgo rocm CFLAGS: -DGGML_USE_CUDA -DGGML_USE_HIP -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  49. #cgo rocm CXXFLAGS: -DGGML_USE_CUDA -DGGML_USE_HIP -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  50. #cgo rocm LDFLAGS: -L${SRCDIR} -lggml_rocm -lhipblas -lamdhip64 -lrocblas
  51. #cgo windows CFLAGS: -Wno-discarded-qualifiers -D_WIN32_WINNT=0x602
  52. #cgo windows CXXFLAGS: -D_WIN32_WINNT=0x602
  53. #cgo windows LDFLAGS: -lmsvcrt -static-libstdc++ -static-libgcc -static
  54. #cgo windows,amd64 LDFLAGS: -L${SRCDIR}/build/windows-amd64
  55. #cgo windows,arm64 CFLAGS: -D__aarch64__ -D__ARM_NEON -D__ARM_FEATURE_FMA
  56. #cgo windows,arm64 CXXFLAGS: -D__aarch64__ -D__ARM_NEON -D__ARM_FEATURE_FMA
  57. #cgo windows,arm64 LDFLAGS: -L${SRCDIR}/build/windows-arm64
  58. #cgo windows,cuda LDFLAGS: -lcuda -lcudart -lcublas -lcublasLt
  59. #cgo windows,rocm LDFLAGS: -lggml_rocm -lhipblas -lamdhip64 -lrocblas
  60. #include <stdlib.h>
  61. #include "llama.h"
  62. #include "clip.h"
  63. #include "ggml.h"
  64. #include "llava.h"
  65. #include "mllama.h"
  66. #include "sampling_ext.h"
  67. extern bool llamaProgressCallback(float progress, void *user_data);
  68. extern void llamaLog(int level, char* text, void* user_data);
  69. typedef enum {COMP_UNKNOWN,COMP_GCC,COMP_CLANG} COMPILER;
  70. COMPILER inline get_compiler() {
  71. #if defined(__clang__)
  72. return COMP_CLANG;
  73. #elif defined(__GNUC__)
  74. return COMP_GCC;
  75. #else
  76. return UNKNOWN_COMPILER;
  77. #endif
  78. }
  79. */
  80. import "C"
  81. import (
  82. _ "embed"
  83. "errors"
  84. "fmt"
  85. "os"
  86. "runtime"
  87. "runtime/cgo"
  88. "slices"
  89. "strings"
  90. "sync/atomic"
  91. "unsafe"
  92. )
  93. func BackendInit() {
  94. C.llama_backend_init()
  95. }
  96. func PrintSystemInfo() string {
  97. var compiler string
  98. switch C.get_compiler() {
  99. case C.COMP_UNKNOWN:
  100. compiler = "cgo(unknown_compiler)"
  101. case C.COMP_GCC:
  102. compiler = "cgo(gcc)"
  103. case C.COMP_CLANG:
  104. compiler = "cgo(clang)"
  105. }
  106. return C.GoString(C.llama_print_system_info()) + compiler
  107. }
  108. var logLevel atomic.Int32
  109. func init() {
  110. logLevel.Store(int32(C.GGML_LOG_LEVEL_INFO))
  111. C.llama_log_set((C.ggml_log_callback)(C.llamaLog), nil)
  112. }
  113. func EnableDebug() {
  114. logLevel.Store(int32(C.GGML_LOG_LEVEL_DEBUG))
  115. }
  116. //export llamaLog
  117. func llamaLog(level int32, text *C.char, _ unsafe.Pointer) {
  118. if level < logLevel.Load() {
  119. return
  120. }
  121. fmt.Fprint(os.Stderr, C.GoString(text))
  122. }
  123. func GetModelArch(modelPath string) (string, error) {
  124. mp := C.CString(modelPath)
  125. defer C.free(unsafe.Pointer(mp))
  126. gguf_ctx := C.gguf_init_from_file(mp, C.struct_gguf_init_params{no_alloc: true, ctx: (**C.struct_ggml_context)(C.NULL)})
  127. if gguf_ctx == nil {
  128. return "", errors.New("unable to load model file")
  129. }
  130. defer C.gguf_free(gguf_ctx)
  131. key := C.CString("general.architecture")
  132. defer C.free(unsafe.Pointer(key))
  133. arch_index := C.gguf_find_key(gguf_ctx, key)
  134. if int(arch_index) < 0 {
  135. return "", errors.New("unknown model architecture")
  136. }
  137. arch := C.gguf_get_val_str(gguf_ctx, arch_index)
  138. return C.GoString(arch), nil
  139. }
  140. type ContextParams struct {
  141. c C.struct_llama_context_params
  142. }
  143. func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, kvCacheType string) ContextParams {
  144. params := C.llama_context_default_params()
  145. params.n_ctx = C.uint(numCtx)
  146. params.n_batch = C.uint(batchSize)
  147. params.n_seq_max = C.uint(numSeqMax)
  148. params.n_threads = C.int(threads)
  149. params.n_threads_batch = params.n_threads
  150. params.embeddings = C.bool(true)
  151. params.flash_attn = C.bool(flashAttention)
  152. params.type_k = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
  153. params.type_v = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
  154. return ContextParams{c: params}
  155. }
  156. // kvCacheTypeFromStr converts a string cache type to the corresponding GGML type value
  157. func kvCacheTypeFromStr(s string) C.enum_ggml_type {
  158. if s == "" {
  159. return C.GGML_TYPE_F16
  160. }
  161. switch s {
  162. case "q8_0":
  163. return C.GGML_TYPE_Q8_0
  164. case "q4_0":
  165. return C.GGML_TYPE_Q4_0
  166. default:
  167. return C.GGML_TYPE_F16
  168. }
  169. }
  170. type Context struct {
  171. c *C.struct_llama_context
  172. numThreads int
  173. }
  174. var ErrKvCacheFull = errors.New("could not find a kv cache slot")
  175. func (c *Context) Decode(batch *Batch) error {
  176. // Positive return values does not mean a fatal error, but rather a warning.
  177. // 0 - success
  178. // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
  179. // < 0 - error
  180. code := int(C.llama_decode(c.c, batch.c))
  181. if code < 0 {
  182. return fmt.Errorf("llama_decode failed with code %d", code)
  183. }
  184. if code > 0 {
  185. return ErrKvCacheFull
  186. }
  187. return nil
  188. }
  189. func (c *Context) Model() *Model {
  190. return &Model{c: C.llama_get_model(c.c)}
  191. }
  192. func (c *Context) KvCacheSeqAdd(seqId int, p0 int, p1 int, delta int) {
  193. C.llama_kv_cache_seq_add(c.c, C.int(seqId), C.int(p0), C.int(p1), C.int(delta))
  194. }
  195. func (c *Context) KvCacheSeqRm(seqId int, p0 int, p1 int) bool {
  196. return bool(C.llama_kv_cache_seq_rm(c.c, C.int(seqId), C.int(p0), C.int(p1)))
  197. }
  198. func (c *Context) KvCacheSeqCp(srcSeqId int, dstSeqId int, p0 int, p1 int) {
  199. C.llama_kv_cache_seq_cp(c.c, C.int(srcSeqId), C.int(dstSeqId), C.int(p0), C.int(p1))
  200. }
  201. func (c *Context) KvCacheClear() {
  202. C.llama_kv_cache_clear(c.c)
  203. }
  204. func (c *Context) KvCacheDefrag() {
  205. C.llama_kv_cache_defrag(c.c)
  206. }
  207. // Get the embeddings for a sequence id
  208. func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
  209. embeddings := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
  210. if embeddings == nil {
  211. return nil
  212. }
  213. return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
  214. }
  215. func (c *Context) GetEmbeddingsIth(i int) []float32 {
  216. embeddings := unsafe.Pointer(C.llama_get_embeddings_ith(c.c, C.int32_t(i)))
  217. if embeddings == nil {
  218. return nil
  219. }
  220. return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
  221. }
  222. // GetLogits returns the logits from the last decode operation.
  223. // The returned slice has length equal to the vocabulary size.
  224. func (c *Context) GetLogits() []float32 {
  225. logits := unsafe.Pointer(C.llama_get_logits(c.c))
  226. if logits == nil {
  227. return nil
  228. }
  229. // Get the number of vocabulary tokens to determine array size
  230. vocabSize := c.Model().NumVocab()
  231. return unsafe.Slice((*float32)(logits), vocabSize)
  232. }
  233. func (m *Model) Detokenize(tokens []int) (string, error) {
  234. var text string
  235. for _, token := range tokens {
  236. piece := m.TokenToPiece(token)
  237. if piece == "" {
  238. return "", fmt.Errorf("failed to convert token %d to piece", token)
  239. }
  240. text += piece
  241. }
  242. return text, nil
  243. }
  244. type ModelParams struct {
  245. NumGpuLayers int
  246. MainGpu int
  247. UseMmap bool
  248. UseMlock bool
  249. TensorSplit []float32
  250. Progress func(float32)
  251. VocabOnly bool
  252. }
  253. //export llamaProgressCallback
  254. func llamaProgressCallback(progress C.float, userData unsafe.Pointer) C.bool {
  255. handle := *(*cgo.Handle)(userData)
  256. callback := handle.Value().(func(float32))
  257. callback(float32(progress))
  258. return true
  259. }
  260. func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
  261. cparams := C.llama_model_default_params()
  262. cparams.n_gpu_layers = C.int(params.NumGpuLayers)
  263. cparams.main_gpu = C.int32_t(params.MainGpu)
  264. cparams.use_mmap = C.bool(params.UseMmap)
  265. cparams.use_mlock = C.bool(params.UseMlock)
  266. cparams.vocab_only = C.bool(params.VocabOnly)
  267. if len(params.TensorSplit) > 0 {
  268. tensorSplitData := &params.TensorSplit[0]
  269. var tensorSplitPin runtime.Pinner
  270. tensorSplitPin.Pin(tensorSplitData)
  271. defer tensorSplitPin.Unpin()
  272. cparams.tensor_split = (*C.float)(unsafe.Pointer(tensorSplitData))
  273. }
  274. if params.Progress != nil {
  275. handle := cgo.NewHandle(params.Progress)
  276. defer handle.Delete()
  277. var handlePin runtime.Pinner
  278. handlePin.Pin(&handle)
  279. defer handlePin.Unpin()
  280. cparams.progress_callback = C.llama_progress_callback(C.llamaProgressCallback)
  281. cparams.progress_callback_user_data = unsafe.Pointer(&handle)
  282. }
  283. m := Model{c: C.llama_load_model_from_file(C.CString(modelPath), cparams)}
  284. if m.c == nil {
  285. return nil, fmt.Errorf("unable to load model: %s", modelPath)
  286. }
  287. return &m, nil
  288. }
  289. func FreeModel(model *Model) {
  290. C.llama_free_model(model.c)
  291. }
  292. func NewContextWithModel(model *Model, params ContextParams) (*Context, error) {
  293. c := Context{
  294. c: C.llama_new_context_with_model(model.c, params.c),
  295. numThreads: int(params.c.n_threads),
  296. }
  297. if c.c == nil {
  298. return nil, errors.New("unable to create llama context")
  299. }
  300. return &c, nil
  301. }
  302. func (m *Model) NumVocab() int {
  303. return int(C.llama_n_vocab(m.c))
  304. }
  305. func (m *Model) TokenIsEog(token int) bool {
  306. return bool(C.llama_token_is_eog(m.c, C.llama_token(token)))
  307. }
  308. func (m *Model) AddBOSToken() bool {
  309. return bool(C.llama_add_bos_token(m.c))
  310. }
  311. func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float32, threads int) error {
  312. cLoraPath := C.CString(loraPath)
  313. defer C.free(unsafe.Pointer(cLoraPath))
  314. loraAdapter := C.llama_lora_adapter_init(m.c, cLoraPath)
  315. if loraAdapter == nil {
  316. return errors.New("unable to load lora")
  317. }
  318. err := -1
  319. if loraAdapter != nil {
  320. err = int(C.llama_lora_adapter_set(context.c, loraAdapter, C.float(scale)))
  321. }
  322. if err != 0 {
  323. return errors.New("error applying lora from file")
  324. }
  325. return nil
  326. }
  327. type Batch struct {
  328. c C.struct_llama_batch
  329. batchSize int
  330. maxSeq int
  331. embedSize int
  332. }
  333. // Creates a new batch for either word tokens or image embeddings (if embedSize is non-zero).
  334. // Batches cannot contain both types at the same time. batchSize is the maximum number of entries
  335. // that can be added per sequence
  336. func NewBatch(batchSize int, maxSeq int, embedSize int) (*Batch, error) {
  337. b := Batch{
  338. c: C.llama_batch_init(C.int(batchSize*maxSeq), C.int(embedSize), C.int(maxSeq)),
  339. batchSize: batchSize,
  340. maxSeq: maxSeq,
  341. embedSize: embedSize,
  342. }
  343. // Check to see if any of the allocations in llama_batch_init() failed
  344. nilPointer := (embedSize == 0 && b.c.token == nil) || (embedSize != 0 && b.c.embd == nil) ||
  345. b.c.pos == nil || b.c.n_seq_id == nil || b.c.seq_id == nil || b.c.logits == nil ||
  346. slices.Contains(unsafe.Slice(b.c.seq_id, b.allocSize()), nil)
  347. if nilPointer {
  348. C.llama_batch_free(b.c)
  349. return nil, fmt.Errorf("unable to allocate batch (batchSize=%v maxSeq=%v embedSize=%v)", batchSize, maxSeq, embedSize)
  350. }
  351. return &b, nil
  352. }
  353. func (b *Batch) Size() int {
  354. return b.batchSize
  355. }
  356. func (b *Batch) allocSize() int {
  357. return b.batchSize * b.maxSeq
  358. }
  359. func (b *Batch) NumTokens() int {
  360. return int(b.c.n_tokens)
  361. }
  362. func (b *Batch) IsEmbedding() bool {
  363. return b.embedSize != 0
  364. }
  365. // Add adds either a token or an image embedding to the batch depending on the type
  366. // when the batch was initialized. The other argument will be ignored. Adds to the
  367. // batch with the given position for the given sequence ids, and optionally instructs
  368. // to include logits.
  369. func (b *Batch) Add(token int, embed []float32, pos int, logits bool, seqIds ...int) {
  370. if !b.IsEmbedding() {
  371. unsafe.Slice(b.c.token, b.allocSize())[b.c.n_tokens] = C.llama_token(token)
  372. } else {
  373. copy(unsafe.Slice((*float32)(b.c.embd), b.allocSize()*b.embedSize)[int(b.c.n_tokens)*b.embedSize:], embed)
  374. }
  375. unsafe.Slice(b.c.pos, b.allocSize())[b.c.n_tokens] = C.llama_pos(pos)
  376. unsafe.Slice(b.c.n_seq_id, b.allocSize())[b.c.n_tokens] = C.int(len(seqIds))
  377. for i, s := range seqIds {
  378. unsafe.Slice((unsafe.Slice(b.c.seq_id, b.allocSize())[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
  379. }
  380. if logits {
  381. unsafe.Slice(b.c.logits, b.allocSize())[b.c.n_tokens] = 1
  382. } else {
  383. unsafe.Slice(b.c.logits, b.allocSize())[b.c.n_tokens] = 0
  384. }
  385. b.c.n_tokens += 1
  386. }
  387. func (b *Batch) Clear() {
  388. b.c.n_tokens = 0
  389. }
  390. func (b *Batch) Free() {
  391. b.batchSize = 0
  392. C.llama_batch_free(b.c)
  393. }
  394. type Model struct {
  395. c *C.struct_llama_model
  396. }
  397. func (m *Model) TokenToPiece(token int) string {
  398. tokenLen := 12
  399. buf := make([]byte, tokenLen)
  400. tokenLen = int(C.llama_token_to_piece(
  401. m.c,
  402. C.int32_t(token),
  403. (*C.char)(unsafe.Pointer(&buf[0])),
  404. C.int32_t(tokenLen),
  405. C.int32_t(0),
  406. C.bool(true),
  407. ))
  408. if tokenLen < 0 {
  409. tokenLen = -tokenLen
  410. buf = make([]byte, tokenLen)
  411. C.llama_token_to_piece(
  412. m.c,
  413. C.int32_t(token),
  414. (*C.char)(unsafe.Pointer(&buf[0])),
  415. C.int32_t(tokenLen),
  416. C.int32_t(0),
  417. C.bool(true),
  418. )
  419. }
  420. return strings.TrimRight(string(buf), "\x00")
  421. }
  422. func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int, error) {
  423. maxTokens := len(text) + 2
  424. cTokens := make([]C.llama_token, maxTokens)
  425. cText := C.CString(text)
  426. defer C.free(unsafe.Pointer(cText))
  427. result := C.llama_tokenize(
  428. m.c,
  429. cText,
  430. C.int32_t(len(text)),
  431. &cTokens[0],
  432. C.int32_t(maxTokens),
  433. C.bool(addSpecial),
  434. C.bool(parseSpecial),
  435. )
  436. // if the result is negative, reallocate and retry with the correct buffer size
  437. if result < 0 {
  438. maxTokens = int(-result)
  439. cTokens = make([]C.llama_token, maxTokens)
  440. result = C.llama_tokenize(
  441. m.c,
  442. cText,
  443. C.int32_t(len(text)),
  444. &cTokens[0],
  445. C.int32_t(maxTokens),
  446. C.bool(addSpecial),
  447. C.bool(parseSpecial),
  448. )
  449. if result < 0 {
  450. return nil, fmt.Errorf("tokenization failed, required %d tokens", -result)
  451. }
  452. }
  453. tokens := make([]int, result)
  454. for i := range result {
  455. tokens[i] = int(cTokens[i])
  456. }
  457. return tokens, nil
  458. }
  459. func (m *Model) NEmbd() int {
  460. return int(C.llama_n_embd(m.c))
  461. }
  462. func Quantize(infile, outfile string, ftype uint32) error {
  463. cinfile := C.CString(infile)
  464. defer C.free(unsafe.Pointer(cinfile))
  465. coutfile := C.CString(outfile)
  466. defer C.free(unsafe.Pointer(coutfile))
  467. params := C.llama_model_quantize_default_params()
  468. params.nthread = -1
  469. params.ftype = ftype
  470. if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
  471. return fmt.Errorf("llama_model_quantize: %d", rc)
  472. }
  473. return nil
  474. }
  475. // vision processing
  476. type ClipContext struct {
  477. c *C.struct_clip_ctx
  478. }
  479. func NewClipContext(llamaContext *Context, modelPath string) (*ClipContext, error) {
  480. mp := C.CString(modelPath)
  481. defer C.free(unsafe.Pointer(mp))
  482. c := C.clip_model_load(mp, 1)
  483. if c == nil {
  484. return nil, fmt.Errorf("unable to load clip model: %v", modelPath)
  485. }
  486. projEmbedSize := int(C.clip_n_mmproj_embd(c))
  487. modelEmbedSize := llamaContext.Model().NEmbd()
  488. if projEmbedSize != modelEmbedSize {
  489. return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
  490. }
  491. return &ClipContext{c: c}, nil
  492. }
  493. func (c *ClipContext) Free() {
  494. C.clip_free(c.c)
  495. }
  496. func (c *ClipContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32, error) {
  497. l := C.llava_image_embed_make_with_bytes(c.c, C.int(llamaContext.numThreads), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data)))
  498. if l == nil {
  499. return nil, errors.New("unable to make llava embedding from image")
  500. }
  501. numTokens := int(l.n_image_pos)
  502. numEmbed := llamaContext.Model().NEmbd()
  503. s := unsafe.Slice((*float32)(l.embed), numEmbed*numTokens)
  504. embed := make([][]float32, numTokens)
  505. rows := make([]float32, len(s))
  506. copy(rows, s)
  507. for i := range embed {
  508. embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
  509. }
  510. C.llava_image_embed_free(l)
  511. return embed, nil
  512. }
  513. type MllamaContext struct {
  514. c *C.struct_mllama_ctx
  515. }
  516. func NewMllamaContext(llamaContext *Context, modelPath string) (*MllamaContext, error) {
  517. mp := C.CString(modelPath)
  518. defer C.free(unsafe.Pointer(mp))
  519. c := C.mllama_model_load(mp, 1)
  520. if c == nil {
  521. return nil, fmt.Errorf("unable to load mllama model: %v", modelPath)
  522. }
  523. projEmbedSize := int(C.mllama_n_embd(c))
  524. modelEmbedSize := llamaContext.Model().NEmbd()
  525. if projEmbedSize != modelEmbedSize {
  526. return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
  527. }
  528. return &MllamaContext{c: c}, nil
  529. }
  530. func (m *MllamaContext) Free() {
  531. C.mllama_free(m.c)
  532. }
  533. func (m *MllamaContext) NewEmbed(llamaContext *Context, data []byte, aspectRatioId int) ([][]float32, error) {
  534. img := C.mllama_image_init()
  535. defer C.mllama_image_free(img)
  536. ok := bool(C.mllama_image_load_from_data(unsafe.Pointer(&data[0]), C.int(len(data)), 560, 560, 3, 4, C.int(aspectRatioId), img))
  537. if !ok {
  538. return nil, errors.New("unable to load mllama image data")
  539. }
  540. rows := make([]float32, m.EmbedSize(llamaContext))
  541. ok = bool(C.mllama_image_encode(m.c, C.int(llamaContext.numThreads), img, (*C.float)(unsafe.Pointer(&rows[0]))))
  542. if !ok {
  543. return nil, errors.New("unable to make mllama embedding from image")
  544. }
  545. embed := make([][]float32, 1)
  546. embed[0] = rows
  547. return embed, nil
  548. }
  549. func (m *MllamaContext) EmbedSize(llamaContext *Context) int {
  550. numTokens := int(C.mllama_n_positions(m.c) * C.mllama_n_tiles(m.c))
  551. numEmbed := llamaContext.Model().NEmbd()
  552. return numTokens * numEmbed
  553. }
  554. func (c *Context) SetCrossAttention(state bool) {
  555. C.llama_set_cross_attention(c.c, C.bool(state))
  556. }
  557. func (c *Context) Synchronize() {
  558. C.llama_synchronize(c.c)
  559. }
  560. // sampling
  561. // TODO: this is a temporary wrapper to allow calling C++ code from CGo
  562. type SamplingContext struct {
  563. c *C.struct_common_sampler
  564. }
  565. type SamplingParams struct {
  566. TopK int
  567. TopP float32
  568. MinP float32
  569. TypicalP float32
  570. Temp float32
  571. RepeatLastN int
  572. PenaltyRepeat float32
  573. PenaltyFreq float32
  574. PenaltyPresent float32
  575. Mirostat int
  576. MirostatTau float32
  577. MirostatEta float32
  578. PenalizeNl bool
  579. Seed uint32
  580. Grammar string
  581. }
  582. func NewSamplingContext(model *Model, params SamplingParams) (*SamplingContext, error) {
  583. var cparams C.struct_common_sampler_cparams
  584. cparams.top_k = C.int32_t(params.TopK)
  585. cparams.top_p = C.float(params.TopP)
  586. cparams.min_p = C.float(params.MinP)
  587. cparams.typical_p = C.float(params.TypicalP)
  588. cparams.temp = C.float(params.Temp)
  589. cparams.penalty_last_n = C.int32_t(params.RepeatLastN)
  590. cparams.penalty_repeat = C.float(params.PenaltyRepeat)
  591. cparams.penalty_freq = C.float(params.PenaltyFreq)
  592. cparams.penalty_present = C.float(params.PenaltyFreq)
  593. cparams.mirostat = C.int32_t(params.Mirostat)
  594. cparams.mirostat_tau = C.float(params.MirostatTau)
  595. cparams.mirostat_eta = C.float(params.MirostatEta)
  596. cparams.penalize_nl = C.bool(params.PenalizeNl)
  597. cparams.seed = C.uint32_t(params.Seed)
  598. grammar := C.CString(params.Grammar)
  599. defer C.free(unsafe.Pointer(grammar))
  600. cparams.grammar = grammar
  601. context := &SamplingContext{c: C.common_sampler_cinit(model.c, &cparams)}
  602. if context.c == nil {
  603. return nil, errors.New("unable to create sampling context")
  604. }
  605. runtime.SetFinalizer(context, func(s *SamplingContext) { C.common_sampler_cfree(s.c) })
  606. return context, nil
  607. }
  608. func (s *SamplingContext) Reset() {
  609. C.common_sampler_creset(s.c)
  610. }
  611. func (s *SamplingContext) Sample(llamaContext *Context, idx int) int {
  612. return int(C.common_sampler_csample(s.c, llamaContext.c, C.int(idx)))
  613. }
  614. func (s *SamplingContext) Accept(id int, applyGrammar bool) {
  615. C.common_sampler_caccept(s.c, C.llama_token(id), C.bool(applyGrammar))
  616. }
  617. // SchemaToGrammar converts the provided JSON schema to a grammar. It returns
  618. // nil if the provided schema is invalid JSON or an invalid JSON schema.
  619. func SchemaToGrammar(schema []byte) []byte {
  620. cStr := C.CString(string(schema))
  621. defer C.free(unsafe.Pointer(cStr))
  622. // Allocate buffer for grammar output with reasonable size
  623. const maxLen = 32768 // 32KB
  624. buf := make([]byte, maxLen)
  625. // Call C function to convert schema to grammar
  626. n := C.schema_to_grammar(cStr, (*C.char)(unsafe.Pointer(&buf[0])), C.size_t(maxLen))
  627. if n == 0 {
  628. // preserve nil
  629. return nil
  630. }
  631. return buf[:n]
  632. }