llama.go 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739
  1. package llama
  2. /*
  3. #cgo CFLAGS: -std=c11
  4. #cgo CXXFLAGS: -std=c++17
  5. #cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/include
  6. #cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/common
  7. #cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/examples/llava
  8. #cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/src
  9. #cgo CPPFLAGS: -I${SRCDIR}/../ml/backend/ggml/ggml/include
  10. #include <stdlib.h>
  11. #include "ggml.h"
  12. #include "llama.h"
  13. #include "clip.h"
  14. #include "llava.h"
  15. #include "gguf.h"
  16. #include "mllama.h"
  17. #include "sampling_ext.h"
  18. extern bool llamaProgressCallback(float progress, void *user_data);
  19. extern void llamaLog(int level, char* text, void* user_data);
  20. */
  21. import "C"
  22. import (
  23. "context"
  24. _ "embed"
  25. "errors"
  26. "fmt"
  27. "log/slog"
  28. "os"
  29. "runtime"
  30. "runtime/cgo"
  31. "slices"
  32. "strings"
  33. "unsafe"
  34. _ "github.com/ollama/ollama/llama/llama.cpp/common"
  35. _ "github.com/ollama/ollama/llama/llama.cpp/examples/llava"
  36. _ "github.com/ollama/ollama/llama/llama.cpp/src"
  37. ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
  38. )
  39. func init() {
  40. C.llama_log_set(C.ggml_log_callback(C.llamaLog), nil)
  41. }
  42. //export llamaLog
  43. func llamaLog(level C.int, text *C.char, _ unsafe.Pointer) {
  44. // slog levels zeros INFO and are multiples of 4
  45. if slog.Default().Enabled(context.TODO(), slog.Level(int(level-C.GGML_LOG_LEVEL_INFO)*4)) {
  46. fmt.Fprint(os.Stderr, C.GoString(text))
  47. }
  48. }
  49. func BackendInit() {
  50. ggml.OnceLoad()
  51. C.llama_backend_init()
  52. }
  53. func GetModelArch(modelPath string) (string, error) {
  54. mp := C.CString(modelPath)
  55. defer C.free(unsafe.Pointer(mp))
  56. gguf_ctx := C.gguf_init_from_file(mp, C.struct_gguf_init_params{no_alloc: true, ctx: (**C.struct_ggml_context)(C.NULL)})
  57. if gguf_ctx == nil {
  58. return "", errors.New("unable to load model file")
  59. }
  60. defer C.gguf_free(gguf_ctx)
  61. key := C.CString("general.architecture")
  62. defer C.free(unsafe.Pointer(key))
  63. arch_index := C.gguf_find_key(gguf_ctx, key)
  64. if int(arch_index) < 0 {
  65. return "", errors.New("unknown model architecture")
  66. }
  67. arch := C.gguf_get_val_str(gguf_ctx, arch_index)
  68. return C.GoString(arch), nil
  69. }
  70. type ContextParams struct {
  71. c C.struct_llama_context_params
  72. }
  73. func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, kvCacheType string) ContextParams {
  74. params := C.llama_context_default_params()
  75. params.n_ctx = C.uint(numCtx)
  76. params.n_batch = C.uint(batchSize)
  77. params.n_seq_max = C.uint(numSeqMax)
  78. params.n_threads = C.int(threads)
  79. params.n_threads_batch = params.n_threads
  80. params.embeddings = C.bool(true)
  81. params.flash_attn = C.bool(flashAttention)
  82. params.type_k = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
  83. params.type_v = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
  84. return ContextParams{c: params}
  85. }
  86. // kvCacheTypeFromStr converts a string cache type to the corresponding GGML type value
  87. func kvCacheTypeFromStr(s string) C.enum_ggml_type {
  88. if s == "" {
  89. return C.GGML_TYPE_F16
  90. }
  91. switch s {
  92. case "q8_0":
  93. return C.GGML_TYPE_Q8_0
  94. case "q4_0":
  95. return C.GGML_TYPE_Q4_0
  96. default:
  97. return C.GGML_TYPE_F16
  98. }
  99. }
  100. type Context struct {
  101. c *C.struct_llama_context
  102. numThreads int
  103. }
  104. var ErrKvCacheFull = errors.New("could not find a kv cache slot")
  105. func (c *Context) Decode(batch *Batch) error {
  106. // Positive return values does not mean a fatal error, but rather a warning.
  107. // 0 - success
  108. // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
  109. // < 0 - error
  110. code := int(C.llama_decode(c.c, batch.c))
  111. if code < 0 {
  112. return fmt.Errorf("llama_decode failed with code %d", code)
  113. }
  114. if code > 0 {
  115. return ErrKvCacheFull
  116. }
  117. return nil
  118. }
  119. func (c *Context) Model() *Model {
  120. return &Model{c: C.llama_get_model(c.c)}
  121. }
  122. func (c *Context) KvCacheSeqAdd(seqId int, p0 int, p1 int, delta int) {
  123. C.llama_kv_cache_seq_add(c.c, C.int(seqId), C.int(p0), C.int(p1), C.int(delta))
  124. }
  125. func (c *Context) KvCacheSeqRm(seqId int, p0 int, p1 int) bool {
  126. return bool(C.llama_kv_cache_seq_rm(c.c, C.int(seqId), C.int(p0), C.int(p1)))
  127. }
  128. func (c *Context) KvCacheSeqCp(srcSeqId int, dstSeqId int, p0 int, p1 int) {
  129. C.llama_kv_cache_seq_cp(c.c, C.int(srcSeqId), C.int(dstSeqId), C.int(p0), C.int(p1))
  130. }
  131. func (c *Context) KvCacheClear() {
  132. C.llama_kv_cache_clear(c.c)
  133. }
  134. func (c *Context) KvCacheDefrag() {
  135. C.llama_kv_cache_defrag(c.c)
  136. }
  137. // Get the embeddings for a sequence id
  138. func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
  139. e := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
  140. if e == nil {
  141. return nil
  142. }
  143. embeddings := make([]float32, c.Model().NEmbd())
  144. _ = copy(embeddings, unsafe.Slice((*float32)(e), c.Model().NEmbd()))
  145. return embeddings
  146. }
  147. func (c *Context) GetEmbeddingsIth(i int) []float32 {
  148. e := unsafe.Pointer(C.llama_get_embeddings_ith(c.c, C.int32_t(i)))
  149. if e == nil {
  150. return nil
  151. }
  152. embeddings := make([]float32, c.Model().NEmbd())
  153. _ = copy(embeddings, unsafe.Slice((*float32)(e), c.Model().NEmbd()))
  154. return embeddings
  155. }
  156. type ModelParams struct {
  157. NumGpuLayers int
  158. MainGpu int
  159. UseMmap bool
  160. UseMlock bool
  161. TensorSplit []float32
  162. Progress func(float32)
  163. VocabOnly bool
  164. }
  165. //export llamaProgressCallback
  166. func llamaProgressCallback(progress C.float, userData unsafe.Pointer) C.bool {
  167. handle := *(*cgo.Handle)(userData)
  168. callback := handle.Value().(func(float32))
  169. callback(float32(progress))
  170. return true
  171. }
  172. func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
  173. cparams := C.llama_model_default_params()
  174. cparams.n_gpu_layers = C.int(params.NumGpuLayers)
  175. cparams.main_gpu = C.int32_t(params.MainGpu)
  176. cparams.use_mmap = C.bool(params.UseMmap)
  177. cparams.use_mlock = C.bool(params.UseMlock)
  178. cparams.vocab_only = C.bool(params.VocabOnly)
  179. if len(params.TensorSplit) > 0 {
  180. tensorSplitData := &params.TensorSplit[0]
  181. var tensorSplitPin runtime.Pinner
  182. tensorSplitPin.Pin(tensorSplitData)
  183. defer tensorSplitPin.Unpin()
  184. cparams.tensor_split = (*C.float)(unsafe.Pointer(tensorSplitData))
  185. }
  186. if params.Progress != nil {
  187. handle := cgo.NewHandle(params.Progress)
  188. defer handle.Delete()
  189. var handlePin runtime.Pinner
  190. handlePin.Pin(&handle)
  191. defer handlePin.Unpin()
  192. cparams.progress_callback = C.llama_progress_callback(C.llamaProgressCallback)
  193. cparams.progress_callback_user_data = unsafe.Pointer(&handle)
  194. }
  195. m := Model{c: C.llama_model_load_from_file(C.CString(modelPath), cparams)}
  196. if m.c == nil {
  197. return nil, fmt.Errorf("unable to load model: %s", modelPath)
  198. }
  199. return &m, nil
  200. }
  201. func LoadVocabFromFile(path string) (*Vocab, error) {
  202. mp := C.CString(path)
  203. defer C.free(unsafe.Pointer(mp))
  204. v := Vocab{c: C.llama_load_vocab_from_file(mp)}
  205. if v.c == nil {
  206. return nil, fmt.Errorf("unable to load vocab: %s", path)
  207. }
  208. return &v, nil
  209. }
  210. func FreeVocab(vocab *Vocab) {
  211. C.llama_free_vocab(vocab.c)
  212. }
  213. func FreeModel(model *Model) {
  214. C.llama_model_free(model.c)
  215. }
  216. func NewContextWithModel(model *Model, params ContextParams) (*Context, error) {
  217. c := Context{
  218. c: C.llama_init_from_model(model.c, params.c),
  219. numThreads: int(params.c.n_threads),
  220. }
  221. if c.c == nil {
  222. return nil, errors.New("unable to create llama context")
  223. }
  224. return &c, nil
  225. }
  226. func (m *Model) NumVocab() int {
  227. return int(C.llama_vocab_n_tokens(m.Vocab()))
  228. }
  229. func (m *Model) TokenIsEog(token int) bool {
  230. return bool(C.llama_vocab_is_eog(m.Vocab(), C.llama_token(token)))
  231. }
  232. func (m *Model) AddBOSToken() bool {
  233. return bool(C.llama_vocab_get_add_bos(m.Vocab()))
  234. }
  235. func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float32, threads int) error {
  236. cLoraPath := C.CString(loraPath)
  237. defer C.free(unsafe.Pointer(cLoraPath))
  238. loraAdapter := C.llama_adapter_lora_init(m.c, cLoraPath)
  239. if loraAdapter == nil {
  240. return errors.New("unable to load lora")
  241. }
  242. err := -1
  243. if loraAdapter != nil {
  244. err = int(C.llama_set_adapter_lora(context.c, loraAdapter, C.float(scale)))
  245. }
  246. if err != 0 {
  247. return errors.New("error applying lora from file")
  248. }
  249. return nil
  250. }
  251. type Vocab struct {
  252. c *C.struct_llama_vocab
  253. }
  254. func (m *Model) Vocab() *C.struct_llama_vocab {
  255. return C.llama_model_get_vocab(m.c)
  256. }
  257. type Batch struct {
  258. c C.struct_llama_batch
  259. batchSize int
  260. maxSeq int
  261. embedSize int
  262. }
  263. // Creates a new batch for either word tokens or image embeddings (if embedSize is non-zero).
  264. // Batches cannot contain both types at the same time. batchSize is the maximum number of entries
  265. // that can be added per sequence
  266. func NewBatch(batchSize int, maxSeq int, embedSize int) (*Batch, error) {
  267. b := Batch{
  268. c: C.llama_batch_init(C.int(batchSize*maxSeq), C.int(embedSize), C.int(maxSeq)),
  269. batchSize: batchSize,
  270. maxSeq: maxSeq,
  271. embedSize: embedSize,
  272. }
  273. // Check to see if any of the allocations in llama_batch_init() failed
  274. nilPointer := (embedSize == 0 && b.c.token == nil) || (embedSize != 0 && b.c.embd == nil) ||
  275. b.c.pos == nil || b.c.n_seq_id == nil || b.c.seq_id == nil || b.c.logits == nil ||
  276. slices.Contains(unsafe.Slice(b.c.seq_id, b.allocSize()), nil)
  277. if nilPointer {
  278. C.llama_batch_free(b.c)
  279. return nil, fmt.Errorf("unable to allocate batch (batchSize=%v maxSeq=%v embedSize=%v)", batchSize, maxSeq, embedSize)
  280. }
  281. return &b, nil
  282. }
  283. func (b *Batch) Size() int {
  284. return b.batchSize
  285. }
  286. func (b *Batch) allocSize() int {
  287. return b.batchSize * b.maxSeq
  288. }
  289. func (b *Batch) NumTokens() int {
  290. return int(b.c.n_tokens)
  291. }
  292. func (b *Batch) IsEmbedding() bool {
  293. return b.embedSize != 0
  294. }
  295. // Add adds either a token or an image embedding to the batch depending on the type
  296. // when the batch was initialized. The other argument will be ignored. Adds to the
  297. // batch with the given position for the given sequence ids, and optionally instructs
  298. // to include logits.
  299. func (b *Batch) Add(token int, embed []float32, pos int, logits bool, seqIds ...int) {
  300. if !b.IsEmbedding() {
  301. unsafe.Slice(b.c.token, b.allocSize())[b.c.n_tokens] = C.llama_token(token)
  302. } else {
  303. copy(unsafe.Slice((*float32)(b.c.embd), b.allocSize()*b.embedSize)[int(b.c.n_tokens)*b.embedSize:], embed)
  304. }
  305. unsafe.Slice(b.c.pos, b.allocSize())[b.c.n_tokens] = C.llama_pos(pos)
  306. unsafe.Slice(b.c.n_seq_id, b.allocSize())[b.c.n_tokens] = C.int(len(seqIds))
  307. for i, s := range seqIds {
  308. unsafe.Slice((unsafe.Slice(b.c.seq_id, b.allocSize())[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
  309. }
  310. if logits {
  311. unsafe.Slice(b.c.logits, b.allocSize())[b.c.n_tokens] = 1
  312. } else {
  313. unsafe.Slice(b.c.logits, b.allocSize())[b.c.n_tokens] = 0
  314. }
  315. b.c.n_tokens += 1
  316. }
  317. func (b *Batch) Clear() {
  318. b.c.n_tokens = 0
  319. }
  320. func (b *Batch) Free() {
  321. b.batchSize = 0
  322. C.llama_batch_free(b.c)
  323. }
  324. type Model struct {
  325. c *C.struct_llama_model
  326. }
  327. func (m *Model) TokenToPiece(token int) string {
  328. tokenLen := 12
  329. buf := make([]byte, tokenLen)
  330. tokenLen = int(C.llama_token_to_piece(
  331. m.Vocab(),
  332. C.int32_t(token),
  333. (*C.char)(unsafe.Pointer(&buf[0])),
  334. C.int32_t(tokenLen),
  335. C.int32_t(0),
  336. C.bool(true),
  337. ))
  338. if tokenLen < 0 {
  339. tokenLen = -tokenLen
  340. buf = make([]byte, tokenLen)
  341. C.llama_token_to_piece(
  342. m.Vocab(),
  343. C.int32_t(token),
  344. (*C.char)(unsafe.Pointer(&buf[0])),
  345. C.int32_t(tokenLen),
  346. C.int32_t(0),
  347. C.bool(true),
  348. )
  349. }
  350. return strings.TrimRight(string(buf), "\x00")
  351. }
  352. func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int, error) {
  353. maxTokens := len(text) + 2
  354. cTokens := make([]C.llama_token, maxTokens)
  355. cText := C.CString(text)
  356. defer C.free(unsafe.Pointer(cText))
  357. result := C.llama_tokenize(
  358. m.Vocab(),
  359. cText,
  360. C.int32_t(len(text)),
  361. &cTokens[0],
  362. C.int32_t(maxTokens),
  363. C.bool(addSpecial),
  364. C.bool(parseSpecial),
  365. )
  366. // if the result is negative, reallocate and retry with the correct buffer size
  367. if result < 0 {
  368. maxTokens = int(-result)
  369. cTokens = make([]C.llama_token, maxTokens)
  370. result = C.llama_tokenize(
  371. m.Vocab(),
  372. cText,
  373. C.int32_t(len(text)),
  374. &cTokens[0],
  375. C.int32_t(maxTokens),
  376. C.bool(addSpecial),
  377. C.bool(parseSpecial),
  378. )
  379. if result < 0 {
  380. return nil, fmt.Errorf("tokenization failed, required %d tokens", -result)
  381. }
  382. }
  383. tokens := make([]int, result)
  384. for i := range result {
  385. tokens[i] = int(cTokens[i])
  386. }
  387. return tokens, nil
  388. }
  389. func (m *Model) NEmbd() int {
  390. return int(C.llama_model_n_embd(m.c))
  391. }
  392. func Quantize(infile, outfile string, ftype uint32) error {
  393. cinfile := C.CString(infile)
  394. defer C.free(unsafe.Pointer(cinfile))
  395. coutfile := C.CString(outfile)
  396. defer C.free(unsafe.Pointer(coutfile))
  397. params := C.llama_model_quantize_default_params()
  398. params.nthread = -1
  399. params.ftype = ftype
  400. if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
  401. return fmt.Errorf("llama_model_quantize: %d", rc)
  402. }
  403. return nil
  404. }
  405. // vision processing
  406. type ClipContext struct {
  407. c *C.struct_clip_ctx
  408. }
  409. func NewClipContext(llamaContext *Context, modelPath string) (*ClipContext, error) {
  410. mp := C.CString(modelPath)
  411. defer C.free(unsafe.Pointer(mp))
  412. c := C.clip_model_load(mp, 1)
  413. if c == nil {
  414. return nil, fmt.Errorf("unable to load clip model: %v", modelPath)
  415. }
  416. projEmbedSize := int(C.clip_n_mmproj_embd(c))
  417. modelEmbedSize := llamaContext.Model().NEmbd()
  418. if projEmbedSize != modelEmbedSize {
  419. return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
  420. }
  421. return &ClipContext{c: c}, nil
  422. }
  423. func (c *ClipContext) Free() {
  424. C.clip_free(c.c)
  425. }
  426. func (c *ClipContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32, error) {
  427. l := C.llava_image_embed_make_with_bytes(c.c, C.int(llamaContext.numThreads), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data)))
  428. if l == nil {
  429. return nil, errors.New("unable to make llava embedding from image")
  430. }
  431. numTokens := int(l.n_image_pos)
  432. numEmbed := llamaContext.Model().NEmbd()
  433. s := unsafe.Slice((*float32)(l.embed), numEmbed*numTokens)
  434. embed := make([][]float32, numTokens)
  435. rows := make([]float32, len(s))
  436. copy(rows, s)
  437. for i := range embed {
  438. embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
  439. }
  440. C.llava_image_embed_free(l)
  441. return embed, nil
  442. }
  443. type MllamaContext struct {
  444. c *C.struct_mllama_ctx
  445. }
  446. func NewMllamaContext(llamaContext *Context, modelPath string) (*MllamaContext, error) {
  447. mp := C.CString(modelPath)
  448. defer C.free(unsafe.Pointer(mp))
  449. c := C.mllama_model_load(mp, 1)
  450. if c == nil {
  451. return nil, fmt.Errorf("unable to load mllama model: %v", modelPath)
  452. }
  453. projEmbedSize := int(C.mllama_n_embd(c))
  454. modelEmbedSize := llamaContext.Model().NEmbd()
  455. if projEmbedSize != modelEmbedSize {
  456. return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
  457. }
  458. return &MllamaContext{c: c}, nil
  459. }
  460. func (m *MllamaContext) Free() {
  461. C.mllama_free(m.c)
  462. }
  463. func (m *MllamaContext) NewEmbed(llamaContext *Context, data []byte, aspectRatioId int) ([][]float32, error) {
  464. img := C.mllama_image_init()
  465. defer C.mllama_image_free(img)
  466. ok := bool(C.mllama_image_load_from_data(unsafe.Pointer(&data[0]), C.int(len(data)), 560, 560, 3, 4, C.int(aspectRatioId), img))
  467. if !ok {
  468. return nil, errors.New("unable to load mllama image data")
  469. }
  470. rows := make([]float32, m.EmbedSize(llamaContext))
  471. ok = bool(C.mllama_image_encode(m.c, C.int(llamaContext.numThreads), img, (*C.float)(unsafe.Pointer(&rows[0]))))
  472. if !ok {
  473. return nil, errors.New("unable to make mllama embedding from image")
  474. }
  475. embed := make([][]float32, 1)
  476. embed[0] = rows
  477. return embed, nil
  478. }
  479. func (m *MllamaContext) EmbedSize(llamaContext *Context) int {
  480. numTokens := int(C.mllama_n_positions(m.c) * C.mllama_n_tiles(m.c))
  481. numEmbed := llamaContext.Model().NEmbd()
  482. return numTokens * numEmbed
  483. }
  484. func (c *Context) SetCrossAttention(state bool) {
  485. C.llama_set_cross_attention(c.c, C.bool(state))
  486. }
  487. func (c *Context) Synchronize() {
  488. C.llama_synchronize(c.c)
  489. }
  490. // sampling
  491. // TODO: this is a temporary wrapper to allow calling C++ code from CGo
  492. type SamplingContext struct {
  493. c *C.struct_common_sampler
  494. }
  495. type SamplingParams struct {
  496. TopK int
  497. TopP float32
  498. MinP float32
  499. TypicalP float32
  500. Temp float32
  501. RepeatLastN int
  502. PenaltyRepeat float32
  503. PenaltyFreq float32
  504. PenaltyPresent float32
  505. Mirostat int
  506. MirostatTau float32
  507. MirostatEta float32
  508. PenalizeNl bool
  509. Seed uint32
  510. Grammar string
  511. }
  512. func NewSamplingContext(model *Model, params SamplingParams) (*SamplingContext, error) {
  513. var cparams C.struct_common_sampler_cparams
  514. cparams.top_k = C.int32_t(params.TopK)
  515. cparams.top_p = C.float(params.TopP)
  516. cparams.min_p = C.float(params.MinP)
  517. cparams.typical_p = C.float(params.TypicalP)
  518. cparams.temp = C.float(params.Temp)
  519. cparams.penalty_last_n = C.int32_t(params.RepeatLastN)
  520. cparams.penalty_repeat = C.float(params.PenaltyRepeat)
  521. cparams.penalty_freq = C.float(params.PenaltyFreq)
  522. cparams.penalty_present = C.float(params.PenaltyFreq)
  523. cparams.mirostat = C.int32_t(params.Mirostat)
  524. cparams.mirostat_tau = C.float(params.MirostatTau)
  525. cparams.mirostat_eta = C.float(params.MirostatEta)
  526. cparams.seed = C.uint32_t(params.Seed)
  527. grammar := C.CString(params.Grammar)
  528. defer C.free(unsafe.Pointer(grammar))
  529. cparams.grammar = grammar
  530. context := &SamplingContext{c: C.common_sampler_cinit(model.c, &cparams)}
  531. if context.c == nil {
  532. return nil, errors.New("unable to create sampling context")
  533. }
  534. runtime.SetFinalizer(context, func(s *SamplingContext) { C.common_sampler_cfree(s.c) })
  535. return context, nil
  536. }
  537. func (s *SamplingContext) Reset() {
  538. C.common_sampler_creset(s.c)
  539. }
  540. func (s *SamplingContext) Sample(llamaContext *Context, idx int) int {
  541. return int(C.common_sampler_csample(s.c, llamaContext.c, C.int(idx)))
  542. }
  543. func (s *SamplingContext) Accept(id int, applyGrammar bool) {
  544. C.common_sampler_caccept(s.c, C.llama_token(id), C.bool(applyGrammar))
  545. }
  546. // SchemaToGrammar converts the provided JSON schema to a grammar. It returns
  547. // nil if the provided schema is invalid JSON or an invalid JSON schema.
  548. func SchemaToGrammar(schema []byte) []byte {
  549. cStr := C.CString(string(schema))
  550. defer C.free(unsafe.Pointer(cStr))
  551. // Allocate buffer for grammar output with reasonable size
  552. const maxLen = 32768 // 32KB
  553. buf := make([]byte, maxLen)
  554. // Call C function to convert schema to grammar
  555. n := C.schema_to_grammar(cStr, (*C.char)(unsafe.Pointer(&buf[0])), C.size_t(maxLen))
  556. if n == 0 {
  557. // preserve nil
  558. return nil
  559. }
  560. return buf[:n]
  561. }
  562. type Sampler struct {
  563. c *C.struct_llama_sampler
  564. }
  565. func NewGrammarSampler(vocab *Vocab, grammar string) *Sampler {
  566. cGrammar := C.CString(grammar)
  567. cRoot := C.CString("root")
  568. defer C.free(unsafe.Pointer(cGrammar))
  569. defer C.free(unsafe.Pointer(cRoot))
  570. sampler := &Sampler{c: C.llama_sampler_init_grammar(vocab.c, cGrammar, cRoot)}
  571. return sampler
  572. }
  573. func (s *Sampler) Accept(token int32) {
  574. C.llama_sampler_accept(s.c, C.llama_token(token))
  575. }
  576. type TokenData struct {
  577. Id int32
  578. Logit float32
  579. }
  580. func (s *Sampler) Apply(tokens []TokenData) {
  581. tds := make([]C.struct_llama_token_data, len(tokens))
  582. for i, token := range tokens {
  583. tds[i] = C.struct_llama_token_data{
  584. id: C.int32_t(token.Id),
  585. logit: C.float(token.Logit),
  586. p: C.float(0.0),
  587. }
  588. }
  589. tda := &C.llama_token_data_array{
  590. data: (*C.struct_llama_token_data)(unsafe.Pointer(&tds[0])),
  591. size: C.size_t(len(tokens)),
  592. selected: C.int64_t(-1),
  593. sorted: C.bool(false),
  594. }
  595. var pinner runtime.Pinner
  596. pinner.Pin(&tds[0])
  597. defer pinner.Unpin()
  598. C.llama_sampler_apply(s.c, tda)
  599. for i := range tokens {
  600. tokens[i].Logit = float32(tds[i].logit)
  601. }
  602. }