llama.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696
  1. package llama
  2. /*
  3. #cgo CFLAGS: -std=c11
  4. #cgo CXXFLAGS: -std=c++17
  5. #cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/include
  6. #cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/common
  7. #cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/examples/llava
  8. #cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/src
  9. #cgo CPPFLAGS: -I${SRCDIR}/../ml/backend/ggml/ggml/include
  10. #include <stdlib.h>
  11. #include "ggml.h"
  12. #include "llama.h"
  13. #include "clip.h"
  14. #include "llava.h"
  15. #include "gguf.h"
  16. #include "mllama.h"
  17. #include "sampling_ext.h"
  18. extern bool llamaProgressCallback(float progress, void *user_data);
  19. extern void llamaLog(int level, char* text, void* user_data);
  20. typedef enum {COMP_UNKNOWN,COMP_GCC,COMP_CLANG} COMPILER;
  21. COMPILER inline get_compiler() {
  22. #if defined(__clang__)
  23. return COMP_CLANG;
  24. #elif defined(__GNUC__)
  25. return COMP_GCC;
  26. #else
  27. return UNKNOWN_COMPILER;
  28. #endif
  29. }
  30. */
  31. import "C"
  32. import (
  33. "context"
  34. _ "embed"
  35. "errors"
  36. "fmt"
  37. "log/slog"
  38. "os"
  39. "runtime"
  40. "runtime/cgo"
  41. "slices"
  42. "strings"
  43. "unsafe"
  44. _ "github.com/ollama/ollama/llama/llama.cpp/common"
  45. _ "github.com/ollama/ollama/llama/llama.cpp/examples/llava"
  46. _ "github.com/ollama/ollama/llama/llama.cpp/src"
  47. ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
  48. )
  49. func init() {
  50. C.llama_log_set(C.ggml_log_callback(C.llamaLog), nil)
  51. }
  52. //export llamaLog
  53. func llamaLog(level C.int, text *C.char, _ unsafe.Pointer) {
  54. // slog levels zeros INFO and are multiples of 4
  55. if slog.Default().Enabled(context.TODO(), slog.Level(int(level-C.GGML_LOG_LEVEL_INFO)*4)) {
  56. fmt.Fprint(os.Stderr, C.GoString(text))
  57. }
  58. }
  59. func BackendInit() {
  60. ggml.OnceLoad()
  61. C.llama_backend_init()
  62. }
  63. func PrintSystemInfo() string {
  64. var compiler string
  65. switch C.get_compiler() {
  66. case C.COMP_UNKNOWN:
  67. compiler = "cgo(unknown_compiler)"
  68. case C.COMP_GCC:
  69. compiler = "cgo(gcc)"
  70. case C.COMP_CLANG:
  71. compiler = "cgo(clang)"
  72. }
  73. return C.GoString(C.llama_print_system_info()) + compiler
  74. }
  75. func GetModelArch(modelPath string) (string, error) {
  76. mp := C.CString(modelPath)
  77. defer C.free(unsafe.Pointer(mp))
  78. gguf_ctx := C.gguf_init_from_file(mp, C.struct_gguf_init_params{no_alloc: true, ctx: (**C.struct_ggml_context)(C.NULL)})
  79. if gguf_ctx == nil {
  80. return "", errors.New("unable to load model file")
  81. }
  82. defer C.gguf_free(gguf_ctx)
  83. key := C.CString("general.architecture")
  84. defer C.free(unsafe.Pointer(key))
  85. arch_index := C.gguf_find_key(gguf_ctx, key)
  86. if int(arch_index) < 0 {
  87. return "", errors.New("unknown model architecture")
  88. }
  89. arch := C.gguf_get_val_str(gguf_ctx, arch_index)
  90. return C.GoString(arch), nil
  91. }
  92. type ContextParams struct {
  93. c C.struct_llama_context_params
  94. }
  95. func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, kvCacheType string) ContextParams {
  96. params := C.llama_context_default_params()
  97. params.n_ctx = C.uint(numCtx)
  98. params.n_batch = C.uint(batchSize)
  99. params.n_seq_max = C.uint(numSeqMax)
  100. params.n_threads = C.int(threads)
  101. params.n_threads_batch = params.n_threads
  102. params.embeddings = C.bool(true)
  103. params.flash_attn = C.bool(flashAttention)
  104. params.type_k = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
  105. params.type_v = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
  106. return ContextParams{c: params}
  107. }
  108. // kvCacheTypeFromStr converts a string cache type to the corresponding GGML type value
  109. func kvCacheTypeFromStr(s string) C.enum_ggml_type {
  110. if s == "" {
  111. return C.GGML_TYPE_F16
  112. }
  113. switch s {
  114. case "q8_0":
  115. return C.GGML_TYPE_Q8_0
  116. case "q4_0":
  117. return C.GGML_TYPE_Q4_0
  118. default:
  119. return C.GGML_TYPE_F16
  120. }
  121. }
  122. type Context struct {
  123. c *C.struct_llama_context
  124. numThreads int
  125. }
  126. var ErrKvCacheFull = errors.New("could not find a kv cache slot")
  127. func (c *Context) Decode(batch *Batch) error {
  128. // Positive return values does not mean a fatal error, but rather a warning.
  129. // 0 - success
  130. // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
  131. // < 0 - error
  132. code := int(C.llama_decode(c.c, batch.c))
  133. if code < 0 {
  134. return fmt.Errorf("llama_decode failed with code %d", code)
  135. }
  136. if code > 0 {
  137. return ErrKvCacheFull
  138. }
  139. return nil
  140. }
  141. func (c *Context) Model() *Model {
  142. return &Model{c: C.llama_get_model(c.c)}
  143. }
  144. func (c *Context) KvCacheSeqAdd(seqId int, p0 int, p1 int, delta int) {
  145. C.llama_kv_cache_seq_add(c.c, C.int(seqId), C.int(p0), C.int(p1), C.int(delta))
  146. }
  147. func (c *Context) KvCacheSeqRm(seqId int, p0 int, p1 int) bool {
  148. return bool(C.llama_kv_cache_seq_rm(c.c, C.int(seqId), C.int(p0), C.int(p1)))
  149. }
  150. func (c *Context) KvCacheSeqCp(srcSeqId int, dstSeqId int, p0 int, p1 int) {
  151. C.llama_kv_cache_seq_cp(c.c, C.int(srcSeqId), C.int(dstSeqId), C.int(p0), C.int(p1))
  152. }
  153. func (c *Context) KvCacheClear() {
  154. C.llama_kv_cache_clear(c.c)
  155. }
  156. func (c *Context) KvCacheDefrag() {
  157. C.llama_kv_cache_defrag(c.c)
  158. }
  159. // Get the embeddings for a sequence id
  160. func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
  161. e := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
  162. if e == nil {
  163. return nil
  164. }
  165. embeddings := make([]float32, c.Model().NEmbd())
  166. _ = copy(embeddings, unsafe.Slice((*float32)(e), c.Model().NEmbd()))
  167. return embeddings
  168. }
  169. func (c *Context) GetEmbeddingsIth(i int) []float32 {
  170. e := unsafe.Pointer(C.llama_get_embeddings_ith(c.c, C.int32_t(i)))
  171. if e == nil {
  172. return nil
  173. }
  174. embeddings := make([]float32, c.Model().NEmbd())
  175. _ = copy(embeddings, unsafe.Slice((*float32)(e), c.Model().NEmbd()))
  176. return embeddings
  177. }
  178. type ModelParams struct {
  179. NumGpuLayers int
  180. MainGpu int
  181. UseMmap bool
  182. UseMlock bool
  183. TensorSplit []float32
  184. Progress func(float32)
  185. VocabOnly bool
  186. }
  187. //export llamaProgressCallback
  188. func llamaProgressCallback(progress C.float, userData unsafe.Pointer) C.bool {
  189. handle := *(*cgo.Handle)(userData)
  190. callback := handle.Value().(func(float32))
  191. callback(float32(progress))
  192. return true
  193. }
  194. func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
  195. cparams := C.llama_model_default_params()
  196. cparams.n_gpu_layers = C.int(params.NumGpuLayers)
  197. cparams.main_gpu = C.int32_t(params.MainGpu)
  198. cparams.use_mmap = C.bool(params.UseMmap)
  199. cparams.use_mlock = C.bool(params.UseMlock)
  200. cparams.vocab_only = C.bool(params.VocabOnly)
  201. if len(params.TensorSplit) > 0 {
  202. tensorSplitData := &params.TensorSplit[0]
  203. var tensorSplitPin runtime.Pinner
  204. tensorSplitPin.Pin(tensorSplitData)
  205. defer tensorSplitPin.Unpin()
  206. cparams.tensor_split = (*C.float)(unsafe.Pointer(tensorSplitData))
  207. }
  208. if params.Progress != nil {
  209. handle := cgo.NewHandle(params.Progress)
  210. defer handle.Delete()
  211. var handlePin runtime.Pinner
  212. handlePin.Pin(&handle)
  213. defer handlePin.Unpin()
  214. cparams.progress_callback = C.llama_progress_callback(C.llamaProgressCallback)
  215. cparams.progress_callback_user_data = unsafe.Pointer(&handle)
  216. }
  217. m := Model{c: C.llama_load_model_from_file(C.CString(modelPath), cparams)}
  218. if m.c == nil {
  219. return nil, fmt.Errorf("unable to load model: %s", modelPath)
  220. }
  221. return &m, nil
  222. }
  223. func FreeModel(model *Model) {
  224. C.llama_free_model(model.c)
  225. }
  226. func NewContextWithModel(model *Model, params ContextParams) (*Context, error) {
  227. c := Context{
  228. c: C.llama_new_context_with_model(model.c, params.c),
  229. numThreads: int(params.c.n_threads),
  230. }
  231. if c.c == nil {
  232. return nil, errors.New("unable to create llama context")
  233. }
  234. return &c, nil
  235. }
  236. func (m *Model) NumVocab() int {
  237. return int(C.llama_n_vocab(m.Vocab()))
  238. }
  239. func (m *Model) TokenIsEog(token int) bool {
  240. return bool(C.llama_token_is_eog(m.Vocab(), C.llama_token(token)))
  241. }
  242. func (m *Model) AddBOSToken() bool {
  243. return bool(C.llama_add_bos_token(m.Vocab()))
  244. }
  245. func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float32, threads int) error {
  246. cLoraPath := C.CString(loraPath)
  247. defer C.free(unsafe.Pointer(cLoraPath))
  248. loraAdapter := C.llama_adapter_lora_init(m.c, cLoraPath)
  249. if loraAdapter == nil {
  250. return errors.New("unable to load lora")
  251. }
  252. err := -1
  253. if loraAdapter != nil {
  254. err = int(C.llama_set_adapter_lora(context.c, loraAdapter, C.float(scale)))
  255. }
  256. if err != 0 {
  257. return errors.New("error applying lora from file")
  258. }
  259. return nil
  260. }
  261. func (m *Model) Vocab() *C.struct_llama_vocab {
  262. return C.llama_model_get_vocab(m.c)
  263. }
  264. type Batch struct {
  265. c C.struct_llama_batch
  266. batchSize int
  267. maxSeq int
  268. embedSize int
  269. }
  270. // Creates a new batch for either word tokens or image embeddings (if embedSize is non-zero).
  271. // Batches cannot contain both types at the same time. batchSize is the maximum number of entries
  272. // that can be added per sequence
  273. func NewBatch(batchSize int, maxSeq int, embedSize int) (*Batch, error) {
  274. b := Batch{
  275. c: C.llama_batch_init(C.int(batchSize*maxSeq), C.int(embedSize), C.int(maxSeq)),
  276. batchSize: batchSize,
  277. maxSeq: maxSeq,
  278. embedSize: embedSize,
  279. }
  280. // Check to see if any of the allocations in llama_batch_init() failed
  281. nilPointer := (embedSize == 0 && b.c.token == nil) || (embedSize != 0 && b.c.embd == nil) ||
  282. b.c.pos == nil || b.c.n_seq_id == nil || b.c.seq_id == nil || b.c.logits == nil ||
  283. slices.Contains(unsafe.Slice(b.c.seq_id, b.allocSize()), nil)
  284. if nilPointer {
  285. C.llama_batch_free(b.c)
  286. return nil, fmt.Errorf("unable to allocate batch (batchSize=%v maxSeq=%v embedSize=%v)", batchSize, maxSeq, embedSize)
  287. }
  288. return &b, nil
  289. }
  290. func (b *Batch) Size() int {
  291. return b.batchSize
  292. }
  293. func (b *Batch) allocSize() int {
  294. return b.batchSize * b.maxSeq
  295. }
  296. func (b *Batch) NumTokens() int {
  297. return int(b.c.n_tokens)
  298. }
  299. func (b *Batch) IsEmbedding() bool {
  300. return b.embedSize != 0
  301. }
  302. // Add adds either a token or an image embedding to the batch depending on the type
  303. // when the batch was initialized. The other argument will be ignored. Adds to the
  304. // batch with the given position for the given sequence ids, and optionally instructs
  305. // to include logits.
  306. func (b *Batch) Add(token int, embed []float32, pos int, logits bool, seqIds ...int) {
  307. if !b.IsEmbedding() {
  308. unsafe.Slice(b.c.token, b.allocSize())[b.c.n_tokens] = C.llama_token(token)
  309. } else {
  310. copy(unsafe.Slice((*float32)(b.c.embd), b.allocSize()*b.embedSize)[int(b.c.n_tokens)*b.embedSize:], embed)
  311. }
  312. unsafe.Slice(b.c.pos, b.allocSize())[b.c.n_tokens] = C.llama_pos(pos)
  313. unsafe.Slice(b.c.n_seq_id, b.allocSize())[b.c.n_tokens] = C.int(len(seqIds))
  314. for i, s := range seqIds {
  315. unsafe.Slice((unsafe.Slice(b.c.seq_id, b.allocSize())[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
  316. }
  317. if logits {
  318. unsafe.Slice(b.c.logits, b.allocSize())[b.c.n_tokens] = 1
  319. } else {
  320. unsafe.Slice(b.c.logits, b.allocSize())[b.c.n_tokens] = 0
  321. }
  322. b.c.n_tokens += 1
  323. }
  324. func (b *Batch) Clear() {
  325. b.c.n_tokens = 0
  326. }
  327. func (b *Batch) Free() {
  328. b.batchSize = 0
  329. C.llama_batch_free(b.c)
  330. }
  331. type Model struct {
  332. c *C.struct_llama_model
  333. }
  334. func (m *Model) TokenToPiece(token int) string {
  335. tokenLen := 12
  336. buf := make([]byte, tokenLen)
  337. tokenLen = int(C.llama_token_to_piece(
  338. m.Vocab(),
  339. C.int32_t(token),
  340. (*C.char)(unsafe.Pointer(&buf[0])),
  341. C.int32_t(tokenLen),
  342. C.int32_t(0),
  343. C.bool(true),
  344. ))
  345. if tokenLen < 0 {
  346. tokenLen = -tokenLen
  347. buf = make([]byte, tokenLen)
  348. C.llama_token_to_piece(
  349. m.Vocab(),
  350. C.int32_t(token),
  351. (*C.char)(unsafe.Pointer(&buf[0])),
  352. C.int32_t(tokenLen),
  353. C.int32_t(0),
  354. C.bool(true),
  355. )
  356. }
  357. return strings.TrimRight(string(buf), "\x00")
  358. }
  359. func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int, error) {
  360. maxTokens := len(text) + 2
  361. cTokens := make([]C.llama_token, maxTokens)
  362. cText := C.CString(text)
  363. defer C.free(unsafe.Pointer(cText))
  364. result := C.llama_tokenize(
  365. m.Vocab(),
  366. cText,
  367. C.int32_t(len(text)),
  368. &cTokens[0],
  369. C.int32_t(maxTokens),
  370. C.bool(addSpecial),
  371. C.bool(parseSpecial),
  372. )
  373. // if the result is negative, reallocate and retry with the correct buffer size
  374. if result < 0 {
  375. maxTokens = int(-result)
  376. cTokens = make([]C.llama_token, maxTokens)
  377. result = C.llama_tokenize(
  378. m.Vocab(),
  379. cText,
  380. C.int32_t(len(text)),
  381. &cTokens[0],
  382. C.int32_t(maxTokens),
  383. C.bool(addSpecial),
  384. C.bool(parseSpecial),
  385. )
  386. if result < 0 {
  387. return nil, fmt.Errorf("tokenization failed, required %d tokens", -result)
  388. }
  389. }
  390. tokens := make([]int, result)
  391. for i := range result {
  392. tokens[i] = int(cTokens[i])
  393. }
  394. return tokens, nil
  395. }
  396. func (m *Model) NEmbd() int {
  397. return int(C.llama_n_embd(m.c))
  398. }
  399. func Quantize(infile, outfile string, ftype uint32) error {
  400. cinfile := C.CString(infile)
  401. defer C.free(unsafe.Pointer(cinfile))
  402. coutfile := C.CString(outfile)
  403. defer C.free(unsafe.Pointer(coutfile))
  404. params := C.llama_model_quantize_default_params()
  405. params.nthread = -1
  406. params.ftype = ftype
  407. if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
  408. return fmt.Errorf("llama_model_quantize: %d", rc)
  409. }
  410. return nil
  411. }
  412. // vision processing
  413. type ClipContext struct {
  414. c *C.struct_clip_ctx
  415. }
  416. func NewClipContext(llamaContext *Context, modelPath string) (*ClipContext, error) {
  417. mp := C.CString(modelPath)
  418. defer C.free(unsafe.Pointer(mp))
  419. c := C.clip_model_load(mp, 1)
  420. if c == nil {
  421. return nil, fmt.Errorf("unable to load clip model: %v", modelPath)
  422. }
  423. projEmbedSize := int(C.clip_n_mmproj_embd(c))
  424. modelEmbedSize := llamaContext.Model().NEmbd()
  425. if projEmbedSize != modelEmbedSize {
  426. return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
  427. }
  428. return &ClipContext{c: c}, nil
  429. }
  430. func (c *ClipContext) Free() {
  431. C.clip_free(c.c)
  432. }
  433. func (c *ClipContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32, error) {
  434. l := C.llava_image_embed_make_with_bytes(c.c, C.int(llamaContext.numThreads), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data)))
  435. if l == nil {
  436. return nil, errors.New("unable to make llava embedding from image")
  437. }
  438. numTokens := int(l.n_image_pos)
  439. numEmbed := llamaContext.Model().NEmbd()
  440. s := unsafe.Slice((*float32)(l.embed), numEmbed*numTokens)
  441. embed := make([][]float32, numTokens)
  442. rows := make([]float32, len(s))
  443. copy(rows, s)
  444. for i := range embed {
  445. embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
  446. }
  447. C.llava_image_embed_free(l)
  448. return embed, nil
  449. }
  450. type MllamaContext struct {
  451. c *C.struct_mllama_ctx
  452. }
  453. func NewMllamaContext(llamaContext *Context, modelPath string) (*MllamaContext, error) {
  454. mp := C.CString(modelPath)
  455. defer C.free(unsafe.Pointer(mp))
  456. c := C.mllama_model_load(mp, 1)
  457. if c == nil {
  458. return nil, fmt.Errorf("unable to load mllama model: %v", modelPath)
  459. }
  460. projEmbedSize := int(C.mllama_n_embd(c))
  461. modelEmbedSize := llamaContext.Model().NEmbd()
  462. if projEmbedSize != modelEmbedSize {
  463. return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
  464. }
  465. return &MllamaContext{c: c}, nil
  466. }
  467. func (m *MllamaContext) Free() {
  468. C.mllama_free(m.c)
  469. }
  470. func (m *MllamaContext) NewEmbed(llamaContext *Context, data []byte, aspectRatioId int) ([][]float32, error) {
  471. img := C.mllama_image_init()
  472. defer C.mllama_image_free(img)
  473. ok := bool(C.mllama_image_load_from_data(unsafe.Pointer(&data[0]), C.int(len(data)), 560, 560, 3, 4, C.int(aspectRatioId), img))
  474. if !ok {
  475. return nil, errors.New("unable to load mllama image data")
  476. }
  477. rows := make([]float32, m.EmbedSize(llamaContext))
  478. ok = bool(C.mllama_image_encode(m.c, C.int(llamaContext.numThreads), img, (*C.float)(unsafe.Pointer(&rows[0]))))
  479. if !ok {
  480. return nil, errors.New("unable to make mllama embedding from image")
  481. }
  482. embed := make([][]float32, 1)
  483. embed[0] = rows
  484. return embed, nil
  485. }
  486. func (m *MllamaContext) EmbedSize(llamaContext *Context) int {
  487. numTokens := int(C.mllama_n_positions(m.c) * C.mllama_n_tiles(m.c))
  488. numEmbed := llamaContext.Model().NEmbd()
  489. return numTokens * numEmbed
  490. }
  491. func (c *Context) SetCrossAttention(state bool) {
  492. C.llama_set_cross_attention(c.c, C.bool(state))
  493. }
  494. func (c *Context) Synchronize() {
  495. C.llama_synchronize(c.c)
  496. }
  497. // sampling
  498. // TODO: this is a temporary wrapper to allow calling C++ code from CGo
  499. type SamplingContext struct {
  500. c *C.struct_common_sampler
  501. }
  502. type SamplingParams struct {
  503. TopK int
  504. TopP float32
  505. MinP float32
  506. TypicalP float32
  507. Temp float32
  508. RepeatLastN int
  509. PenaltyRepeat float32
  510. PenaltyFreq float32
  511. PenaltyPresent float32
  512. Mirostat int
  513. MirostatTau float32
  514. MirostatEta float32
  515. PenalizeNl bool
  516. Seed uint32
  517. Grammar string
  518. }
  519. func NewSamplingContext(model *Model, params SamplingParams) (*SamplingContext, error) {
  520. var cparams C.struct_common_sampler_cparams
  521. cparams.top_k = C.int32_t(params.TopK)
  522. cparams.top_p = C.float(params.TopP)
  523. cparams.min_p = C.float(params.MinP)
  524. cparams.typical_p = C.float(params.TypicalP)
  525. cparams.temp = C.float(params.Temp)
  526. cparams.penalty_last_n = C.int32_t(params.RepeatLastN)
  527. cparams.penalty_repeat = C.float(params.PenaltyRepeat)
  528. cparams.penalty_freq = C.float(params.PenaltyFreq)
  529. cparams.penalty_present = C.float(params.PenaltyFreq)
  530. cparams.mirostat = C.int32_t(params.Mirostat)
  531. cparams.mirostat_tau = C.float(params.MirostatTau)
  532. cparams.mirostat_eta = C.float(params.MirostatEta)
  533. cparams.seed = C.uint32_t(params.Seed)
  534. grammar := C.CString(params.Grammar)
  535. defer C.free(unsafe.Pointer(grammar))
  536. cparams.grammar = grammar
  537. context := &SamplingContext{c: C.common_sampler_cinit(model.c, &cparams)}
  538. if context.c == nil {
  539. return nil, errors.New("unable to create sampling context")
  540. }
  541. runtime.SetFinalizer(context, func(s *SamplingContext) { C.common_sampler_cfree(s.c) })
  542. return context, nil
  543. }
  544. func (s *SamplingContext) Reset() {
  545. C.common_sampler_creset(s.c)
  546. }
  547. func (s *SamplingContext) Sample(llamaContext *Context, idx int) int {
  548. return int(C.common_sampler_csample(s.c, llamaContext.c, C.int(idx)))
  549. }
  550. func (s *SamplingContext) Accept(id int, applyGrammar bool) {
  551. C.common_sampler_caccept(s.c, C.llama_token(id), C.bool(applyGrammar))
  552. }
  553. // SchemaToGrammar converts the provided JSON schema to a grammar. It returns
  554. // nil if the provided schema is invalid JSON or an invalid JSON schema.
  555. func SchemaToGrammar(schema []byte) []byte {
  556. cStr := C.CString(string(schema))
  557. defer C.free(unsafe.Pointer(cStr))
  558. // Allocate buffer for grammar output with reasonable size
  559. const maxLen = 32768 // 32KB
  560. buf := make([]byte, maxLen)
  561. // Call C function to convert schema to grammar
  562. n := C.schema_to_grammar(cStr, (*C.char)(unsafe.Pointer(&buf[0])), C.size_t(maxLen))
  563. if n == 0 {
  564. // preserve nil
  565. return nil
  566. }
  567. return buf[:n]
  568. }