llama.go 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711
  1. package llama
  2. /*
  3. #cgo CFLAGS: -std=c11
  4. #cgo CXXFLAGS: -std=c++17
  5. #cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/include
  6. #cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/common
  7. #cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/examples/llava
  8. #cgo CPPFLAGS: -I${SRCDIR}/llama.cpp/src
  9. #cgo CPPFLAGS: -I${SRCDIR}/../ml/backend/ggml/ggml/include
  10. #include <stdlib.h>
  11. #include "ggml.h"
  12. #include "llama.h"
  13. #include "clip.h"
  14. #include "llava.h"
  15. #include "mllama.h"
  16. #include "sampling_ext.h"
  17. extern bool llamaProgressCallback(float progress, void *user_data);
  18. extern void llamaLog(int level, char* text, void* user_data);
  19. typedef enum {COMP_UNKNOWN,COMP_GCC,COMP_CLANG} COMPILER;
  20. COMPILER inline get_compiler() {
  21. #if defined(__clang__)
  22. return COMP_CLANG;
  23. #elif defined(__GNUC__)
  24. return COMP_GCC;
  25. #else
  26. return UNKNOWN_COMPILER;
  27. #endif
  28. }
  29. */
  30. import "C"
  31. import (
  32. _ "embed"
  33. "errors"
  34. "fmt"
  35. "os"
  36. "runtime"
  37. "runtime/cgo"
  38. "slices"
  39. "strings"
  40. "sync/atomic"
  41. "unsafe"
  42. _ "github.com/ollama/ollama/llama/llama.cpp/common"
  43. _ "github.com/ollama/ollama/llama/llama.cpp/examples/llava"
  44. _ "github.com/ollama/ollama/llama/llama.cpp/src"
  45. ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
  46. )
  47. func BackendInit() {
  48. ggml.OnceLoad()
  49. C.llama_backend_init()
  50. }
  51. func PrintSystemInfo() string {
  52. var compiler string
  53. switch C.get_compiler() {
  54. case C.COMP_UNKNOWN:
  55. compiler = "cgo(unknown_compiler)"
  56. case C.COMP_GCC:
  57. compiler = "cgo(gcc)"
  58. case C.COMP_CLANG:
  59. compiler = "cgo(clang)"
  60. }
  61. return C.GoString(C.llama_print_system_info()) + compiler
  62. }
  63. var logLevel atomic.Int32
  64. func init() {
  65. logLevel.Store(int32(C.GGML_LOG_LEVEL_INFO))
  66. C.llama_log_set((C.ggml_log_callback)(C.llamaLog), nil)
  67. }
  68. func EnableDebug() {
  69. logLevel.Store(int32(C.GGML_LOG_LEVEL_DEBUG))
  70. }
  71. //export llamaLog
  72. func llamaLog(level int32, text *C.char, _ unsafe.Pointer) {
  73. if level < logLevel.Load() {
  74. return
  75. }
  76. fmt.Fprint(os.Stderr, C.GoString(text))
  77. }
  78. func GetModelArch(modelPath string) (string, error) {
  79. mp := C.CString(modelPath)
  80. defer C.free(unsafe.Pointer(mp))
  81. gguf_ctx := C.gguf_init_from_file(mp, C.struct_gguf_init_params{no_alloc: true, ctx: (**C.struct_ggml_context)(C.NULL)})
  82. if gguf_ctx == nil {
  83. return "", errors.New("unable to load model file")
  84. }
  85. defer C.gguf_free(gguf_ctx)
  86. key := C.CString("general.architecture")
  87. defer C.free(unsafe.Pointer(key))
  88. arch_index := C.gguf_find_key(gguf_ctx, key)
  89. if int(arch_index) < 0 {
  90. return "", errors.New("unknown model architecture")
  91. }
  92. arch := C.gguf_get_val_str(gguf_ctx, arch_index)
  93. return C.GoString(arch), nil
  94. }
  95. type ContextParams struct {
  96. c C.struct_llama_context_params
  97. }
  98. func NewContextParams(numCtx int, batchSize int, numSeqMax int, threads int, flashAttention bool, kvCacheType string) ContextParams {
  99. params := C.llama_context_default_params()
  100. params.n_ctx = C.uint(numCtx)
  101. params.n_batch = C.uint(batchSize)
  102. params.n_seq_max = C.uint(numSeqMax)
  103. params.n_threads = C.int(threads)
  104. params.n_threads_batch = params.n_threads
  105. params.embeddings = C.bool(true)
  106. params.flash_attn = C.bool(flashAttention)
  107. params.type_k = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
  108. params.type_v = kvCacheTypeFromStr(strings.ToLower(kvCacheType))
  109. return ContextParams{c: params}
  110. }
  111. // kvCacheTypeFromStr converts a string cache type to the corresponding GGML type value
  112. func kvCacheTypeFromStr(s string) C.enum_ggml_type {
  113. if s == "" {
  114. return C.GGML_TYPE_F16
  115. }
  116. switch s {
  117. case "q8_0":
  118. return C.GGML_TYPE_Q8_0
  119. case "q4_0":
  120. return C.GGML_TYPE_Q4_0
  121. default:
  122. return C.GGML_TYPE_F16
  123. }
  124. }
  125. type Context struct {
  126. c *C.struct_llama_context
  127. numThreads int
  128. }
  129. var ErrKvCacheFull = errors.New("could not find a kv cache slot")
  130. func (c *Context) Decode(batch *Batch) error {
  131. // Positive return values does not mean a fatal error, but rather a warning.
  132. // 0 - success
  133. // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
  134. // < 0 - error
  135. code := int(C.llama_decode(c.c, batch.c))
  136. if code < 0 {
  137. return fmt.Errorf("llama_decode failed with code %d", code)
  138. }
  139. if code > 0 {
  140. return ErrKvCacheFull
  141. }
  142. return nil
  143. }
  144. func (c *Context) Model() *Model {
  145. return &Model{c: C.llama_get_model(c.c)}
  146. }
  147. func (c *Context) KvCacheSeqAdd(seqId int, p0 int, p1 int, delta int) {
  148. C.llama_kv_cache_seq_add(c.c, C.int(seqId), C.int(p0), C.int(p1), C.int(delta))
  149. }
  150. func (c *Context) KvCacheSeqRm(seqId int, p0 int, p1 int) bool {
  151. return bool(C.llama_kv_cache_seq_rm(c.c, C.int(seqId), C.int(p0), C.int(p1)))
  152. }
  153. func (c *Context) KvCacheSeqCp(srcSeqId int, dstSeqId int, p0 int, p1 int) {
  154. C.llama_kv_cache_seq_cp(c.c, C.int(srcSeqId), C.int(dstSeqId), C.int(p0), C.int(p1))
  155. }
  156. func (c *Context) KvCacheClear() {
  157. C.llama_kv_cache_clear(c.c)
  158. }
  159. func (c *Context) KvCacheDefrag() {
  160. C.llama_kv_cache_defrag(c.c)
  161. }
  162. // Get the embeddings for a sequence id
  163. func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
  164. e := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))
  165. if e == nil {
  166. return nil
  167. }
  168. embeddings := make([]float32, c.Model().NEmbd())
  169. _ = copy(embeddings, unsafe.Slice((*float32)(e), c.Model().NEmbd()))
  170. return embeddings
  171. }
  172. func (c *Context) GetEmbeddingsIth(i int) []float32 {
  173. e := unsafe.Pointer(C.llama_get_embeddings_ith(c.c, C.int32_t(i)))
  174. if e == nil {
  175. return nil
  176. }
  177. embeddings := make([]float32, c.Model().NEmbd())
  178. _ = copy(embeddings, unsafe.Slice((*float32)(e), c.Model().NEmbd()))
  179. return embeddings
  180. }
  181. // GetLogits returns the logits from the last decode operation.
  182. // The returned slice has length equal to the vocabulary size.
  183. func (c *Context) GetLogits() []float32 {
  184. logits := unsafe.Pointer(C.llama_get_logits(c.c))
  185. if logits == nil {
  186. return nil
  187. }
  188. // Get the number of vocabulary tokens to determine array size
  189. vocabSize := c.Model().NumVocab()
  190. return unsafe.Slice((*float32)(logits), vocabSize)
  191. }
  192. type ModelParams struct {
  193. NumGpuLayers int
  194. MainGpu int
  195. UseMmap bool
  196. UseMlock bool
  197. TensorSplit []float32
  198. Progress func(float32)
  199. VocabOnly bool
  200. }
  201. //export llamaProgressCallback
  202. func llamaProgressCallback(progress C.float, userData unsafe.Pointer) C.bool {
  203. handle := *(*cgo.Handle)(userData)
  204. callback := handle.Value().(func(float32))
  205. callback(float32(progress))
  206. return true
  207. }
  208. func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
  209. cparams := C.llama_model_default_params()
  210. cparams.n_gpu_layers = C.int(params.NumGpuLayers)
  211. cparams.main_gpu = C.int32_t(params.MainGpu)
  212. cparams.use_mmap = C.bool(params.UseMmap)
  213. cparams.use_mlock = C.bool(params.UseMlock)
  214. cparams.vocab_only = C.bool(params.VocabOnly)
  215. if len(params.TensorSplit) > 0 {
  216. tensorSplitData := &params.TensorSplit[0]
  217. var tensorSplitPin runtime.Pinner
  218. tensorSplitPin.Pin(tensorSplitData)
  219. defer tensorSplitPin.Unpin()
  220. cparams.tensor_split = (*C.float)(unsafe.Pointer(tensorSplitData))
  221. }
  222. if params.Progress != nil {
  223. handle := cgo.NewHandle(params.Progress)
  224. defer handle.Delete()
  225. var handlePin runtime.Pinner
  226. handlePin.Pin(&handle)
  227. defer handlePin.Unpin()
  228. cparams.progress_callback = C.llama_progress_callback(C.llamaProgressCallback)
  229. cparams.progress_callback_user_data = unsafe.Pointer(&handle)
  230. }
  231. m := Model{c: C.llama_load_model_from_file(C.CString(modelPath), cparams)}
  232. if m.c == nil {
  233. return nil, fmt.Errorf("unable to load model: %s", modelPath)
  234. }
  235. return &m, nil
  236. }
  237. func FreeModel(model *Model) {
  238. C.llama_free_model(model.c)
  239. }
  240. func NewContextWithModel(model *Model, params ContextParams) (*Context, error) {
  241. c := Context{
  242. c: C.llama_new_context_with_model(model.c, params.c),
  243. numThreads: int(params.c.n_threads),
  244. }
  245. if c.c == nil {
  246. return nil, errors.New("unable to create llama context")
  247. }
  248. return &c, nil
  249. }
  250. func (m *Model) NumVocab() int {
  251. return int(C.llama_n_vocab(m.c))
  252. }
  253. func (m *Model) TokenIsEog(token int) bool {
  254. return bool(C.llama_token_is_eog(m.c, C.llama_token(token)))
  255. }
  256. func (m *Model) AddBOSToken() bool {
  257. return bool(C.llama_add_bos_token(m.c))
  258. }
  259. func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float32, threads int) error {
  260. cLoraPath := C.CString(loraPath)
  261. defer C.free(unsafe.Pointer(cLoraPath))
  262. loraAdapter := C.llama_lora_adapter_init(m.c, cLoraPath)
  263. if loraAdapter == nil {
  264. return errors.New("unable to load lora")
  265. }
  266. err := -1
  267. if loraAdapter != nil {
  268. err = int(C.llama_lora_adapter_set(context.c, loraAdapter, C.float(scale)))
  269. }
  270. if err != 0 {
  271. return errors.New("error applying lora from file")
  272. }
  273. return nil
  274. }
  275. type Batch struct {
  276. c C.struct_llama_batch
  277. batchSize int
  278. maxSeq int
  279. embedSize int
  280. }
  281. // Creates a new batch for either word tokens or image embeddings (if embedSize is non-zero).
  282. // Batches cannot contain both types at the same time. batchSize is the maximum number of entries
  283. // that can be added per sequence
  284. func NewBatch(batchSize int, maxSeq int, embedSize int) (*Batch, error) {
  285. b := Batch{
  286. c: C.llama_batch_init(C.int(batchSize*maxSeq), C.int(embedSize), C.int(maxSeq)),
  287. batchSize: batchSize,
  288. maxSeq: maxSeq,
  289. embedSize: embedSize,
  290. }
  291. // Check to see if any of the allocations in llama_batch_init() failed
  292. nilPointer := (embedSize == 0 && b.c.token == nil) || (embedSize != 0 && b.c.embd == nil) ||
  293. b.c.pos == nil || b.c.n_seq_id == nil || b.c.seq_id == nil || b.c.logits == nil ||
  294. slices.Contains(unsafe.Slice(b.c.seq_id, b.allocSize()), nil)
  295. if nilPointer {
  296. C.llama_batch_free(b.c)
  297. return nil, fmt.Errorf("unable to allocate batch (batchSize=%v maxSeq=%v embedSize=%v)", batchSize, maxSeq, embedSize)
  298. }
  299. return &b, nil
  300. }
  301. func (b *Batch) Size() int {
  302. return b.batchSize
  303. }
  304. func (b *Batch) allocSize() int {
  305. return b.batchSize * b.maxSeq
  306. }
  307. func (b *Batch) NumTokens() int {
  308. return int(b.c.n_tokens)
  309. }
  310. func (b *Batch) IsEmbedding() bool {
  311. return b.embedSize != 0
  312. }
  313. // Add adds either a token or an image embedding to the batch depending on the type
  314. // when the batch was initialized. The other argument will be ignored. Adds to the
  315. // batch with the given position for the given sequence ids, and optionally instructs
  316. // to include logits.
  317. func (b *Batch) Add(token int, embed []float32, pos int, logits bool, seqIds ...int) {
  318. if !b.IsEmbedding() {
  319. unsafe.Slice(b.c.token, b.allocSize())[b.c.n_tokens] = C.llama_token(token)
  320. } else {
  321. copy(unsafe.Slice((*float32)(b.c.embd), b.allocSize()*b.embedSize)[int(b.c.n_tokens)*b.embedSize:], embed)
  322. }
  323. unsafe.Slice(b.c.pos, b.allocSize())[b.c.n_tokens] = C.llama_pos(pos)
  324. unsafe.Slice(b.c.n_seq_id, b.allocSize())[b.c.n_tokens] = C.int(len(seqIds))
  325. for i, s := range seqIds {
  326. unsafe.Slice((unsafe.Slice(b.c.seq_id, b.allocSize())[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
  327. }
  328. if logits {
  329. unsafe.Slice(b.c.logits, b.allocSize())[b.c.n_tokens] = 1
  330. } else {
  331. unsafe.Slice(b.c.logits, b.allocSize())[b.c.n_tokens] = 0
  332. }
  333. b.c.n_tokens += 1
  334. }
  335. func (b *Batch) Clear() {
  336. b.c.n_tokens = 0
  337. }
  338. func (b *Batch) Free() {
  339. b.batchSize = 0
  340. C.llama_batch_free(b.c)
  341. }
  342. type Model struct {
  343. c *C.struct_llama_model
  344. }
  345. func (m *Model) TokenToPiece(token int) string {
  346. tokenLen := 12
  347. buf := make([]byte, tokenLen)
  348. tokenLen = int(C.llama_token_to_piece(
  349. m.c,
  350. C.int32_t(token),
  351. (*C.char)(unsafe.Pointer(&buf[0])),
  352. C.int32_t(tokenLen),
  353. C.int32_t(0),
  354. C.bool(true),
  355. ))
  356. if tokenLen < 0 {
  357. tokenLen = -tokenLen
  358. buf = make([]byte, tokenLen)
  359. C.llama_token_to_piece(
  360. m.c,
  361. C.int32_t(token),
  362. (*C.char)(unsafe.Pointer(&buf[0])),
  363. C.int32_t(tokenLen),
  364. C.int32_t(0),
  365. C.bool(true),
  366. )
  367. }
  368. return strings.TrimRight(string(buf), "\x00")
  369. }
  370. func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int, error) {
  371. maxTokens := len(text) + 2
  372. cTokens := make([]C.llama_token, maxTokens)
  373. cText := C.CString(text)
  374. defer C.free(unsafe.Pointer(cText))
  375. result := C.llama_tokenize(
  376. m.c,
  377. cText,
  378. C.int32_t(len(text)),
  379. &cTokens[0],
  380. C.int32_t(maxTokens),
  381. C.bool(addSpecial),
  382. C.bool(parseSpecial),
  383. )
  384. // if the result is negative, reallocate and retry with the correct buffer size
  385. if result < 0 {
  386. maxTokens = int(-result)
  387. cTokens = make([]C.llama_token, maxTokens)
  388. result = C.llama_tokenize(
  389. m.c,
  390. cText,
  391. C.int32_t(len(text)),
  392. &cTokens[0],
  393. C.int32_t(maxTokens),
  394. C.bool(addSpecial),
  395. C.bool(parseSpecial),
  396. )
  397. if result < 0 {
  398. return nil, fmt.Errorf("tokenization failed, required %d tokens", -result)
  399. }
  400. }
  401. tokens := make([]int, result)
  402. for i := range result {
  403. tokens[i] = int(cTokens[i])
  404. }
  405. return tokens, nil
  406. }
  407. func (m *Model) NEmbd() int {
  408. return int(C.llama_n_embd(m.c))
  409. }
  410. func Quantize(infile, outfile string, ftype uint32) error {
  411. cinfile := C.CString(infile)
  412. defer C.free(unsafe.Pointer(cinfile))
  413. coutfile := C.CString(outfile)
  414. defer C.free(unsafe.Pointer(coutfile))
  415. params := C.llama_model_quantize_default_params()
  416. params.nthread = -1
  417. params.ftype = ftype
  418. if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
  419. return fmt.Errorf("llama_model_quantize: %d", rc)
  420. }
  421. return nil
  422. }
  423. // vision processing
  424. type ClipContext struct {
  425. c *C.struct_clip_ctx
  426. }
  427. func NewClipContext(llamaContext *Context, modelPath string) (*ClipContext, error) {
  428. mp := C.CString(modelPath)
  429. defer C.free(unsafe.Pointer(mp))
  430. c := C.clip_model_load(mp, 1)
  431. if c == nil {
  432. return nil, fmt.Errorf("unable to load clip model: %v", modelPath)
  433. }
  434. projEmbedSize := int(C.clip_n_mmproj_embd(c))
  435. modelEmbedSize := llamaContext.Model().NEmbd()
  436. if projEmbedSize != modelEmbedSize {
  437. return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
  438. }
  439. return &ClipContext{c: c}, nil
  440. }
  441. func (c *ClipContext) Free() {
  442. C.clip_free(c.c)
  443. }
  444. func (c *ClipContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32, error) {
  445. l := C.llava_image_embed_make_with_bytes(c.c, C.int(llamaContext.numThreads), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data)))
  446. if l == nil {
  447. return nil, errors.New("unable to make llava embedding from image")
  448. }
  449. numTokens := int(l.n_image_pos)
  450. numEmbed := llamaContext.Model().NEmbd()
  451. s := unsafe.Slice((*float32)(l.embed), numEmbed*numTokens)
  452. embed := make([][]float32, numTokens)
  453. rows := make([]float32, len(s))
  454. copy(rows, s)
  455. for i := range embed {
  456. embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
  457. }
  458. C.llava_image_embed_free(l)
  459. return embed, nil
  460. }
  461. type MllamaContext struct {
  462. c *C.struct_mllama_ctx
  463. }
  464. func NewMllamaContext(llamaContext *Context, modelPath string) (*MllamaContext, error) {
  465. mp := C.CString(modelPath)
  466. defer C.free(unsafe.Pointer(mp))
  467. c := C.mllama_model_load(mp, 1)
  468. if c == nil {
  469. return nil, fmt.Errorf("unable to load mllama model: %v", modelPath)
  470. }
  471. projEmbedSize := int(C.mllama_n_embd(c))
  472. modelEmbedSize := llamaContext.Model().NEmbd()
  473. if projEmbedSize != modelEmbedSize {
  474. return nil, fmt.Errorf("projector embedding size (%d) does not match model (%d)", projEmbedSize, modelEmbedSize)
  475. }
  476. return &MllamaContext{c: c}, nil
  477. }
  478. func (m *MllamaContext) Free() {
  479. C.mllama_free(m.c)
  480. }
  481. func (m *MllamaContext) NewEmbed(llamaContext *Context, data []byte, aspectRatioId int) ([][]float32, error) {
  482. img := C.mllama_image_init()
  483. defer C.mllama_image_free(img)
  484. ok := bool(C.mllama_image_load_from_data(unsafe.Pointer(&data[0]), C.int(len(data)), 560, 560, 3, 4, C.int(aspectRatioId), img))
  485. if !ok {
  486. return nil, errors.New("unable to load mllama image data")
  487. }
  488. rows := make([]float32, m.EmbedSize(llamaContext))
  489. ok = bool(C.mllama_image_encode(m.c, C.int(llamaContext.numThreads), img, (*C.float)(unsafe.Pointer(&rows[0]))))
  490. if !ok {
  491. return nil, errors.New("unable to make mllama embedding from image")
  492. }
  493. embed := make([][]float32, 1)
  494. embed[0] = rows
  495. return embed, nil
  496. }
  497. func (m *MllamaContext) EmbedSize(llamaContext *Context) int {
  498. numTokens := int(C.mllama_n_positions(m.c) * C.mllama_n_tiles(m.c))
  499. numEmbed := llamaContext.Model().NEmbd()
  500. return numTokens * numEmbed
  501. }
  502. func (c *Context) SetCrossAttention(state bool) {
  503. C.llama_set_cross_attention(c.c, C.bool(state))
  504. }
  505. func (c *Context) Synchronize() {
  506. C.llama_synchronize(c.c)
  507. }
  508. // sampling
  509. // TODO: this is a temporary wrapper to allow calling C++ code from CGo
  510. type SamplingContext struct {
  511. c *C.struct_common_sampler
  512. }
  513. type SamplingParams struct {
  514. TopK int
  515. TopP float32
  516. MinP float32
  517. TypicalP float32
  518. Temp float32
  519. RepeatLastN int
  520. PenaltyRepeat float32
  521. PenaltyFreq float32
  522. PenaltyPresent float32
  523. Mirostat int
  524. MirostatTau float32
  525. MirostatEta float32
  526. PenalizeNl bool
  527. Seed uint32
  528. Grammar string
  529. }
  530. func NewSamplingContext(model *Model, params SamplingParams) (*SamplingContext, error) {
  531. var cparams C.struct_common_sampler_cparams
  532. cparams.top_k = C.int32_t(params.TopK)
  533. cparams.top_p = C.float(params.TopP)
  534. cparams.min_p = C.float(params.MinP)
  535. cparams.typical_p = C.float(params.TypicalP)
  536. cparams.temp = C.float(params.Temp)
  537. cparams.penalty_last_n = C.int32_t(params.RepeatLastN)
  538. cparams.penalty_repeat = C.float(params.PenaltyRepeat)
  539. cparams.penalty_freq = C.float(params.PenaltyFreq)
  540. cparams.penalty_present = C.float(params.PenaltyFreq)
  541. cparams.mirostat = C.int32_t(params.Mirostat)
  542. cparams.mirostat_tau = C.float(params.MirostatTau)
  543. cparams.mirostat_eta = C.float(params.MirostatEta)
  544. cparams.seed = C.uint32_t(params.Seed)
  545. grammar := C.CString(params.Grammar)
  546. defer C.free(unsafe.Pointer(grammar))
  547. cparams.grammar = grammar
  548. context := &SamplingContext{c: C.common_sampler_cinit(model.c, &cparams)}
  549. if context.c == nil {
  550. return nil, errors.New("unable to create sampling context")
  551. }
  552. runtime.SetFinalizer(context, func(s *SamplingContext) { C.common_sampler_cfree(s.c) })
  553. return context, nil
  554. }
  555. func (s *SamplingContext) Reset() {
  556. C.common_sampler_creset(s.c)
  557. }
  558. func (s *SamplingContext) Sample(llamaContext *Context, idx int) int {
  559. return int(C.common_sampler_csample(s.c, llamaContext.c, C.int(idx)))
  560. }
  561. func (s *SamplingContext) Accept(id int, applyGrammar bool) {
  562. C.common_sampler_caccept(s.c, C.llama_token(id), C.bool(applyGrammar))
  563. }
  564. // SchemaToGrammar converts the provided JSON schema to a grammar. It returns
  565. // nil if the provided schema is invalid JSON or an invalid JSON schema.
  566. func SchemaToGrammar(schema []byte) []byte {
  567. cStr := C.CString(string(schema))
  568. defer C.free(unsafe.Pointer(cStr))
  569. // Allocate buffer for grammar output with reasonable size
  570. const maxLen = 32768 // 32KB
  571. buf := make([]byte, maxLen)
  572. // Call C function to convert schema to grammar
  573. n := C.schema_to_grammar(cStr, (*C.char)(unsafe.Pointer(&buf[0])), C.size_t(maxLen))
  574. if n == 0 {
  575. // preserve nil
  576. return nil
  577. }
  578. return buf[:n]
  579. }