llama.go 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. package llama
  2. // #cgo CFLAGS: -std=c11 -DNDEBUG -DLOG_DISABLE_LOGS
  3. // #cgo CXXFLAGS: -std=c++11 -DNDEBUG -DLOG_DISABLE_LOGS
  4. // #cgo darwin,arm64 CFLAGS: -DGGML_USE_METAL -DGGML_METAL_EMBED_LIBRARY -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
  5. // #cgo darwin,arm64 CXXFLAGS: -DGGML_USE_METAL -DGGML_METAL_EMBED_LIBRARY -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
  6. // #cgo darwin,arm64 LDFLAGS: -ld_classic ${SRCDIR}/ggml-metal.o -framework Foundation -framework Metal -framework MetalKit -framework Accelerate
  7. // #cgo darwin,amd64 CFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
  8. // #cgo darwin,amd64 CXXFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
  9. // #cgo darwin,amd64 LDFLAGS: -ld_classic -framework Foundation -framework Accelerate
  10. // #cgo linux CFLAGS: -D_GNU_SOURCE
  11. // #cgo linux CXXFLAGS: -D_GNU_SOURCE
  12. // #cgo windows LDFLAGS: -lmsvcrt
  13. // #cgo avx CFLAGS: -mavx
  14. // #cgo avx CXXFLAGS: -mavx
  15. // #cgo avx2 CFLAGS: -mavx2 -mfma
  16. // #cgo avx2 CXXFLAGS: -mavx2 -mfma
  17. // #cgo cuda CFLAGS: -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  18. // #cgo cuda CXXFLAGS: -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  19. // #cgo rocm CFLAGS: -DGGML_USE_CUDA -DGGML_USE_HIPBLAS -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  20. // #cgo rocm CXXFLAGS: -DGGML_USE_CUDA -DGGML_USE_HIPBLAS -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  21. // #cgo rocm LDFLAGS: -L${SRCDIR} -lggml-hipblas -lhipblas -lamdhip64 -lrocblas
  22. // #cgo windows,cuda LDFLAGS: -L. -L"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.3/lib/x64" -lggml-cuda -lcuda -lcudart -lcublas -lcublasLt
  23. // #cgo windows,rocm LDFLAGS: -L. -L"C:/Program Files/AMD/ROCm/5.7/lib"
  24. // #cgo linux,cuda LDFLAGS: -L${SRCDIR} -L/usr/local/cuda/lib64 -lggml-cuda -lcuda -lcudart -lcublas -lcublasLt -lpthread -ldl -lrt
  25. // #cgo linux,rocm LDFLAGS: -L/opt/rocm/lib
  26. // #include <stdlib.h>
  27. // #include "llama.h"
  28. // #include "clip.h"
  29. // #include "llava.h"
  30. import "C"
  31. import (
  32. "fmt"
  33. "runtime"
  34. "strings"
  35. "unsafe"
  36. "github.com/ollama/ollama/llm"
  37. )
  38. func BackendInit() {
  39. C.llama_backend_init()
  40. }
  41. func PrintSystemInfo() string {
  42. return C.GoString(C.llama_print_system_info())
  43. }
  44. type ContextParams struct {
  45. c C.struct_llama_context_params
  46. }
  47. func NewContextParams() ContextParams {
  48. params := C.llama_context_default_params()
  49. params.seed = C.uint(1234)
  50. params.n_ctx = C.uint(2048)
  51. params.n_threads = C.uint(runtime.NumCPU())
  52. params.n_threads_batch = params.n_threads
  53. return ContextParams{c: params}
  54. }
  55. type ModelParams struct {
  56. c C.struct_llama_model_params
  57. }
  58. func NewModelParams() ModelParams {
  59. params := C.llama_model_default_params()
  60. params.n_gpu_layers = 999
  61. return ModelParams{c: params}
  62. }
  63. type Context struct {
  64. c *C.struct_llama_context
  65. }
  66. func (c *Context) KvCacheClear() {
  67. C.llama_kv_cache_clear(c.c)
  68. }
  69. func (c *Context) Decode(batch Batch) error {
  70. // Positive return values does not mean a fatal error, but rather a warning.
  71. // 0 - success
  72. // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
  73. // < 0 - error
  74. code := int(C.llama_decode(c.c, batch.c))
  75. if code < 0 {
  76. return fmt.Errorf("llama_decode failed with code %d", code)
  77. }
  78. if code > 0 {
  79. return fmt.Errorf("could not find a KV slot for the batch - try reducing the size of the batch or increase the context. code: %d", code)
  80. }
  81. return nil
  82. }
  83. func (c *Context) Model() *Model {
  84. return &Model{c: C.llama_get_model(c.c)}
  85. }
  86. // TODO: break this up
  87. func (c *Context) SampleTokenGreedy(batch Batch, i int) int {
  88. nv := c.Model().NumVocab()
  89. // TODO(jmorganca): split this up into different functions
  90. candidates := (*C.struct_llama_token_data)(C.malloc(C.size_t(nv) * C.size_t(unsafe.Sizeof(C.struct_llama_token_data{}))))
  91. defer C.free(unsafe.Pointer(candidates))
  92. // get most recent logits
  93. logits := C.llama_get_logits_ith(c.c, C.int(i))
  94. for i := 0; i < int(nv); i++ {
  95. ptr := (*C.struct_llama_token_data)(unsafe.Pointer(uintptr(unsafe.Pointer(candidates)) + uintptr(i)*unsafe.Sizeof(C.struct_llama_token_data{})))
  96. ptr.id = C.int(i)
  97. ptr.logit = unsafe.Slice(logits, nv)[i]
  98. ptr.p = 0.0
  99. }
  100. return int(C.llama_sample_token_greedy(c.c, &C.llama_token_data_array{
  101. data: candidates,
  102. size: C.size_t(nv),
  103. sorted: C.bool(false),
  104. }))
  105. }
  106. func (c *Context) KvCacheSeqRm(seqId int, p0 int, p1 int) bool {
  107. return bool(C.llama_kv_cache_seq_rm(c.c, C.int(seqId), C.int(p0), C.int(p1)))
  108. }
  109. func LoadModelFromFile(modelPath string, params ModelParams) *Model {
  110. return &Model{c: C.llama_load_model_from_file(C.CString(modelPath), params.c)}
  111. }
  112. func NewContextWithModel(model *Model, params ContextParams) *Context {
  113. return &Context{c: C.llama_new_context_with_model(model.c, params.c)}
  114. }
  115. func (m *Model) NumVocab() int {
  116. return int(C.llama_n_vocab(m.c))
  117. }
  118. func (m *Model) TokenIsEog(token int) bool {
  119. return bool(C.llama_token_is_eog(m.c, C.llama_token(token)))
  120. }
  121. type Batch struct {
  122. c C.struct_llama_batch
  123. }
  124. func NewBatch(nTokens int, embd int, maxSeq int) Batch {
  125. return Batch{c: C.llama_batch_init(C.int(nTokens), C.int(embd), C.int(maxSeq))}
  126. }
  127. func (b *Batch) NumTokens() int {
  128. return int(b.c.n_tokens)
  129. }
  130. func (b *Batch) Add(token int, pos int, seqIds []int, logits bool) {
  131. unsafe.Slice(b.c.token, 512)[b.c.n_tokens] = C.llama_token(token)
  132. unsafe.Slice(b.c.pos, 512)[b.c.n_tokens] = C.llama_pos(pos)
  133. unsafe.Slice(b.c.n_seq_id, 512)[b.c.n_tokens] = C.int(len(seqIds))
  134. for i, s := range seqIds {
  135. unsafe.Slice((unsafe.Slice(b.c.seq_id, 512)[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
  136. }
  137. if logits {
  138. unsafe.Slice(b.c.logits, 512)[b.c.n_tokens] = 1
  139. }
  140. b.c.n_tokens += 1
  141. }
  142. func (b *Batch) Clear() {
  143. b.c.n_tokens = 0
  144. }
  145. func (b *Batch) Free() {
  146. C.llama_batch_free(b.c)
  147. }
  148. // LLAMA_API struct llama_batch llama_batch_get_one(
  149. //
  150. // llama_token * tokens,
  151. // int32_t n_tokens,
  152. // llama_pos pos_0,
  153. // llama_seq_id seq_id);
  154. func BatchGetOne(tokens []int, pos0 int, seqId int) Batch {
  155. return Batch{c: C.llama_batch_get_one((*C.int)(unsafe.Pointer(&tokens[0])), C.int32_t(len(tokens)), C.int(pos0), C.int(seqId))}
  156. }
  157. type Model struct {
  158. c *C.struct_llama_model
  159. }
  160. func (m *Model) TokenToPiece(token int) string {
  161. buf := make([]byte, 12)
  162. C.llama_token_to_piece(
  163. m.c,
  164. C.int32_t(token),
  165. (*C.char)(unsafe.Pointer(&buf[0])),
  166. C.int32_t(12),
  167. C.bool(true),
  168. )
  169. return strings.TrimRight(string(buf), "\x00")
  170. }
  171. func (m *Model) Tokenize(text string, maxTokens int, addSpecial bool, parseSpecial bool) ([]int, error) {
  172. cTokens := make([]C.llama_token, maxTokens)
  173. cText := C.CString(text)
  174. defer C.free(unsafe.Pointer(cText))
  175. result := C.llama_tokenize(
  176. m.c,
  177. cText,
  178. C.int32_t(len(text)),
  179. &cTokens[0],
  180. C.int32_t(maxTokens),
  181. C.bool(addSpecial),
  182. C.bool(parseSpecial),
  183. )
  184. if result < 0 {
  185. return nil, fmt.Errorf("tokenization failed, required %d tokens", -result)
  186. }
  187. tokens := make([]int, result)
  188. for i := 0; i < int(result); i++ {
  189. tokens[i] = int(cTokens[i])
  190. }
  191. return tokens, nil
  192. }
  193. func Quantize(infile, outfile string, ftype llm.FileType) error {
  194. cinfile := C.CString(infile)
  195. defer C.free(unsafe.Pointer(cinfile))
  196. coutfile := C.CString(outfile)
  197. defer C.free(unsafe.Pointer(coutfile))
  198. params := C.llama_model_quantize_default_params()
  199. params.nthread = -1
  200. params.ftype = ftype.Value()
  201. if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
  202. return fmt.Errorf("llama_model_quantize: %d", rc)
  203. }
  204. return nil
  205. }
  206. type ClipContext struct {
  207. c *C.struct_clip_ctx
  208. }
  209. func NewClipContext(modelPath string) *ClipContext {
  210. mp := C.CString(modelPath)
  211. defer C.free(unsafe.Pointer(mp))
  212. cc := C.clip_model_load(mp, 1)
  213. return &ClipContext{c: cc}
  214. }
  215. type LlavaContext struct {
  216. c *C.struct_llava_context
  217. }
  218. type LlavaImageEmbed struct {
  219. c *C.struct_llava_image_embed
  220. }
  221. func NewLlavaImageEmbed(clipContext *ClipContext, data []byte) *LlavaImageEmbed {
  222. return &LlavaImageEmbed{c: C.llava_image_embed_make_with_bytes(clipContext.c, C.int(runtime.NumCPU()), (*C.uchar)(unsafe.Pointer(&data[0])), C.int(len(data)))}
  223. }
  224. func LlavaEvalImageEmbed(llamaContext *Context, embed *LlavaImageEmbed, nBatch int, nPast *int) {
  225. C.llava_eval_image_embed(llamaContext.c, embed.c, C.int(nBatch), (*C.int)(unsafe.Pointer(nPast)))
  226. }