llm.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. package llm
  2. import (
  3. "context"
  4. "fmt"
  5. "log"
  6. "os"
  7. "runtime"
  8. "github.com/jmorganca/ollama/api"
  9. "github.com/jmorganca/ollama/gpu"
  10. )
  11. type LLM interface {
  12. Predict(context.Context, PredictOpts, func(PredictResult)) error
  13. Embedding(context.Context, string) ([]float64, error)
  14. Encode(context.Context, string) ([]int, error)
  15. Decode(context.Context, []int) (string, error)
  16. Close()
  17. }
  18. func New(workDir, model string, adapters, projectors []string, opts api.Options) (LLM, error) {
  19. if _, err := os.Stat(model); err != nil {
  20. return nil, err
  21. }
  22. f, err := os.Open(model)
  23. if err != nil {
  24. return nil, err
  25. }
  26. defer f.Close()
  27. ggml, err := DecodeGGML(f)
  28. if err != nil {
  29. return nil, err
  30. }
  31. if opts.NumCtx < 4 {
  32. opts.NumCtx = 4
  33. }
  34. vram, _ := gpu.CheckVRAM()
  35. size := ggml.Size
  36. // fp16 k,v matrices require = n_ctx * n_layer * n_embd / n_head * n_head_kv * 2 bytes each * 2 key and value
  37. kv := 2 * 2 * int64(opts.NumCtx) * int64(ggml.NumLayers()) * int64(ggml.NumEmbed()) * int64(ggml.NumHeadKv()) / int64(ggml.NumHead())
  38. // this amount is the overhead + tensors in memory
  39. // TODO: get this from the llama.cpp's graph calculations instead of
  40. // estimating it's 1/6 * kv_cache_size * num_gqa
  41. graph := int64(ggml.NumGQA()) * kv / 6
  42. info := gpu.GetGPUInfo()
  43. switch runtime.GOOS {
  44. case "darwin":
  45. if opts.NumGPU == 0 {
  46. break
  47. }
  48. if size+kv+graph > vram {
  49. log.Println("not enough vram available, falling back to CPU only")
  50. info.Library = "cpu"
  51. info.Variant = gpu.GetCPUVariant()
  52. opts.NumGPU = 0
  53. break
  54. }
  55. opts.NumGPU = 1
  56. default:
  57. if info.Library == "cpu" {
  58. log.Println("GPU not available, falling back to CPU")
  59. opts.NumGPU = 0
  60. break
  61. }
  62. // don't use GPU at all if no layers are loaded
  63. if opts.NumGPU == 0 {
  64. info.Library = "cpu"
  65. info.Variant = gpu.GetCPUVariant()
  66. break
  67. }
  68. // user-defined GPU count
  69. if opts.NumGPU != -1 {
  70. break
  71. }
  72. // the "main" GPU needs the most memory and determines the limit
  73. // of how many layers can be loaded. It needs to fit:
  74. // 1. the full compute graph allocation for all devices (graph)
  75. // 2. the proportional kv cache for all devices (kv * % layers)
  76. // 3. the proportional model (size * % layers / # devices)
  77. // This estimates the number of layers
  78. maxlayers := int64(ggml.NumLayers()) + 1
  79. devices := int64(info.DeviceCount)
  80. avg := vram / devices
  81. layers := maxlayers * (avg - graph) / (kv + size/devices)
  82. if layers > maxlayers {
  83. layers = maxlayers
  84. }
  85. // 1 + 2 must fit on the main gpu
  86. min := graph + kv*layers/maxlayers
  87. if layers <= 0 || min > avg {
  88. log.Printf("not enough vram available, falling back to CPU only")
  89. info.Library = "cpu"
  90. info.Variant = gpu.GetCPUVariant()
  91. opts.NumGPU = 0
  92. break
  93. }
  94. opts.NumGPU = int(layers)
  95. }
  96. opts.RopeFrequencyBase = 0.0
  97. opts.RopeFrequencyScale = 0.0
  98. return newLlmServer(info, model, adapters, projectors, opts)
  99. }
  100. // Give any native cgo implementations an opportunity to initialize
  101. func Init(workdir string) error {
  102. return nativeInit(workdir)
  103. }
  104. func newLlmServer(gpuInfo gpu.GpuInfo, model string, adapters, projectors []string, opts api.Options) (LLM, error) {
  105. dynLibs := getDynLibs(gpuInfo)
  106. // Check to see if the user has requested a specific library instead of auto-detecting
  107. demandLib := os.Getenv("OLLAMA_LLM_LIBRARY")
  108. if demandLib != "" {
  109. libPath := availableDynLibs[demandLib]
  110. if libPath == "" {
  111. log.Printf("Invalid OLLAMA_LLM_LIBRARY %s - not found", demandLib)
  112. } else {
  113. log.Printf("Loading OLLAMA_LLM_LIBRARY=%s", demandLib)
  114. dynLibs = []string{libPath}
  115. }
  116. }
  117. err2 := fmt.Errorf("unable to locate suitable llm library")
  118. for _, dynLib := range dynLibs {
  119. srv, err := newDynExtServer(dynLib, model, adapters, projectors, opts)
  120. if err == nil {
  121. return srv, nil
  122. }
  123. log.Printf("Failed to load dynamic library %s %s", dynLib, err)
  124. err2 = err
  125. }
  126. return nil, err2
  127. }