llama.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608
  1. package llama
  2. /*
  3. #cgo CFLAGS: -O2 -std=c11 -DGGML_BUILD=1 -DNDEBUG -DLOG_DISABLE_LOGS -DGGML_USE_LLAMAFILE
  4. #cgo CXXFLAGS: -O2 -std=c++11 -DGGML_BUILD=1 -DNDEBUG -DLOG_DISABLE_LOGS -DGGML_USE_LLAMAFILE
  5. #cgo amd64,avx CFLAGS: -mavx
  6. #cgo amd64,avx CXXFLAGS: -mavx
  7. #cgo amd64,avx2 CFLAGS: -mavx2 -mfma
  8. #cgo amd64,avx2 CXXFLAGS: -mavx2 -mfma
  9. #cgo amd64,f16c CFLAGS: -mf16c
  10. #cgo amd64,f16c CXXFLAGS: -mf16c
  11. #cgo amd64,fma CFLAGS: -mfma
  12. #cgo amd64,fma CXXFLAGS: -mfma
  13. #cgo avx CFLAGS: -mavx
  14. #cgo avx CXXFLAGS: -mavx
  15. #cgo avx2 CFLAGS: -mavx2 -mfma -mf16c
  16. #cgo avx2 CXXFLAGS: -mavx2 -mfma -mf16c
  17. #cgo cuda CFLAGS: -fPIE -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  18. #cgo cuda CFLAGS: -fPIE -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  19. #cgo cuda CXXFLAGS: -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  20. #cgo cuda CXXFLAGS: -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  21. #cgo cuda_v11 LDFLAGS: -lggml_cuda_v11 -L/usr/local/cuda-11/lib64
  22. #cgo cuda_v12 LDFLAGS: -lggml_cuda_v12 -L/usr/local/cuda-12/lib64
  23. #cgo darwin,amd64 CFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
  24. #cgo darwin,amd64 CXXFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
  25. #cgo darwin,amd64 LDFLAGS: -framework Foundation
  26. #cgo darwin,amd64,avx2 CFLAGS: -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
  27. #cgo darwin,amd64,avx2 CXXFLAGS: -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
  28. #cgo darwin,amd64,avx2 LDFLAGS: -framework Accelerate
  29. #cgo darwin,arm64 CFLAGS: -DGGML_USE_METAL -DGGML_USE_ACCELERATE -DGGML_METAL_EMBED_LIBRARY -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64 -DGGML_USE_BLAS
  30. #cgo darwin,arm64 CXXFLAGS: -DGGML_USE_METAL -DGGML_USE_ACCELERATE -DGGML_METAL_EMBED_LIBRARY -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64 -DGGML_USE_BLAS
  31. #cgo darwin,arm64 LDFLAGS: -framework Foundation -framework Metal -framework MetalKit -framework Accelerate
  32. #cgo linux CFLAGS: -D_GNU_SOURCE
  33. #cgo linux CXXFLAGS: -D_GNU_SOURCE
  34. #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/Linux/amd64
  35. #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/Linux/amd64
  36. #cgo linux,arm64 CFLAGS: -D__aarch64__ -D__ARM_NEON -D__ARM_FEATURE_FMA -D__ARM_FEATURE_MATMUL_INT8
  37. #cgo linux,arm64 CXXFLAGS: -D__aarch64__ -D__ARM_NEON -D__ARM_FEATURE_FMA -D__ARM_FEATURE_MATMUL_INT8
  38. #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/Linux/arm64
  39. #cgo linux,arm64,sve CFLAGS: -march=armv8.6-a+sve
  40. #cgo linux,arm64,sve CXXFLAGS: -march=armv8.6-a+sve
  41. #cgo linux,cuda LDFLAGS: -lcuda -lcudart -lcublas -lcublasLt -lpthread -ldl -lrt -lresolv
  42. #cgo linux,rocm LDFLAGS: -L/opt/rocm/lib -lpthread -ldl -lrt -lresolv
  43. #cgo rocm CFLAGS: -DGGML_USE_CUDA -DGGML_USE_HIPBLAS -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  44. #cgo rocm CXXFLAGS: -DGGML_USE_CUDA -DGGML_USE_HIPBLAS -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  45. #cgo rocm LDFLAGS: -L${SRCDIR} -lggml_rocm -lhipblas -lamdhip64 -lrocblas
  46. #cgo windows CFLAGS: -Wno-discarded-qualifiers -D_WIN32_WINNT=0x602
  47. #cgo windows CXXFLAGS: -D_WIN32_WINNT=0x602
  48. #cgo windows LDFLAGS: -lmsvcrt
  49. #cgo windows LDFLAGS: -lmsvcrt -static-libstdc++ -static-libgcc -static
  50. #cgo windows,amd64 LDFLAGS: -L${SRCDIR}/build/Windows/amd64
  51. #cgo windows,amd64 LDFLAGS: -L${SRCDIR}/build/Windows/amd64
  52. #cgo windows,arm64 CFLAGS: -D__aarch64__ -D__ARM_NEON -D__ARM_FEATURE_FMA
  53. #cgo windows,arm64 CXXFLAGS: -D__aarch64__ -D__ARM_NEON -D__ARM_FEATURE_FMA
  54. #cgo windows,arm64 LDFLAGS: -L${SRCDIR}/build/Windows/arm64
  55. #cgo windows,arm64 LDFLAGS: -L${SRCDIR}/build/Windows/arm64
  56. #cgo windows,cuda LDFLAGS: -lcuda -lcudart -lcublas -lcublasLt
  57. #cgo windows,rocm LDFLAGS: -lggml_rocm -lhipblas -lamdhip64 -lrocblas
  58. #include <stdlib.h>
  59. #include "llama.h"
  60. #include "clip.h"
  61. #include "ggml.h"
  62. #include "llava.h"
  63. #include "mllama.h"
  64. #include "sampling_ext.h"
  65. bool llamaProgressCallback(float progress, void *user_data);
  66. */
  67. import "C"
  68. import (
  69. _ "embed"
  70. "errors"
  71. "fmt"
  72. "runtime"
  73. "runtime/cgo"
  74. "strings"
  75. "unsafe"
  76. )
  77. var CpuFeatures = ""
  78. func BackendInit() {
  79. C.llama_backend_init()
  80. }
  81. func PrintSystemInfo() string {
  82. return C.GoString(C.llama_print_system_info())
  83. }
  84. type ContextParams struct {
  85. c C.struct_llama_context_params
  86. }
  87. func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool) ContextParams {
  88. params := C.llama_context_default_params()
  89. params.n_ctx = C.uint(numCtx)
  90. params.n_batch = C.uint(batchSize)
  91. params.n_seq_max = C.uint(numSeqMax)
  92. params.n_threads = C.int(threads)
  93. params.n_threads_batch = params.n_threads
  94. params.embeddings = C.bool(true)
  95. params.flash_attn = C.bool(flashAttention)
  96. return ContextParams{c: params}
  97. }
  98. type Context struct {
  99. c *C.struct_llama_context
  100. numThreads int
  101. }
  102. func (c *Context) KvCacheClear() {
  103. C.llama_kv_cache_clear(c.c)
  104. }
  105. func (c *Context) Decode(batch *Batch) error {
  106. // Positive return values does not mean a fatal error, but rather a warning.
  107. // 0 - success
  108. // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
  109. // < 0 - error
  110. code := int(C.llama_decode(c.c, batch.c))
  111. if code < 0 {
  112. return fmt.Errorf("llama_decode failed with code %d", code)
  113. }
  114. if code > 0 {
  115. return fmt.Errorf("could not find a KV slot for the batch - try reducing the size of the batch or increase the context. code: %d", code)
  116. }
  117. return nil
  118. }
  119. func (c *Context) Model() *Model {
  120. return &Model{c: C.llama_get_model(c.c)}
  121. }
  122. func (c *Context) KvCacheSeqAdd(seqId int, p0 int, p1 int, delta int) {
  123. C.llama_kv_cache_seq_add(c.c, C.int(seqId), C.int(p0), C.int(p1), C.int(delta))
  124. }
  125. func (c *Context) KvCacheSeqRm(seqId int, p0 int, p1 int) bool {
  126. return bool(C.llama_kv_cache_seq_rm(c.c, C.int(seqId), C.int(p0), C.int(p1)))
  127. }
  128. func (c *Context) KvCacheSeqCp(srcSeqId int, dstSeqId int, p0 int, p1 int) {
  129. C.llama_kv_cache_seq_cp(c.c, C.int(srcSeqId), C.int(dstSeqId), C.int(p0), C.int(p1))
  130. }
  131. // Get the embeddings for a sequence id
  132. func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
  133. embeddings := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
  134. if embeddings == nil {
  135. return nil
  136. }
  137. return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
  138. }
  139. func (c *Context) GetEmbeddingsIth(i int) []float32 {
  140. embeddings := unsafe.Pointer(C.llama_get_embeddings_ith(c.c, C.int32_t(i)))
  141. if embeddings == nil {
  142. return nil
  143. }
  144. return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
  145. }
  146. type ModelParams struct {
  147. NumGpuLayers int
  148. MainGpu int
  149. UseMmap bool
  150. UseMlock bool
  151. TensorSplit []float32
  152. Progress func(float32)
  153. VocabOnly bool
  154. }
  155. //export llamaProgressCallback
  156. func llamaProgressCallback(progress C.float, userData unsafe.Pointer) C.bool {
  157. handle := *(*cgo.Handle)(userData)
  158. callback := handle.Value().(func(float32))
  159. callback(float32(progress))
  160. return true
  161. }
  162. func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
  163. cparams := C.llama_model_default_params()
  164. cparams.n_gpu_layers = C.int(params.NumGpuLayers)
  165. cparams.main_gpu = C.int32_t(params.MainGpu)
  166. cparams.use_mmap = C.bool(params.UseMmap)
  167. cparams.use_mlock = C.bool(params.UseMlock)
  168. cparams.vocab_only = C.bool(params.VocabOnly)
  169. if len(params.TensorSplit) > 0 {
  170. tensorSplitData := &params.TensorSplit[0]
  171. var tensorSplitPin runtime.Pinner
  172. tensorSplitPin.Pin(tensorSplitData)
  173. defer tensorSplitPin.Unpin()
  174. cparams.tensor_split = (*C.float)(unsafe.Pointer(tensorSplitData))
  175. }
  176. if params.Progress != nil {
  177. handle := cgo.NewHandle(params.Progress)
  178. defer handle.Delete()
  179. var handlePin runtime.Pinner
  180. handlePin.Pin(&handle)
  181. defer handlePin.Unpin()
  182. cparams.progress_callback = C.llama_progress_callback(C.llamaProgressCallback)
  183. cparams.progress_callback_user_data = unsafe.Pointer(&handle)
  184. }
  185. m := Model{c: C.llama_load_model_from_file(C.CString(modelPath), cparams)}
  186. if m.c == (*C.struct_llama_model)(C.NULL) {
  187. return nil, fmt.Errorf("unable to load model: %s", modelPath)
  188. }
  189. return &m, nil
  190. }
  191. func FreeModel(model *Model) {
  192. C.llama_free_model(model.c)
  193. }
  194. func NewContextWithModel(model *Model, params ContextParams) (*Context, error) {
  195. c := Context{
  196. c: C.llama_new_context_with_model(model.c, params.c),
  197. numThreads: int(params.c.n_threads),
  198. }
  199. if c.c == (*C.struct_llama_context)(C.NULL) {
  200. return nil, errors.New("unable to create llama context")
  201. }
  202. return &c, nil
  203. }
  204. func (m *Model) NumVocab() int {
  205. return int(C.llama_n_vocab(m.c))
  206. }
  207. func (m *Model) TokenIsEog(token int) bool {
  208. return bool(C.llama_token_is_eog(m.c, C.llama_token(token)))
  209. }
  210. func (m *Model) AddBOSToken() bool {
  211. return bool(C.llama_add_bos_token(m.c))
  212. }
  213. func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float32, threads int) error {
  214. cLoraPath := C.CString(loraPath)
  215. defer C.free(unsafe.Pointer(cLoraPath))
  216. loraAdapter := C.llama_lora_adapter_init(m.c, cLoraPath)
  217. err := -1
  218. if loraAdapter != nil {
  219. err = int(C.llama_lora_adapter_set(context.c, loraAdapter, C.float(scale)))
  220. }
  221. if err != 0 {
  222. return errors.New("error applying lora from file")
  223. }
  224. return nil
  225. }
  226. type Batch struct {
  227. c C.struct_llama_batch
  228. batchSize int
  229. embedSize int
  230. }
  231. // Creates a new batch for either word tokens if embed is 0 or
  232. // image embeddings if embed is specified. Batches cannot contain
  233. // both types at the same time
  234. func NewBatch(nTokens int, embed int, maxSeq int) *Batch {
  235. return &Batch{
  236. c: C.llama_batch_init(C.int(nTokens), C.int(embed), C.int(maxSeq)),
  237. batchSize: nTokens,
  238. embedSize: embed,
  239. }
  240. }
  241. func (b *Batch) NumTokens() int {
  242. return int(b.c.n_tokens)
  243. }
  244. func (b *Batch) IsEmbedding() bool {
  245. return b.embedSize != 0
  246. }
  247. // Add adds either a token or an image embedding to the batch depending on the type
  248. // when the batch was initialized. The other argument will be ignored. Adds to the
  249. // batch with the given position for the given sequence ids, and optionally instructs
  250. // to include logits.
  251. func (b *Batch) Add(token int, embed []float32, pos int, seqIds []int, logits bool) {
  252. if !b.IsEmbedding() {
  253. unsafe.Slice(b.c.token, b.batchSize)[b.c.n_tokens] = C.llama_token(token)
  254. } else {
  255. copy(unsafe.Slice((*float32)(b.c.embd), b.batchSize*b.embedSize)[int(b.c.n_tokens)*b.embedSize:], embed)
  256. }
  257. unsafe.Slice(b.c.pos, b.batchSize)[b.c.n_tokens] = C.llama_pos(pos)
  258. unsafe.Slice(b.c.n_seq_id, b.batchSize)[b.c.n_tokens] = C.int(len(seqIds))
  259. for i, s := range seqIds {
  260. unsafe.Slice((unsafe.Slice(b.c.seq_id, b.batchSize)[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
  261. }
  262. if logits {
  263. unsafe.Slice(b.c.logits, b.batchSize)[b.c.n_tokens] = 1
  264. }
  265. b.c.n_tokens += 1
  266. }
  267. func (b *Batch) Clear() {
  268. b.c.n_tokens = 0
  269. }
  270. func (b *Batch) Free() {
  271. b.batchSize = 0
  272. C.llama_batch_free(b.c)
  273. }
  274. type Model struct {
  275. c *C.struct_llama_model
  276. }
  277. func (m *Model) TokenToPiece(token int) string {
  278. tokenLen := 12
  279. buf := make([]byte, tokenLen)
  280. tokenLen = int(C.llama_token_to_piece(
  281. m.c,
  282. C.int32_t(token),
  283. (*C.char)(unsafe.Pointer(&buf[0])),
  284. C.int32_t(tokenLen),
  285. C.int32_t(0),
  286. C.bool(true),
  287. ))
  288. if tokenLen < 0 {
  289. tokenLen = -tokenLen
  290. buf = make([]byte, tokenLen)
  291. C.llama_token_to_piece(
  292. m.c,
  293. C.int32_t(token),
  294. (*C.char)(unsafe.Pointer(&buf[0])),
  295. C.int32_t(tokenLen),
  296. C.int32_t(0),
  297. C.bool(true),
  298. )
  299. }
  300. return strings.TrimRight(string(buf), "\x00")
  301. }
  302. func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int, error) {
  303. maxTokens := len(text) + 2
  304. cTokens := make([]C.llama_token, maxTokens)
  305. cText := C.CString(text)
  306. defer C.free(unsafe.Pointer(cText))
  307. result := C.llama_tokenize(
  308. m.c,
  309. cText,
  310. C.int32_t(len(text)),
  311. &cTokens[0],
  312. C.int32_t(maxTokens),
  313. C.bool(addSpecial),
  314. C.bool(parseSpecial),
  315. )
  316. // if the result is negative, reallocate and retry with the correct buffer size
  317. if result < 0 {
  318. maxTokens = int(-result)
  319. cTokens = make([]C.llama_token, maxTokens)
  320. result = C.llama_tokenize(
  321. m.c,
  322. cText,
  323. C.int32_t(len(text)),
  324. &cTokens[0],
  325. C.int32_t(maxTokens),
  326. C.bool(addSpecial),
  327. C.bool(parseSpecial),
  328. )
  329. if result < 0 {
  330. return nil, fmt.Errorf("tokenization failed, required %d tokens", -result)
  331. }
  332. }
  333. tokens := make([]int, result)
  334. for i := range result {
  335. tokens[i] = int(cTokens[i])
  336. }
  337. return tokens, nil
  338. }
  339. func (m *Model) NEmbd() int {
  340. return int(C.llama_n_embd(m.c))
  341. }
  342. func Quantize(infile, outfile string, ftype uint32) error {
  343. cinfile := C.CString(infile)
  344. defer C.free(unsafe.Pointer(cinfile))
  345. coutfile := C.CString(outfile)
  346. defer C.free(unsafe.Pointer(coutfile))
  347. params := C.llama_model_quantize_default_params()
  348. params.nthread = -1
  349. params.ftype = ftype
  350. if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
  351. return fmt.Errorf("llama_model_quantize: %d", rc)
  352. }
  353. return nil
  354. }
  355. // llava
  356. type ClipContext struct {
  357. c *C.struct_clip_ctx
  358. m *C.struct_mllama_ctx
  359. IsMllama bool
  360. embedPin runtime.Pinner
  361. pinned bool
  362. }
  363. func getVisionArch(mp *C.char) (string, error) {
  364. gguf_ctx := C.gguf_init_from_file(mp, C.struct_gguf_init_params{no_alloc: true, ctx: (**C.struct_ggml_context)(C.NULL)})
  365. if gguf_ctx == nil {
  366. return "", errors.New("unable to load vision projector")
  367. }
  368. defer C.gguf_free(gguf_ctx)
  369. arch_index := C.gguf_find_key(gguf_ctx, C.CString("general.architecture"))
  370. if int(arch_index) < 0 {
  371. return "", errors.New("unknown vision model architecture")
  372. }
  373. arch := C.gguf_get_val_str(gguf_ctx, arch_index)
  374. return C.GoString(arch), nil
  375. }
  376. func NewClipContext(modelPath string) (*ClipContext, error) {
  377. mp := C.CString(modelPath)
  378. defer C.free(unsafe.Pointer(mp))
  379. arch, err := getVisionArch(mp)
  380. if err != nil {
  381. return nil, err
  382. }
  383. var cc ClipContext
  384. if arch == "clip" {
  385. cc.c = C.clip_model_load(mp, 1)
  386. } else if arch == "mllama" {
  387. cc.m = C.mllama_model_load(mp, 1)
  388. cc.IsMllama = true
  389. } else {
  390. return nil, fmt.Errorf("unknown vision model architecture: %s", arch)
  391. }
  392. // XXX: check embedding size?
  393. return &cc, nil
  394. }
  395. func (c *ClipContext) Free() {
  396. if c.c != nil {
  397. C.clip_free(c.c)
  398. }
  399. if c.m != nil {
  400. C.mllama_free(c.m)
  401. }
  402. }
  403. func NewLlavaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []byte) [][]float32 {
  404. c := C.llava_image_embed_make_with_bytes(clipContext.c, C.int(llamaContext.numThreads), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data)))
  405. numTokens := int(c.n_image_pos)
  406. numEmbed := llamaContext.Model().NEmbd()
  407. s := unsafe.Slice((*float32)(c.embed), numEmbed*numTokens)
  408. embed := make([][]float32, numTokens)
  409. rows := make([]float32, len(s))
  410. copy(rows, s)
  411. for i := range embed {
  412. embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
  413. }
  414. C.llava_image_embed_free(c)
  415. return embed
  416. }
  417. func NewMllamaImageEmbed(llamaContext *Context, clipContext *ClipContext, data []byte, aspectRatioId int) [][]float32 {
  418. img := C.mllama_image_init()
  419. defer C.mllama_image_free(img)
  420. C.mllama_image_load_from_data(unsafe.Pointer(&data[0]), C.int(len(data)), 560, 560, 3, 4, C.int(aspectRatioId), img)
  421. numTokens := int(C.mllama_n_positions(clipContext.m) * C.mllama_n_tiles(clipContext.m))
  422. numEmbed := llamaContext.Model().NEmbd()
  423. rows := make([]float32, numEmbed*numTokens)
  424. C.mllama_image_encode(clipContext.m, C.int(llamaContext.numThreads), img, (*C.float)(unsafe.Pointer(&rows[0])))
  425. embed := make([][]float32, numTokens)
  426. for i := range embed {
  427. embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
  428. }
  429. return embed
  430. }
  431. // This really needs to be set on a batch instead
  432. func MllamaSetCrossAttn(llamaContext *Context, clipContext *ClipContext, embed [][]float32) {
  433. if embed != nil {
  434. if clipContext.pinned {
  435. panic("Cross attention state already pinned")
  436. }
  437. embedData := &embed[0][0]
  438. clipContext.embedPin.Pin(embedData)
  439. clipContext.pinned = true
  440. C.llama_set_cross_attn_state(llamaContext.c, (*C.float)(unsafe.Pointer(embedData)))
  441. } else {
  442. C.llama_set_cross_attn_state(llamaContext.c, (*C.float)(C.NULL))
  443. if clipContext.pinned {
  444. clipContext.embedPin.Unpin()
  445. clipContext.pinned = false
  446. }
  447. }
  448. }
  449. // sampling
  450. // TODO: this is a temporary wrapper to allow calling C++ code from CGo
  451. type SamplingContext struct {
  452. c *C.struct_gpt_sampler
  453. }
  454. type SamplingParams struct {
  455. TopK int
  456. TopP float32
  457. MinP float32
  458. TfsZ float32
  459. TypicalP float32
  460. Temp float32
  461. RepeatLastN int
  462. PenaltyRepeat float32
  463. PenaltyFreq float32
  464. PenaltyPresent float32
  465. Mirostat int
  466. MirostatTau float32
  467. MirostatEta float32
  468. PenalizeNl bool
  469. Seed uint32
  470. Grammar string
  471. }
  472. func NewSamplingContext(model *Model, params SamplingParams) *SamplingContext {
  473. var cparams C.struct_gpt_sampler_cparams
  474. cparams.top_k = C.int32_t(params.TopK)
  475. cparams.top_p = C.float(params.TopP)
  476. cparams.min_p = C.float(params.MinP)
  477. cparams.tfs_z = C.float(params.TfsZ)
  478. cparams.typical_p = C.float(params.TypicalP)
  479. cparams.temp = C.float(params.Temp)
  480. cparams.penalty_last_n = C.int32_t(params.RepeatLastN)
  481. cparams.penalty_repeat = C.float(params.PenaltyRepeat)
  482. cparams.penalty_freq = C.float(params.PenaltyFreq)
  483. cparams.penalty_present = C.float(params.PenaltyFreq)
  484. cparams.mirostat = C.int32_t(params.Mirostat)
  485. cparams.mirostat_tau = C.float(params.MirostatTau)
  486. cparams.mirostat_eta = C.float(params.MirostatEta)
  487. cparams.penalize_nl = C.bool(params.PenalizeNl)
  488. cparams.seed = C.uint32_t(params.Seed)
  489. grammar := C.CString(params.Grammar)
  490. defer C.free(unsafe.Pointer(grammar))
  491. cparams.grammar = grammar
  492. context := &SamplingContext{c: C.gpt_sampler_cinit(model.c, &cparams)}
  493. runtime.SetFinalizer(context, func(s *SamplingContext) { C.gpt_sampler_cfree(s.c) })
  494. return context
  495. }
  496. func (s *SamplingContext) Reset() {
  497. C.gpt_sampler_creset(s.c)
  498. }
  499. func (s *SamplingContext) Sample(llamaContext *Context, idx int) int {
  500. return int(C.gpt_sampler_csample(s.c, llamaContext.c, C.int(idx)))
  501. }
  502. func (s *SamplingContext) Accept(id int, applyGrammar bool) {
  503. C.gpt_sampler_caccept(s.c, C.llama_token(id), C.bool(applyGrammar))
  504. }