llama.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633
  1. package llama
  2. //go:generate make -j 8
  3. /*
  4. #cgo CFLAGS: -O2 -std=c11 -DGGML_BUILD=1 -DNDEBUG -DLOG_DISABLE_LOGS -DGGML_USE_LLAMAFILE
  5. #cgo CXXFLAGS: -O2 -std=c++11 -DGGML_BUILD=1 -DNDEBUG -DLOG_DISABLE_LOGS -DGGML_USE_LLAMAFILE
  6. #cgo amd64,avx CFLAGS: -mavx
  7. #cgo amd64,avx CXXFLAGS: -mavx
  8. #cgo amd64,avx2 CFLAGS: -mavx2 -mfma
  9. #cgo amd64,avx2 CXXFLAGS: -mavx2 -mfma
  10. #cgo amd64,f16c CFLAGS: -mf16c
  11. #cgo amd64,f16c CXXFLAGS: -mf16c
  12. #cgo amd64,fma CFLAGS: -mfma
  13. #cgo amd64,fma CXXFLAGS: -mfma
  14. #cgo avx CFLAGS: -mavx
  15. #cgo avx CXXFLAGS: -mavx
  16. #cgo avx2 CFLAGS: -mavx2 -mfma -mf16c
  17. #cgo avx2 CXXFLAGS: -mavx2 -mfma -mf16c
  18. #cgo cuda CFLAGS: -fPIE -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  19. #cgo cuda CFLAGS: -fPIE -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  20. #cgo cuda CXXFLAGS: -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  21. #cgo cuda CXXFLAGS: -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  22. #cgo cuda_v11 LDFLAGS: -lggml_cuda_v11 -L/usr/local/cuda-11/lib64
  23. #cgo cuda_v12 LDFLAGS: -lggml_cuda_v12 -L/usr/local/cuda-12/lib64
  24. #cgo darwin,amd64 CFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
  25. #cgo darwin,amd64 CXXFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
  26. #cgo darwin,amd64 LDFLAGS: -framework Foundation
  27. #cgo darwin,amd64,avx2 CFLAGS: -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
  28. #cgo darwin,amd64,avx2 CXXFLAGS: -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
  29. #cgo darwin,amd64,avx2 LDFLAGS: -framework Accelerate
  30. #cgo darwin,arm64 CFLAGS: -DGGML_USE_METAL -DGGML_USE_ACCELERATE -DGGML_METAL_EMBED_LIBRARY -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64 -DGGML_USE_BLAS
  31. #cgo darwin,arm64 CXXFLAGS: -DGGML_USE_METAL -DGGML_USE_ACCELERATE -DGGML_METAL_EMBED_LIBRARY -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64 -DGGML_USE_BLAS
  32. #cgo darwin,arm64 LDFLAGS: -framework Foundation -framework Metal -framework MetalKit -framework Accelerate
  33. #cgo linux CFLAGS: -D_GNU_SOURCE
  34. #cgo linux CXXFLAGS: -D_GNU_SOURCE
  35. #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/Linux/amd64
  36. #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/Linux/amd64
  37. #cgo linux,arm64 CFLAGS: -D__aarch64__ -D__ARM_NEON -D__ARM_FEATURE_FMA -D__ARM_FEATURE_MATMUL_INT8
  38. #cgo linux,arm64 CXXFLAGS: -D__aarch64__ -D__ARM_NEON -D__ARM_FEATURE_FMA -D__ARM_FEATURE_MATMUL_INT8
  39. #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/Linux/arm64
  40. #cgo linux,arm64,sve CFLAGS: -march=armv8.6-a+sve
  41. #cgo linux,arm64,sve CXXFLAGS: -march=armv8.6-a+sve
  42. #cgo linux,cuda LDFLAGS: -lcuda -lcudart -lcublas -lcublasLt -lpthread -ldl -lrt -lresolv
  43. #cgo linux,rocm LDFLAGS: -L/opt/rocm/lib -lpthread -ldl -lrt -lresolv
  44. #cgo rocm CFLAGS: -DGGML_USE_CUDA -DGGML_USE_HIPBLAS -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  45. #cgo rocm CXXFLAGS: -DGGML_USE_CUDA -DGGML_USE_HIPBLAS -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  46. #cgo rocm LDFLAGS: -L${SRCDIR} -lggml_rocm -lhipblas -lamdhip64 -lrocblas
  47. #cgo windows CFLAGS: -Wno-discarded-qualifiers -D_WIN32_WINNT=0x602
  48. #cgo windows CXXFLAGS: -D_WIN32_WINNT=0x602
  49. #cgo windows LDFLAGS: -lmsvcrt
  50. #cgo windows LDFLAGS: -lmsvcrt -static-libstdc++ -static-libgcc -static
  51. #cgo windows,amd64 LDFLAGS: -L${SRCDIR}/build/Windows/amd64
  52. #cgo windows,amd64 LDFLAGS: -L${SRCDIR}/build/Windows/amd64
  53. #cgo windows,arm64 CFLAGS: -D__aarch64__ -D__ARM_NEON -D__ARM_FEATURE_FMA
  54. #cgo windows,arm64 CXXFLAGS: -D__aarch64__ -D__ARM_NEON -D__ARM_FEATURE_FMA
  55. #cgo windows,arm64 LDFLAGS: -L${SRCDIR}/build/Windows/arm64
  56. #cgo windows,arm64 LDFLAGS: -L${SRCDIR}/build/Windows/arm64
  57. #cgo windows,cuda LDFLAGS: -lcuda -lcudart -lcublas -lcublasLt
  58. #cgo windows,rocm LDFLAGS: -lggml_rocm -lhipblas -lamdhip64 -lrocblas
  59. #include <stdlib.h>
  60. #include "llama.h"
  61. #include "clip.h"
  62. #include "ggml.h"
  63. #include "llava.h"
  64. #include "mllama.h"
  65. #include "sampling_ext.h"
  66. bool llamaProgressCallback(float progress, void *user_data);
  67. typedef enum {COMP_UNKNOWN,COMP_GCC,COMP_CLANG} COMPILER;
  68. COMPILER inline get_compiler() {
  69. #if defined(__clang__)
  70. return COMP_CLANG;
  71. #elif defined(__GNUC__)
  72. return COMP_GCC;
  73. #else
  74. return UNKNOWN_COMPILER;
  75. #endif
  76. }
  77. */
  78. import "C"
  79. import (
  80. _ "embed"
  81. "errors"
  82. "fmt"
  83. "runtime"
  84. "runtime/cgo"
  85. "strings"
  86. "unsafe"
  87. )
  88. var CpuFeatures = ""
  89. func BackendInit() {
  90. C.llama_backend_init()
  91. }
  92. func PrintSystemInfo() string {
  93. var compiler string
  94. switch C.get_compiler() {
  95. case C.COMP_UNKNOWN:
  96. compiler = "cgo(unknown_compiler)"
  97. case C.COMP_GCC:
  98. compiler = "cgo(gcc)"
  99. case C.COMP_CLANG:
  100. compiler = "cgo(clang)"
  101. }
  102. return C.GoString(C.llama_print_system_info()) + compiler
  103. }
  104. func GetModelArch(modelPath string) (string, error) {
  105. mp := C.CString(modelPath)
  106. defer C.free(unsafe.Pointer(mp))
  107. gguf_ctx := C.gguf_init_from_file(mp, C.struct_gguf_init_params{no_alloc: true, ctx: (**C.struct_ggml_context)(C.NULL)})
  108. if gguf_ctx == nil {
  109. return "", errors.New("unable to load model file")
  110. }
  111. defer C.gguf_free(gguf_ctx)
  112. key := C.CString("general.architecture")
  113. defer C.free(unsafe.Pointer(key))
  114. arch_index := C.gguf_find_key(gguf_ctx, key)
  115. if int(arch_index) < 0 {
  116. return "", errors.New("unknown model architecture")
  117. }
  118. arch := C.gguf_get_val_str(gguf_ctx, arch_index)
  119. return C.GoString(arch), nil
  120. }
  121. type ContextParams struct {
  122. c C.struct_llama_context_params
  123. }
  124. func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool) ContextParams {
  125. params := C.llama_context_default_params()
  126. params.n_ctx = C.uint(numCtx)
  127. params.n_batch = C.uint(batchSize)
  128. params.n_seq_max = C.uint(numSeqMax)
  129. params.n_threads = C.int(threads)
  130. params.n_threads_batch = params.n_threads
  131. params.embeddings = C.bool(true)
  132. params.flash_attn = C.bool(flashAttention)
  133. return ContextParams{c: params}
  134. }
  135. type Context struct {
  136. c *C.struct_llama_context
  137. numThreads int
  138. }
  139. func (c *Context) KvCacheClear() {
  140. C.llama_kv_cache_clear(c.c)
  141. }
  142. func (c *Context) Decode(batch *Batch) error {
  143. // Positive return values does not mean a fatal error, but rather a warning.
  144. // 0 - success
  145. // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
  146. // < 0 - error
  147. code := int(C.llama_decode(c.c, batch.c))
  148. if code < 0 {
  149. return fmt.Errorf("llama_decode failed with code %d", code)
  150. }
  151. if code > 0 {
  152. return fmt.Errorf("could not find a KV slot for the batch - try reducing the size of the batch or increase the context. code: %d", code)
  153. }
  154. return nil
  155. }
  156. func (c *Context) Model() *Model {
  157. return &Model{c: C.llama_get_model(c.c)}
  158. }
  159. func (c *Context) KvCacheSeqAdd(seqId int, p0 int, p1 int, delta int) {
  160. C.llama_kv_cache_seq_add(c.c, C.int(seqId), C.int(p0), C.int(p1), C.int(delta))
  161. }
  162. func (c *Context) KvCacheSeqRm(seqId int, p0 int, p1 int) bool {
  163. return bool(C.llama_kv_cache_seq_rm(c.c, C.int(seqId), C.int(p0), C.int(p1)))
  164. }
  165. func (c *Context) KvCacheSeqCp(srcSeqId int, dstSeqId int, p0 int, p1 int) {
  166. C.llama_kv_cache_seq_cp(c.c, C.int(srcSeqId), C.int(dstSeqId), C.int(p0), C.int(p1))
  167. }
  168. // Get the embeddings for a sequence id
  169. func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
  170. embeddings := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
  171. if embeddings == nil {
  172. return nil
  173. }
  174. return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
  175. }
  176. func (c *Context) GetEmbeddingsIth(i int) []float32 {
  177. embeddings := unsafe.Pointer(C.llama_get_embeddings_ith(c.c, C.int32_t(i)))
  178. if embeddings == nil {
  179. return nil
  180. }
  181. return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
  182. }
  183. type ModelParams struct {
  184. NumGpuLayers int
  185. MainGpu int
  186. UseMmap bool
  187. UseMlock bool
  188. TensorSplit []float32
  189. Progress func(float32)
  190. VocabOnly bool
  191. }
  192. //export llamaProgressCallback
  193. func llamaProgressCallback(progress C.float, userData unsafe.Pointer) C.bool {
  194. handle := *(*cgo.Handle)(userData)
  195. callback := handle.Value().(func(float32))
  196. callback(float32(progress))
  197. return true
  198. }
  199. func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
  200. cparams := C.llama_model_default_params()
  201. cparams.n_gpu_layers = C.int(params.NumGpuLayers)
  202. cparams.main_gpu = C.int32_t(params.MainGpu)
  203. cparams.use_mmap = C.bool(params.UseMmap)
  204. cparams.use_mlock = C.bool(params.UseMlock)
  205. cparams.vocab_only = C.bool(params.VocabOnly)
  206. if len(params.TensorSplit) > 0 {
  207. tensorSplitData := &params.TensorSplit[0]
  208. var tensorSplitPin runtime.Pinner
  209. tensorSplitPin.Pin(tensorSplitData)
  210. defer tensorSplitPin.Unpin()
  211. cparams.tensor_split = (*C.float)(unsafe.Pointer(tensorSplitData))
  212. }
  213. if params.Progress != nil {
  214. handle := cgo.NewHandle(params.Progress)
  215. defer handle.Delete()
  216. var handlePin runtime.Pinner
  217. handlePin.Pin(&handle)
  218. defer handlePin.Unpin()
  219. cparams.progress_callback = C.llama_progress_callback(C.llamaProgressCallback)
  220. cparams.progress_callback_user_data = unsafe.Pointer(&handle)
  221. }
  222. m := Model{c: C.llama_load_model_from_file(C.CString(modelPath), cparams)}
  223. if m.c == (*C.struct_llama_model)(C.NULL) {
  224. return nil, fmt.Errorf("unable to load model: %s", modelPath)
  225. }
  226. return &m, nil
  227. }
  228. func FreeModel(model *Model) {
  229. C.llama_free_model(model.c)
  230. }
  231. func NewContextWithModel(model *Model, params ContextParams) (*Context, error) {
  232. c := Context{
  233. c: C.llama_new_context_with_model(model.c, params.c),
  234. numThreads: int(params.c.n_threads),
  235. }
  236. if c.c == (*C.struct_llama_context)(C.NULL) {
  237. return nil, errors.New("unable to create llama context")
  238. }
  239. return &c, nil
  240. }
  241. func (m *Model) NumVocab() int {
  242. return int(C.llama_n_vocab(m.c))
  243. }
  244. func (m *Model) TokenIsEog(token int) bool {
  245. return bool(C.llama_token_is_eog(m.c, C.llama_token(token)))
  246. }
  247. func (m *Model) AddBOSToken() bool {
  248. return bool(C.llama_add_bos_token(m.c))
  249. }
  250. func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float32, threads int) error {
  251. cLoraPath := C.CString(loraPath)
  252. defer C.free(unsafe.Pointer(cLoraPath))
  253. loraAdapter := C.llama_lora_adapter_init(m.c, cLoraPath)
  254. err := -1
  255. if loraAdapter != nil {
  256. err = int(C.llama_lora_adapter_set(context.c, loraAdapter, C.float(scale)))
  257. }
  258. if err != 0 {
  259. return errors.New("error applying lora from file")
  260. }
  261. return nil
  262. }
  263. type Batch struct {
  264. c C.struct_llama_batch
  265. batchSize int
  266. maxSeq int
  267. embedSize int
  268. }
  269. // Creates a new batch for either word tokens or image embeddings (if embedSize is non-zero).
  270. // Batches cannot contain both types at the same time. batchSize is the maximum number of entries
  271. // that can be added per sequence
  272. func NewBatch(batchSize int, maxSeq int, embedSize int) *Batch {
  273. return &Batch{
  274. c: C.llama_batch_init(C.int(batchSize*maxSeq), C.int(embedSize), C.int(maxSeq)),
  275. batchSize: batchSize,
  276. maxSeq: maxSeq,
  277. embedSize: embedSize,
  278. }
  279. }
  280. func (b *Batch) Size() int {
  281. return b.batchSize
  282. }
  283. func (b *Batch) allocSize() int {
  284. return b.batchSize * b.maxSeq
  285. }
  286. func (b *Batch) NumTokens() int {
  287. return int(b.c.n_tokens)
  288. }
  289. func (b *Batch) IsEmbedding() bool {
  290. return b.embedSize != 0
  291. }
  292. // Add adds either a token or an image embedding to the batch depending on the type
  293. // when the batch was initialized. The other argument will be ignored. Adds to the
  294. // batch with the given position for the given sequence ids, and optionally instructs
  295. // to include logits.
  296. func (b *Batch) Add(token int, embed []float32, pos int, logits bool, seqIds ...int) {
  297. if !b.IsEmbedding() {
  298. unsafe.Slice(b.c.token, b.allocSize())[b.c.n_tokens] = C.llama_token(token)
  299. } else {
  300. copy(unsafe.Slice((*float32)(b.c.embd), b.allocSize()*b.embedSize)[int(b.c.n_tokens)*b.embedSize:], embed)
  301. }
  302. unsafe.Slice(b.c.pos, b.allocSize())[b.c.n_tokens] = C.llama_pos(pos)
  303. unsafe.Slice(b.c.n_seq_id, b.allocSize())[b.c.n_tokens] = C.int(len(seqIds))
  304. for i, s := range seqIds {
  305. unsafe.Slice((unsafe.Slice(b.c.seq_id, b.allocSize())[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
  306. }
  307. if logits {
  308. unsafe.Slice(b.c.logits, b.allocSize())[b.c.n_tokens] = 1
  309. }
  310. b.c.n_tokens += 1
  311. }
  312. func (b *Batch) Clear() {
  313. b.c.n_tokens = 0
  314. }
  315. func (b *Batch) Free() {
  316. b.batchSize = 0
  317. C.llama_batch_free(b.c)
  318. }
  319. type Model struct {
  320. c *C.struct_llama_model
  321. }
  322. func (m *Model) TokenToPiece(token int) string {
  323. tokenLen := 12
  324. buf := make([]byte, tokenLen)
  325. tokenLen = int(C.llama_token_to_piece(
  326. m.c,
  327. C.int32_t(token),
  328. (*C.char)(unsafe.Pointer(&buf[0])),
  329. C.int32_t(tokenLen),
  330. C.int32_t(0),
  331. C.bool(true),
  332. ))
  333. if tokenLen < 0 {
  334. tokenLen = -tokenLen
  335. buf = make([]byte, tokenLen)
  336. C.llama_token_to_piece(
  337. m.c,
  338. C.int32_t(token),
  339. (*C.char)(unsafe.Pointer(&buf[0])),
  340. C.int32_t(tokenLen),
  341. C.int32_t(0),
  342. C.bool(true),
  343. )
  344. }
  345. return strings.TrimRight(string(buf), "\x00")
  346. }
  347. func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int, error) {
  348. maxTokens := len(text) + 2
  349. cTokens := make([]C.llama_token, maxTokens)
  350. cText := C.CString(text)
  351. defer C.free(unsafe.Pointer(cText))
  352. result := C.llama_tokenize(
  353. m.c,
  354. cText,
  355. C.int32_t(len(text)),
  356. &cTokens[0],
  357. C.int32_t(maxTokens),
  358. C.bool(addSpecial),
  359. C.bool(parseSpecial),
  360. )
  361. // if the result is negative, reallocate and retry with the correct buffer size
  362. if result < 0 {
  363. maxTokens = int(-result)
  364. cTokens = make([]C.llama_token, maxTokens)
  365. result = C.llama_tokenize(
  366. m.c,
  367. cText,
  368. C.int32_t(len(text)),
  369. &cTokens[0],
  370. C.int32_t(maxTokens),
  371. C.bool(addSpecial),
  372. C.bool(parseSpecial),
  373. )
  374. if result < 0 {
  375. return nil, fmt.Errorf("tokenization failed, required %d tokens", -result)
  376. }
  377. }
  378. tokens := make([]int, result)
  379. for i := range result {
  380. tokens[i] = int(cTokens[i])
  381. }
  382. return tokens, nil
  383. }
  384. func (m *Model) NEmbd() int {
  385. return int(C.llama_n_embd(m.c))
  386. }
  387. func Quantize(infile, outfile string, ftype uint32) error {
  388. cinfile := C.CString(infile)
  389. defer C.free(unsafe.Pointer(cinfile))
  390. coutfile := C.CString(outfile)
  391. defer C.free(unsafe.Pointer(coutfile))
  392. params := C.llama_model_quantize_default_params()
  393. params.nthread = -1
  394. params.ftype = ftype
  395. if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
  396. return fmt.Errorf("llama_model_quantize: %d", rc)
  397. }
  398. return nil
  399. }
  400. // vision processing
  401. type ClipContext struct {
  402. c *C.struct_clip_ctx
  403. }
  404. func NewClipContext(llamaContext *Context, modelPath string) (*ClipContext, error) {
  405. mp := C.CString(modelPath)
  406. defer C.free(unsafe.Pointer(mp))
  407. c := C.clip_model_load(mp, 1)
  408. projEmbedSize := int(C.clip_n_mmproj_embd(c))
  409. modelEmbedSize := llamaContext.Model().NEmbd()
  410. if projEmbedSize != modelEmbedSize {
  411. return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
  412. }
  413. return &ClipContext{c: c}, nil
  414. }
  415. func (c *ClipContext) Free() {
  416. C.clip_free(c.c)
  417. }
  418. func (c *ClipContext) NewEmbed(llamaContext *Context, data []byte) [][]float32 {
  419. l := C.llava_image_embed_make_with_bytes(c.c, C.int(llamaContext.numThreads), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data)))
  420. numTokens := int(l.n_image_pos)
  421. numEmbed := llamaContext.Model().NEmbd()
  422. s := unsafe.Slice((*float32)(l.embed), numEmbed*numTokens)
  423. embed := make([][]float32, numTokens)
  424. rows := make([]float32, len(s))
  425. copy(rows, s)
  426. for i := range embed {
  427. embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
  428. }
  429. C.llava_image_embed_free(l)
  430. return embed
  431. }
  432. type MllamaContext struct {
  433. c *C.struct_mllama_ctx
  434. }
  435. func NewMllamaContext(llamaContext *Context, modelPath string) (*MllamaContext, error) {
  436. mp := C.CString(modelPath)
  437. defer C.free(unsafe.Pointer(mp))
  438. c := C.mllama_model_load(mp, 1)
  439. projEmbedSize := int(C.mllama_n_embd(c))
  440. modelEmbedSize := llamaContext.Model().NEmbd()
  441. if projEmbedSize != modelEmbedSize {
  442. return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
  443. }
  444. return &MllamaContext{c: c}, nil
  445. }
  446. func (m *MllamaContext) Free() {
  447. C.mllama_free(m.c)
  448. }
  449. func (m *MllamaContext) NewEmbed(llamaContext *Context, data []byte, aspectRatioId int) [][]float32 {
  450. img := C.mllama_image_init()
  451. defer C.mllama_image_free(img)
  452. C.mllama_image_load_from_data(unsafe.Pointer(&data[0]), C.int(len(data)), 560, 560, 3, 4, C.int(aspectRatioId), img)
  453. rows := make([]float32, m.EmbedSize(llamaContext))
  454. C.mllama_image_encode(m.c, C.int(llamaContext.numThreads), img, (*C.float)(unsafe.Pointer(&rows[0])))
  455. embed := make([][]float32, 1)
  456. embed[0] = rows
  457. return embed
  458. }
  459. func (m *MllamaContext) EmbedSize(llamaContext *Context) int {
  460. numTokens := int(C.mllama_n_positions(m.c) * C.mllama_n_tiles(m.c))
  461. numEmbed := llamaContext.Model().NEmbd()
  462. return numTokens * numEmbed
  463. }
  464. func (c *Context) SetCrossAttention(state bool) {
  465. C.llama_set_cross_attention(c.c, C.bool(state))
  466. }
  467. // sampling
  468. // TODO: this is a temporary wrapper to allow calling C++ code from CGo
  469. type SamplingContext struct {
  470. c *C.struct_gpt_sampler
  471. }
  472. type SamplingParams struct {
  473. TopK int
  474. TopP float32
  475. MinP float32
  476. TfsZ float32
  477. TypicalP float32
  478. Temp float32
  479. RepeatLastN int
  480. PenaltyRepeat float32
  481. PenaltyFreq float32
  482. PenaltyPresent float32
  483. Mirostat int
  484. MirostatTau float32
  485. MirostatEta float32
  486. PenalizeNl bool
  487. Seed uint32
  488. Grammar string
  489. }
  490. func NewSamplingContext(model *Model, params SamplingParams) *SamplingContext {
  491. var cparams C.struct_gpt_sampler_cparams
  492. cparams.top_k = C.int32_t(params.TopK)
  493. cparams.top_p = C.float(params.TopP)
  494. cparams.min_p = C.float(params.MinP)
  495. cparams.tfs_z = C.float(params.TfsZ)
  496. cparams.typical_p = C.float(params.TypicalP)
  497. cparams.temp = C.float(params.Temp)
  498. cparams.penalty_last_n = C.int32_t(params.RepeatLastN)
  499. cparams.penalty_repeat = C.float(params.PenaltyRepeat)
  500. cparams.penalty_freq = C.float(params.PenaltyFreq)
  501. cparams.penalty_present = C.float(params.PenaltyFreq)
  502. cparams.mirostat = C.int32_t(params.Mirostat)
  503. cparams.mirostat_tau = C.float(params.MirostatTau)
  504. cparams.mirostat_eta = C.float(params.MirostatEta)
  505. cparams.penalize_nl = C.bool(params.PenalizeNl)
  506. cparams.seed = C.uint32_t(params.Seed)
  507. grammar := C.CString(params.Grammar)
  508. defer C.free(unsafe.Pointer(grammar))
  509. cparams.grammar = grammar
  510. context := &SamplingContext{c: C.gpt_sampler_cinit(model.c, &cparams)}
  511. runtime.SetFinalizer(context, func(s *SamplingContext) { C.gpt_sampler_cfree(s.c) })
  512. return context
  513. }
  514. func (s *SamplingContext) Reset() {
  515. C.gpt_sampler_creset(s.c)
  516. }
  517. func (s *SamplingContext) Sample(llamaContext *Context, idx int) int {
  518. return int(C.gpt_sampler_csample(s.c, llamaContext.c, C.int(idx)))
  519. }
  520. func (s *SamplingContext) Accept(id int, applyGrammar bool) {
  521. C.gpt_sampler_caccept(s.c, C.llama_token(id), C.bool(applyGrammar))
  522. }