llama.go 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. package llama
  2. // #cgo darwin,arm64 CFLAGS: -std=c11 -DGGML_USE_METAL -DGGML_METAL_EMBED_LIBRARY -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
  3. // #cgo darwin,arm64 CXXFLAGS: -std=c++11 -DGGML_USE_METAL -DGGML_METAL_EMBED_LIBRARY -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
  4. // #cgo darwin,arm64 LDFLAGS: -ld_classic ${SRCDIR}/ggml-metal.o -framework Foundation -framework Metal -framework MetalKit -framework Accelerate
  5. // #cgo darwin,amd64 CFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
  6. // #cgo darwin,amd64 CXXFLAGS: -std=c++11 -Wno-incompatible-pointer-types-discards-qualifiers
  7. // #cgo darwin,amd64 LDFLAGS: -ld_classic -framework Foundation -framework Accelerate
  8. // #cgo windows LDFLAGS: -lmsvcrt
  9. // #cgo avx CFLAGS: -mavx
  10. // #cgo avx CXXFLAGS: -mavx
  11. // #cgo avx2 CFLAGS: -mavx2 -mfma
  12. // #cgo avx2 CXXFLAGS: -mavx2 -mfma
  13. // #cgo cuda CFLAGS: -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_MULTIPLATFORM -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  14. // #cgo cuda CXXFLAGS: -std=c++11 -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_MULTIPLATFORM -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  15. // #cgo rocm CXXFLAGS: -std=c++11 -DGGML_USE_CUDA -DGGML_USE_HIPBLAS -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_MULTIPLATFORM -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
  16. // #cgo windows,cuda LDFLAGS: -L. -L"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.3/lib/x64" -lggml-cuda -lcuda -lcudart -lcublas -lcublasLt
  17. // #cgo windows,rocm LDFLAGS: -L. -L"C:/Program Files/AMD/ROCm/5.7/lib" -lggml-hipblas -lhipblas -lamdhip64 -lrocblas
  18. // #include <stdlib.h>
  19. // #include "llama.h"
  20. import "C"
  21. import (
  22. "fmt"
  23. "runtime"
  24. "strings"
  25. "unsafe"
  26. "github.com/ollama/ollama/llm"
  27. )
  28. type Token int32
  29. type Pos int32
  30. type SeqId int32
  31. // SystemInfo is an unused example of calling llama.cpp functions using CGo
  32. func PrintSystemInfo() string {
  33. return C.GoString(C.llama_print_system_info())
  34. }
  35. func BackendInit() {
  36. C.llama_backend_init()
  37. }
  38. type ContextParams struct {
  39. c C.struct_llama_context_params
  40. }
  41. func NewContextParams() ContextParams {
  42. params := C.llama_context_default_params()
  43. params.seed = C.uint(1234)
  44. params.n_ctx = C.uint(2048)
  45. params.n_threads = C.uint(runtime.NumCPU())
  46. params.n_threads_batch = params.n_threads
  47. return ContextParams{c: params}
  48. }
  49. type ModelParams struct {
  50. c C.struct_llama_model_params
  51. }
  52. func NewModelParams() ModelParams {
  53. params := C.llama_model_default_params()
  54. params.n_gpu_layers = 999
  55. return ModelParams{c: params}
  56. }
  57. type Context struct {
  58. c *C.struct_llama_context
  59. }
  60. func (c *Context) Decode(batch Batch) error {
  61. // Positive return values does not mean a fatal error, but rather a warning.
  62. // 0 - success
  63. // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
  64. // < 0 - error
  65. code := int(C.llama_decode(c.c, batch.c))
  66. if code < 0 {
  67. return fmt.Errorf("llama_decode failed with code %d", code)
  68. }
  69. if code > 0 {
  70. return fmt.Errorf("could not find a KV slot for the batch - try reducing the size of the batch or increase the context. code: %d\n", code)
  71. }
  72. return nil
  73. }
  74. func (c *Context) getModel() *Model {
  75. return &Model{c: C.llama_get_model(c.c)}
  76. }
  77. func (c *Context) SampleTokenGreedy(batch Batch) Token {
  78. nv := c.getModel().NumVocab()
  79. // TODO(jmorganca): split this up into different functions
  80. candidates := (*C.struct_llama_token_data)(C.malloc(C.size_t(nv) * C.size_t(unsafe.Sizeof(C.struct_llama_token_data{}))))
  81. defer C.free(unsafe.Pointer(candidates))
  82. // get most recent logits
  83. logits := C.llama_get_logits_ith(c.c, C.int(batch.NumTokens()-1))
  84. for i := 0; i < int(nv); i++ {
  85. ptr := (*C.struct_llama_token_data)(unsafe.Pointer(uintptr(unsafe.Pointer(candidates)) + uintptr(i)*unsafe.Sizeof(C.struct_llama_token_data{})))
  86. ptr.id = C.int(i)
  87. ptr.logit = unsafe.Slice(logits, nv)[i]
  88. ptr.p = 0.0
  89. }
  90. return Token(C.llama_sample_token_greedy(c.c, &C.llama_token_data_array{
  91. data: candidates,
  92. size: C.size_t(nv),
  93. sorted: C.bool(false),
  94. }))
  95. }
  96. func LoadModelFromFile(modelPath string, params ModelParams) *Model {
  97. return &Model{c: C.llama_load_model_from_file(C.CString(modelPath), params.c)}
  98. }
  99. func NewContextWithModel(model *Model, params ContextParams) *Context {
  100. return &Context{c: C.llama_new_context_with_model(model.c, params.c)}
  101. }
  102. func (m *Model) NumVocab() int {
  103. return int(C.llama_n_vocab(m.c))
  104. }
  105. func (m *Model) TokenIsEog(token Token) bool {
  106. return bool(C.llama_token_is_eog(m.c, C.llama_token(token)))
  107. }
  108. type Batch struct {
  109. c C.struct_llama_batch
  110. }
  111. func NewBatch(nTokens int, nSeqs int, nCtx int) Batch {
  112. return Batch{c: C.llama_batch_init(C.int(nTokens), C.int(nSeqs), C.int(nCtx))}
  113. }
  114. func (b *Batch) NumTokens() int {
  115. return int(b.c.n_tokens)
  116. }
  117. func (b *Batch) Add(token Token, pos Pos, seqIds []SeqId, logits bool) {
  118. unsafe.Slice(b.c.token, 512)[b.c.n_tokens] = C.llama_token(token)
  119. unsafe.Slice(b.c.pos, 512)[b.c.n_tokens] = C.llama_pos(pos)
  120. unsafe.Slice(b.c.n_seq_id, 512)[b.c.n_tokens] = C.int(len(seqIds))
  121. for i, s := range seqIds {
  122. unsafe.Slice((unsafe.Slice(b.c.seq_id, 512)[b.c.n_tokens]), C.int(len(seqIds)))[i] = C.int32_t(s)
  123. }
  124. if logits {
  125. unsafe.Slice(b.c.logits, 512)[b.c.n_tokens] = 1
  126. }
  127. b.c.n_tokens += 1
  128. }
  129. func (b *Batch) Clear() {
  130. b.c.n_tokens = 0
  131. }
  132. type Model struct {
  133. c *C.struct_llama_model
  134. }
  135. func (m *Model) TokenToPiece(token Token) string {
  136. buf := make([]byte, 12)
  137. C.llama_token_to_piece(
  138. m.c,
  139. C.int32_t(token),
  140. (*C.char)(unsafe.Pointer(&buf[0])),
  141. C.int32_t(12),
  142. C.bool(true),
  143. )
  144. return strings.TrimRight(string(buf), "\x00")
  145. }
  146. func (m *Model) Tokenize(text string, maxTokens int, addSpecial bool, parseSpecial bool) ([]Token, error) {
  147. cTokens := make([]C.llama_token, maxTokens)
  148. cText := C.CString(text)
  149. defer C.free(unsafe.Pointer(cText))
  150. result := C.llama_tokenize(
  151. m.c,
  152. cText,
  153. C.int32_t(len(text)),
  154. &cTokens[0],
  155. C.int32_t(maxTokens),
  156. C.bool(addSpecial),
  157. C.bool(parseSpecial),
  158. )
  159. if result < 0 {
  160. return nil, fmt.Errorf("tokenization failed, required %d tokens", -result)
  161. }
  162. tokens := make([]Token, result)
  163. for i := 0; i < int(result); i++ {
  164. tokens[i] = Token(cTokens[i])
  165. }
  166. return tokens, nil
  167. }
  168. func Quantize(infile, outfile string, ftype llm.FileType) error {
  169. cinfile := C.CString(infile)
  170. defer C.free(unsafe.Pointer(cinfile))
  171. coutfile := C.CString(outfile)
  172. defer C.free(unsafe.Pointer(coutfile))
  173. params := C.llama_model_quantize_default_params()
  174. params.nthread = -1
  175. params.ftype = ftype.Value()
  176. if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
  177. return fmt.Errorf("llama_model_quantize: %d", rc)
  178. }
  179. return nil
  180. }