llama.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702
  1. package llama
  2. //go:generate make -j 8
  3. /*
  4. #cgo CFLAGS: -std=c11
  5. #cgo CXXFLAGS: -std=c++11
  6. #cgo CPPFLAGS: -I${SRCDIR}/../ml/backend/ggml/ggml/include
  7. #cgo darwin,arm64 CPPFLAGS: -DGGML_USE_METAL
  8. #include <stdlib.h>
  9. #include "llama.h"
  10. #include "clip.h"
  11. #include "ggml.h"
  12. #include "llava.h"
  13. #include "mllama.h"
  14. #include "sampling_ext.h"
  15. extern bool llamaProgressCallback(float progress, void *user_data);
  16. extern void llamaLog(int level, char* text, void* user_data);
  17. typedef enum {COMP_UNKNOWN,COMP_GCC,COMP_CLANG} COMPILER;
  18. COMPILER inline get_compiler() {
  19. #if defined(__clang__)
  20. return COMP_CLANG;
  21. #elif defined(__GNUC__)
  22. return COMP_GCC;
  23. #else
  24. return UNKNOWN_COMPILER;
  25. #endif
  26. }
  27. */
  28. import "C"
  29. import (
  30. "bytes"
  31. _ "embed"
  32. "encoding/json"
  33. "errors"
  34. "fmt"
  35. "log/slog"
  36. "runtime"
  37. "runtime/cgo"
  38. "slices"
  39. "strings"
  40. "sync/atomic"
  41. "unsafe"
  42. _ "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
  43. )
  44. func BackendInit() {
  45. C.llama_backend_init()
  46. }
  47. func PrintSystemInfo() string {
  48. var compiler string
  49. switch C.get_compiler() {
  50. case C.COMP_UNKNOWN:
  51. compiler = "cgo(unknown_compiler)"
  52. case C.COMP_GCC:
  53. compiler = "cgo(gcc)"
  54. case C.COMP_CLANG:
  55. compiler = "cgo(clang)"
  56. }
  57. return C.GoString(C.llama_print_system_info()) + compiler
  58. }
  59. var logLevel atomic.Int32
  60. func init() {
  61. logLevel.Store(int32(C.GGML_LOG_LEVEL_INFO))
  62. C.llama_log_set((C.ggml_log_callback)(C.llamaLog), nil)
  63. }
  64. func EnableDebug() {
  65. logLevel.Store(int32(C.GGML_LOG_LEVEL_DEBUG))
  66. }
  67. //export llamaLog
  68. func llamaLog(level int32, text *C.char, _ unsafe.Pointer) {
  69. if level < logLevel.Load() {
  70. return
  71. }
  72. fmt.Print(C.GoString(text))
  73. }
  74. func GetModelArch(modelPath string) (string, error) {
  75. mp := C.CString(modelPath)
  76. defer C.free(unsafe.Pointer(mp))
  77. gguf_ctx := C.gguf_init_from_file(mp, C.struct_gguf_init_params{no_alloc: true, ctx: (**C.struct_ggml_context)(C.NULL)})
  78. if gguf_ctx == nil {
  79. return "", errors.New("unable to load model file")
  80. }
  81. defer C.gguf_free(gguf_ctx)
  82. key := C.CString("general.architecture")
  83. defer C.free(unsafe.Pointer(key))
  84. arch_index := C.gguf_find_key(gguf_ctx, key)
  85. if int(arch_index) < 0 {
  86. return "", errors.New("unknown model architecture")
  87. }
  88. arch := C.gguf_get_val_str(gguf_ctx, arch_index)
  89. return C.GoString(arch), nil
  90. }
  91. type ContextParams struct {
  92. c C.struct_llama_context_params
  93. }
  94. func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, kvCacheType string) ContextParams {
  95. params := C.llama_context_default_params()
  96. params.n_ctx = C.uint(numCtx)
  97. params.n_batch = C.uint(batchSize)
  98. params.n_seq_max = C.uint(numSeqMax)
  99. params.n_threads = C.int(threads)
  100. params.n_threads_batch = params.n_threads
  101. params.embeddings = C.bool(true)
  102. params.flash_attn = C.bool(flashAttention)
  103. params.type_k = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
  104. params.type_v = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
  105. return ContextParams{c: params}
  106. }
  107. // kvCacheTypeFromStr converts a string cache type to the corresponding GGML type value
  108. func kvCacheTypeFromStr(s string) C.enum_ggml_type {
  109. if s == "" {
  110. return C.GGML_TYPE_F16
  111. }
  112. switch s {
  113. case "q8_0":
  114. return C.GGML_TYPE_Q8_0
  115. case "q4_0":
  116. return C.GGML_TYPE_Q4_0
  117. default:
  118. return C.GGML_TYPE_F16
  119. }
  120. }
  121. type Context struct {
  122. c *C.struct_llama_context
  123. numThreads int
  124. }
  125. var ErrKvCacheFull = errors.New("could not find a kv cache slot")
  126. func (c *Context) Decode(batch *Batch) error {
  127. // Positive return values does not mean a fatal error, but rather a warning.
  128. // 0 - success
  129. // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
  130. // < 0 - error
  131. code := int(C.llama_decode(c.c, batch.c))
  132. if code < 0 {
  133. return fmt.Errorf("llama_decode failed with code %d", code)
  134. }
  135. if code > 0 {
  136. return ErrKvCacheFull
  137. }
  138. return nil
  139. }
  140. func (c *Context) Model() *Model {
  141. return &Model{c: C.llama_get_model(c.c)}
  142. }
  143. func (c *Context) KvCacheSeqAdd(seqId int, p0 int, p1 int, delta int) {
  144. C.llama_kv_cache_seq_add(c.c, C.int(seqId), C.int(p0), C.int(p1), C.int(delta))
  145. }
  146. func (c *Context) KvCacheSeqRm(seqId int, p0 int, p1 int) bool {
  147. return bool(C.llama_kv_cache_seq_rm(c.c, C.int(seqId), C.int(p0), C.int(p1)))
  148. }
  149. func (c *Context) KvCacheSeqCp(srcSeqId int, dstSeqId int, p0 int, p1 int) {
  150. C.llama_kv_cache_seq_cp(c.c, C.int(srcSeqId), C.int(dstSeqId), C.int(p0), C.int(p1))
  151. }
  152. func (c *Context) KvCacheClear() {
  153. C.llama_kv_cache_clear(c.c)
  154. }
  155. func (c *Context) KvCacheDefrag() {
  156. C.llama_kv_cache_defrag(c.c)
  157. }
  158. // Get the embeddings for a sequence id
  159. func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
  160. embeddings := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
  161. if embeddings == nil {
  162. return nil
  163. }
  164. return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
  165. }
  166. func (c *Context) GetEmbeddingsIth(i int) []float32 {
  167. embeddings := unsafe.Pointer(C.llama_get_embeddings_ith(c.c, C.int32_t(i)))
  168. if embeddings == nil {
  169. return nil
  170. }
  171. return unsafe.Slice((*float32)(embeddings), c.Model().NEmbd())
  172. }
  173. type ModelParams struct {
  174. NumGpuLayers int
  175. MainGpu int
  176. UseMmap bool
  177. UseMlock bool
  178. TensorSplit []float32
  179. Progress func(float32)
  180. VocabOnly bool
  181. }
  182. //export llamaProgressCallback
  183. func llamaProgressCallback(progress C.float, userData unsafe.Pointer) C.bool {
  184. handle := *(*cgo.Handle)(userData)
  185. callback := handle.Value().(func(float32))
  186. callback(float32(progress))
  187. return true
  188. }
  189. func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
  190. cparams := C.llama_model_default_params()
  191. cparams.n_gpu_layers = C.int(params.NumGpuLayers)
  192. cparams.main_gpu = C.int32_t(params.MainGpu)
  193. cparams.use_mmap = C.bool(params.UseMmap)
  194. cparams.use_mlock = C.bool(params.UseMlock)
  195. cparams.vocab_only = C.bool(params.VocabOnly)
  196. if len(params.TensorSplit) > 0 {
  197. tensorSplitData := &params.TensorSplit[0]
  198. var tensorSplitPin runtime.Pinner
  199. tensorSplitPin.Pin(tensorSplitData)
  200. defer tensorSplitPin.Unpin()
  201. cparams.tensor_split = (*C.float)(unsafe.Pointer(tensorSplitData))
  202. }
  203. if params.Progress != nil {
  204. handle := cgo.NewHandle(params.Progress)
  205. defer handle.Delete()
  206. var handlePin runtime.Pinner
  207. handlePin.Pin(&handle)
  208. defer handlePin.Unpin()
  209. cparams.progress_callback = C.llama_progress_callback(C.llamaProgressCallback)
  210. cparams.progress_callback_user_data = unsafe.Pointer(&handle)
  211. }
  212. m := Model{c: C.llama_load_model_from_file(C.CString(modelPath), cparams)}
  213. if m.c == nil {
  214. return nil, fmt.Errorf("unable to load model: %s", modelPath)
  215. }
  216. return &m, nil
  217. }
  218. func FreeModel(model *Model) {
  219. C.llama_free_model(model.c)
  220. }
  221. func NewContextWithModel(model *Model, params ContextParams) (*Context, error) {
  222. c := Context{
  223. c: C.llama_new_context_with_model(model.c, params.c),
  224. numThreads: int(params.c.n_threads),
  225. }
  226. if c.c == nil {
  227. return nil, errors.New("unable to create llama context")
  228. }
  229. return &c, nil
  230. }
  231. func (m *Model) NumVocab() int {
  232. return int(C.llama_n_vocab(m.c))
  233. }
  234. func (m *Model) TokenIsEog(token int) bool {
  235. return bool(C.llama_token_is_eog(m.c, C.llama_token(token)))
  236. }
  237. func (m *Model) AddBOSToken() bool {
  238. return bool(C.llama_add_bos_token(m.c))
  239. }
  240. func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float32, threads int) error {
  241. cLoraPath := C.CString(loraPath)
  242. defer C.free(unsafe.Pointer(cLoraPath))
  243. loraAdapter := C.llama_lora_adapter_init(m.c, cLoraPath)
  244. if loraAdapter == nil {
  245. return errors.New("unable to load lora")
  246. }
  247. err := -1
  248. if loraAdapter != nil {
  249. err = int(C.llama_lora_adapter_set(context.c, loraAdapter, C.float(scale)))
  250. }
  251. if err != 0 {
  252. return errors.New("error applying lora from file")
  253. }
  254. return nil
  255. }
  256. type Batch struct {
  257. c C.struct_llama_batch
  258. batchSize int
  259. maxSeq int
  260. embedSize int
  261. }
  262. // Creates a new batch for either word tokens or image embeddings (if embedSize is non-zero).
  263. // Batches cannot contain both types at the same time. batchSize is the maximum number of entries
  264. // that can be added per sequence
  265. func NewBatch(batchSize int, maxSeq int, embedSize int) (*Batch, error) {
  266. b := Batch{
  267. c: C.llama_batch_init(C.int(batchSize*maxSeq), C.int(embedSize), C.int(maxSeq)),
  268. batchSize: batchSize,
  269. maxSeq: maxSeq,
  270. embedSize: embedSize,
  271. }
  272. // Check to see if any of the allocations in llama_batch_init() failed
  273. nilPointer := (embedSize == 0 && b.c.token == nil) || (embedSize != 0 && b.c.embd == nil) ||
  274. b.c.pos == nil || b.c.n_seq_id == nil || b.c.seq_id == nil || b.c.logits == nil ||
  275. slices.Contains(unsafe.Slice(b.c.seq_id, b.allocSize()), nil)
  276. if nilPointer {
  277. C.llama_batch_free(b.c)
  278. return nil, fmt.Errorf("unable to allocate batch (batchSize=%v maxSeq=%v embedSize=%v)", batchSize, maxSeq, embedSize)
  279. }
  280. return &b, nil
  281. }
  282. func (b *Batch) Size() int {
  283. return b.batchSize
  284. }
  285. func (b *Batch) allocSize() int {
  286. return b.batchSize * b.maxSeq
  287. }
  288. func (b *Batch) NumTokens() int {
  289. return int(b.c.n_tokens)
  290. }
  291. func (b *Batch) IsEmbedding() bool {
  292. return b.embedSize != 0
  293. }
  294. // Add adds either a token or an image embedding to the batch depending on the type
  295. // when the batch was initialized. The other argument will be ignored. Adds to the
  296. // batch with the given position for the given sequence ids, and optionally instructs
  297. // to include logits.
  298. func (b *Batch) Add(token int, embed []float32, pos int, logits bool, seqIds ...int) {
  299. if !b.IsEmbedding() {
  300. unsafe.Slice(b.c.token, b.allocSize())[b.c.n_tokens] = C.llama_token(token)
  301. } else {
  302. copy(unsafe.Slice((*float32)(b.c.embd), b.allocSize()*b.embedSize)[int(b.c.n_tokens)*b.embedSize:], embed)
  303. }
  304. unsafe.Slice(b.c.pos, b.allocSize())[b.c.n_tokens] = C.llama_pos(pos)
  305. unsafe.Slice(b.c.n_seq_id, b.allocSize())[b.c.n_tokens] = C.int(len(seqIds))
  306. for i, s := range seqIds {
  307. unsafe.Slice((unsafe.Slice(b.c.seq_id, b.allocSize())[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
  308. }
  309. if logits {
  310. unsafe.Slice(b.c.logits, b.allocSize())[b.c.n_tokens] = 1
  311. } else {
  312. unsafe.Slice(b.c.logits, b.allocSize())[b.c.n_tokens] = 0
  313. }
  314. b.c.n_tokens += 1
  315. }
  316. func (b *Batch) Clear() {
  317. b.c.n_tokens = 0
  318. }
  319. func (b *Batch) Free() {
  320. b.batchSize = 0
  321. C.llama_batch_free(b.c)
  322. }
  323. type Model struct {
  324. c *C.struct_llama_model
  325. }
  326. func (m *Model) TokenToPiece(token int) string {
  327. tokenLen := 12
  328. buf := make([]byte, tokenLen)
  329. tokenLen = int(C.llama_token_to_piece(
  330. m.c,
  331. C.int32_t(token),
  332. (*C.char)(unsafe.Pointer(&buf[0])),
  333. C.int32_t(tokenLen),
  334. C.int32_t(0),
  335. C.bool(true),
  336. ))
  337. if tokenLen < 0 {
  338. tokenLen = -tokenLen
  339. buf = make([]byte, tokenLen)
  340. C.llama_token_to_piece(
  341. m.c,
  342. C.int32_t(token),
  343. (*C.char)(unsafe.Pointer(&buf[0])),
  344. C.int32_t(tokenLen),
  345. C.int32_t(0),
  346. C.bool(true),
  347. )
  348. }
  349. return strings.TrimRight(string(buf), "\x00")
  350. }
  351. func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int, error) {
  352. maxTokens := len(text) + 2
  353. cTokens := make([]C.llama_token, maxTokens)
  354. cText := C.CString(text)
  355. defer C.free(unsafe.Pointer(cText))
  356. result := C.llama_tokenize(
  357. m.c,
  358. cText,
  359. C.int32_t(len(text)),
  360. &cTokens[0],
  361. C.int32_t(maxTokens),
  362. C.bool(addSpecial),
  363. C.bool(parseSpecial),
  364. )
  365. // if the result is negative, reallocate and retry with the correct buffer size
  366. if result < 0 {
  367. maxTokens = int(-result)
  368. cTokens = make([]C.llama_token, maxTokens)
  369. result = C.llama_tokenize(
  370. m.c,
  371. cText,
  372. C.int32_t(len(text)),
  373. &cTokens[0],
  374. C.int32_t(maxTokens),
  375. C.bool(addSpecial),
  376. C.bool(parseSpecial),
  377. )
  378. if result < 0 {
  379. return nil, fmt.Errorf("tokenization failed, required %d tokens", -result)
  380. }
  381. }
  382. tokens := make([]int, result)
  383. for i := range result {
  384. tokens[i] = int(cTokens[i])
  385. }
  386. return tokens, nil
  387. }
  388. func (m *Model) NEmbd() int {
  389. return int(C.llama_n_embd(m.c))
  390. }
  391. func Quantize(infile, outfile string, ftype uint32) error {
  392. cinfile := C.CString(infile)
  393. defer C.free(unsafe.Pointer(cinfile))
  394. coutfile := C.CString(outfile)
  395. defer C.free(unsafe.Pointer(coutfile))
  396. params := C.llama_model_quantize_default_params()
  397. params.nthread = -1
  398. params.ftype = ftype
  399. if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
  400. return fmt.Errorf("llama_model_quantize: %d", rc)
  401. }
  402. return nil
  403. }
  404. // vision processing
  405. type ClipContext struct {
  406. c *C.struct_clip_ctx
  407. }
  408. func NewClipContext(llamaContext *Context, modelPath string) (*ClipContext, error) {
  409. mp := C.CString(modelPath)
  410. defer C.free(unsafe.Pointer(mp))
  411. c := C.clip_model_load(mp, 1)
  412. if c == nil {
  413. return nil, fmt.Errorf("unable to load clip model: %v", modelPath)
  414. }
  415. projEmbedSize := int(C.clip_n_mmproj_embd(c))
  416. modelEmbedSize := llamaContext.Model().NEmbd()
  417. if projEmbedSize != modelEmbedSize {
  418. return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
  419. }
  420. return &ClipContext{c: c}, nil
  421. }
  422. func (c *ClipContext) Free() {
  423. C.clip_free(c.c)
  424. }
  425. func (c *ClipContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32, error) {
  426. l := C.llava_image_embed_make_with_bytes(c.c, C.int(llamaContext.numThreads), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data)))
  427. if l == nil {
  428. return nil, errors.New("unable to make llava embedding from image")
  429. }
  430. numTokens := int(l.n_image_pos)
  431. numEmbed := llamaContext.Model().NEmbd()
  432. s := unsafe.Slice((*float32)(l.embed), numEmbed*numTokens)
  433. embed := make([][]float32, numTokens)
  434. rows := make([]float32, len(s))
  435. copy(rows, s)
  436. for i := range embed {
  437. embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
  438. }
  439. C.llava_image_embed_free(l)
  440. return embed, nil
  441. }
  442. type MllamaContext struct {
  443. c *C.struct_mllama_ctx
  444. }
  445. func NewMllamaContext(llamaContext *Context, modelPath string) (*MllamaContext, error) {
  446. mp := C.CString(modelPath)
  447. defer C.free(unsafe.Pointer(mp))
  448. c := C.mllama_model_load(mp, 1)
  449. if c == nil {
  450. return nil, fmt.Errorf("unable to load mllama model: %v", modelPath)
  451. }
  452. projEmbedSize := int(C.mllama_n_embd(c))
  453. modelEmbedSize := llamaContext.Model().NEmbd()
  454. if projEmbedSize != modelEmbedSize {
  455. return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
  456. }
  457. return &MllamaContext{c: c}, nil
  458. }
  459. func (m *MllamaContext) Free() {
  460. C.mllama_free(m.c)
  461. }
  462. func (m *MllamaContext) NewEmbed(llamaContext *Context, data []byte, aspectRatioId int) ([][]float32, error) {
  463. img := C.mllama_image_init()
  464. defer C.mllama_image_free(img)
  465. ok := bool(C.mllama_image_load_from_data(unsafe.Pointer(&data[0]), C.int(len(data)), 560, 560, 3, 4, C.int(aspectRatioId), img))
  466. if !ok {
  467. return nil, errors.New("unable to load mllama image data")
  468. }
  469. rows := make([]float32, m.EmbedSize(llamaContext))
  470. ok = bool(C.mllama_image_encode(m.c, C.int(llamaContext.numThreads), img, (*C.float)(unsafe.Pointer(&rows[0]))))
  471. if !ok {
  472. return nil, errors.New("unable to make mllama embedding from image")
  473. }
  474. embed := make([][]float32, 1)
  475. embed[0] = rows
  476. return embed, nil
  477. }
  478. func (m *MllamaContext) EmbedSize(llamaContext *Context) int {
  479. numTokens := int(C.mllama_n_positions(m.c) * C.mllama_n_tiles(m.c))
  480. numEmbed := llamaContext.Model().NEmbd()
  481. return numTokens * numEmbed
  482. }
  483. func (c *Context) SetCrossAttention(state bool) {
  484. C.llama_set_cross_attention(c.c, C.bool(state))
  485. }
  486. func (c *Context) Synchronize() {
  487. C.llama_synchronize(c.c)
  488. }
  489. // sampling
  490. // TODO: this is a temporary wrapper to allow calling C++ code from CGo
  491. type SamplingContext struct {
  492. c *C.struct_common_sampler
  493. }
  494. type SamplingParams struct {
  495. TopK int
  496. TopP float32
  497. MinP float32
  498. TypicalP float32
  499. Temp float32
  500. RepeatLastN int
  501. PenaltyRepeat float32
  502. PenaltyFreq float32
  503. PenaltyPresent float32
  504. Mirostat int
  505. MirostatTau float32
  506. MirostatEta float32
  507. PenalizeNl bool
  508. Seed uint32
  509. Grammar string
  510. }
  511. func NewSamplingContext(model *Model, params SamplingParams) (*SamplingContext, error) {
  512. var cparams C.struct_common_sampler_cparams
  513. cparams.top_k = C.int32_t(params.TopK)
  514. cparams.top_p = C.float(params.TopP)
  515. cparams.min_p = C.float(params.MinP)
  516. cparams.typical_p = C.float(params.TypicalP)
  517. cparams.temp = C.float(params.Temp)
  518. cparams.penalty_last_n = C.int32_t(params.RepeatLastN)
  519. cparams.penalty_repeat = C.float(params.PenaltyRepeat)
  520. cparams.penalty_freq = C.float(params.PenaltyFreq)
  521. cparams.penalty_present = C.float(params.PenaltyFreq)
  522. cparams.mirostat = C.int32_t(params.Mirostat)
  523. cparams.mirostat_tau = C.float(params.MirostatTau)
  524. cparams.mirostat_eta = C.float(params.MirostatEta)
  525. cparams.penalize_nl = C.bool(params.PenalizeNl)
  526. cparams.seed = C.uint32_t(params.Seed)
  527. grammar := C.CString(params.Grammar)
  528. defer C.free(unsafe.Pointer(grammar))
  529. cparams.grammar = grammar
  530. context := &SamplingContext{c: C.common_sampler_cinit(model.c, &cparams)}
  531. if context.c == nil {
  532. return nil, errors.New("unable to create sampling context")
  533. }
  534. runtime.SetFinalizer(context, func(s *SamplingContext) { C.common_sampler_cfree(s.c) })
  535. return context, nil
  536. }
  537. func (s *SamplingContext) Reset() {
  538. C.common_sampler_creset(s.c)
  539. }
  540. func (s *SamplingContext) Sample(llamaContext *Context, idx int) int {
  541. return int(C.common_sampler_csample(s.c, llamaContext.c, C.int(idx)))
  542. }
  543. func (s *SamplingContext) Accept(id int, applyGrammar bool) {
  544. C.common_sampler_caccept(s.c, C.llama_token(id), C.bool(applyGrammar))
  545. }
  546. type JsonSchema struct {
  547. Defs map[string]any `json:"$defs,omitempty"`
  548. Properties map[string]any `json:"properties,omitempty"`
  549. Required []string `json:"required,omitempty"`
  550. Title string `json:"title,omitempty"`
  551. Type string `json:"type,omitempty"`
  552. }
  553. func (js JsonSchema) AsGrammar() string {
  554. var b bytes.Buffer
  555. if err := json.NewEncoder(&b).Encode(js); err != nil {
  556. return ""
  557. }
  558. cStr := C.CString(b.String())
  559. defer C.free(unsafe.Pointer(cStr))
  560. // Allocate buffer for grammar output with reasonable size
  561. const maxLen = 32768 // 32KB
  562. buf := make([]byte, maxLen)
  563. // Call C function to convert schema to grammar
  564. length := C.schema_to_grammar(cStr, (*C.char)(unsafe.Pointer(&buf[0])), C.size_t(maxLen))
  565. if length == 0 {
  566. slog.Warn("unable to convert schema to grammar")
  567. }
  568. return string(buf[:length])
  569. }