llm.go 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. package llm
  2. import (
  3. "context"
  4. "log"
  5. "os"
  6. "runtime"
  7. "github.com/jmorganca/ollama/api"
  8. "github.com/jmorganca/ollama/gpu"
  9. )
  10. type LLM interface {
  11. Predict(context.Context, PredictOpts, func(PredictResult)) error
  12. Embedding(context.Context, string) ([]float64, error)
  13. Encode(context.Context, string) ([]int, error)
  14. Decode(context.Context, []int) (string, error)
  15. Close()
  16. }
  17. func New(workDir, model string, adapters, projectors []string, opts api.Options) (LLM, error) {
  18. if _, err := os.Stat(model); err != nil {
  19. return nil, err
  20. }
  21. f, err := os.Open(model)
  22. if err != nil {
  23. return nil, err
  24. }
  25. defer f.Close()
  26. ggml, err := DecodeGGML(f)
  27. if err != nil {
  28. return nil, err
  29. }
  30. if opts.NumCtx < 4 {
  31. opts.NumCtx = 4
  32. }
  33. vram, _ := gpu.CheckVRAM()
  34. size := ggml.Size
  35. // fp16 k,v matrices require = n_ctx * n_layer * n_embd / n_head * n_head_kv * 2 bytes each * 2 key and value
  36. kv := 2 * 2 * int64(opts.NumCtx) * int64(ggml.NumLayers()) * int64(ggml.NumEmbed()) * int64(ggml.NumHeadKv()) / int64(ggml.NumHead())
  37. // this amount is the overhead + tensors in memory
  38. // TODO: get this from the llama.cpp's graph calculations instead of
  39. // estimating it's 1/6 * kv_cache_size * num_gqa
  40. graph := int64(ggml.NumGQA()) * kv / 6
  41. info := gpu.GetGPUInfo()
  42. library := info.Library
  43. switch runtime.GOOS {
  44. case "darwin":
  45. if opts.NumGPU == 0 {
  46. break
  47. }
  48. if size+kv+graph > vram {
  49. log.Println("not enough vram available, falling back to CPU only")
  50. opts.NumGPU = 0
  51. break
  52. }
  53. opts.NumGPU = 1
  54. default:
  55. if library == "cpu" || library == "default" {
  56. log.Println("GPU not available, falling back to CPU")
  57. opts.NumGPU = 0
  58. break
  59. }
  60. // don't use GPU at all if no layers are loaded
  61. if opts.NumGPU == 0 {
  62. library = "cpu"
  63. break
  64. }
  65. // user-defined GPU count
  66. if opts.NumGPU != -1 {
  67. break
  68. }
  69. // the "main" GPU needs the most memory and determines the limit
  70. // of how many layers can be loaded. It needs to fit:
  71. // 1. the full compute graph allocation for all devices (graph)
  72. // 2. the proportional kv cache for all devices (kv * % layers)
  73. // 3. the proportional model (size * % layers / # devices)
  74. // This estimates the number of layers
  75. maxlayers := int64(ggml.NumLayers()) + 1
  76. devices := int64(info.DeviceCount)
  77. avg := vram / devices
  78. layers := maxlayers * (avg - graph) / (kv + size/devices)
  79. if layers > maxlayers {
  80. layers = maxlayers
  81. }
  82. // 1 + 2 must fit on the main gpu
  83. min := graph + kv*layers/maxlayers
  84. if layers <= 0 || min > avg {
  85. log.Printf("not enough vram available, falling back to CPU only")
  86. library = "cpu"
  87. opts.NumGPU = 0
  88. break
  89. }
  90. opts.NumGPU = int(layers)
  91. }
  92. opts.RopeFrequencyBase = 0.0
  93. opts.RopeFrequencyScale = 0.0
  94. gpuInfo := gpu.GetGPUInfo()
  95. return newLlmServer(gpuInfo, model, adapters, projectors, opts)
  96. }
  97. // Give any native cgo implementations an opportunity to initialize
  98. func Init(workdir string) error {
  99. return nativeInit(workdir)
  100. }
  101. func newLlmServer(gpuInfo gpu.GpuInfo, model string, adapters, projectors []string, opts api.Options) (LLM, error) {
  102. dynLibs := getDynLibs(gpuInfo)
  103. // Check to see if the user has requested a specific library instead of auto-detecting
  104. demandLib := os.Getenv("OLLAMA_LLM_LIBRARY")
  105. if demandLib != "" {
  106. libPath := availableDynLibs[demandLib]
  107. if libPath == "" {
  108. log.Printf("Invalid OLLAMA_LLM_LIBRARY %s - not found", demandLib)
  109. } else {
  110. log.Printf("Loading OLLAMA_LLM_LIBRARY=%s", demandLib)
  111. dynLibs = []string{libPath}
  112. }
  113. }
  114. err2 := fmt.Errorf("unable to locate suitable llm library")
  115. for _, dynLib := range dynLibs {
  116. srv, err := newDynExtServer(dynLib, model, adapters, projectors, opts)
  117. if err == nil {
  118. return srv, nil
  119. }
  120. log.Printf("Failed to load dynamic library %s %s", dynLib, err)
  121. err2 = err
  122. }
  123. return nil, err2
  124. }