llm.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. package llm
  2. import (
  3. "context"
  4. "fmt"
  5. "log"
  6. "os"
  7. "github.com/pbnjay/memory"
  8. "github.com/jmorganca/ollama/api"
  9. )
  10. type LLM interface {
  11. Predict(context.Context, []int, string, func(api.GenerateResponse)) error
  12. Embedding(context.Context, string) ([]float64, error)
  13. Encode(context.Context, string) ([]int, error)
  14. Decode(context.Context, []int) (string, error)
  15. SetOptions(api.Options)
  16. Close()
  17. Ping(context.Context) error
  18. }
  19. func New(workDir, model string, adapters []string, opts api.Options) (LLM, error) {
  20. if _, err := os.Stat(model); err != nil {
  21. return nil, err
  22. }
  23. f, err := os.Open(model)
  24. if err != nil {
  25. return nil, err
  26. }
  27. defer f.Close()
  28. ggml, err := DecodeGGML(f)
  29. if err != nil {
  30. return nil, err
  31. }
  32. switch ggml.FileType() {
  33. case "Q8_0":
  34. if ggml.Name() != "gguf" && opts.NumGPU != 0 {
  35. // GGML Q8_0 do not support Metal API and will
  36. // cause the runner to segmentation fault so disable GPU
  37. log.Printf("WARNING: GPU disabled for F32, Q5_0, Q5_1, and Q8_0")
  38. opts.NumGPU = 0
  39. }
  40. case "F32", "Q5_0", "Q5_1":
  41. if opts.NumGPU != 0 {
  42. // F32, Q5_0, Q5_1, and Q8_0 do not support Metal API and will
  43. // cause the runner to segmentation fault so disable GPU
  44. log.Printf("WARNING: GPU disabled for F32, Q5_0, Q5_1, and Q8_0")
  45. opts.NumGPU = 0
  46. }
  47. }
  48. totalResidentMemory := memory.TotalMemory()
  49. switch ggml.ModelType() {
  50. case "3B", "7B":
  51. if ggml.FileType() == "F16" && totalResidentMemory < 16*1024*1024 {
  52. return nil, fmt.Errorf("F16 model requires at least 16GB of memory")
  53. } else if totalResidentMemory < 8*1024*1024 {
  54. return nil, fmt.Errorf("model requires at least 8GB of memory")
  55. }
  56. case "13B":
  57. if ggml.FileType() == "F16" && totalResidentMemory < 32*1024*1024 {
  58. return nil, fmt.Errorf("F16 model requires at least 32GB of memory")
  59. } else if totalResidentMemory < 16*1024*1024 {
  60. return nil, fmt.Errorf("model requires at least 16GB of memory")
  61. }
  62. case "30B", "34B", "40B":
  63. if ggml.FileType() == "F16" && totalResidentMemory < 64*1024*1024 {
  64. return nil, fmt.Errorf("F16 model requires at least 64GB of memory")
  65. } else if totalResidentMemory < 32*1024*1024 {
  66. return nil, fmt.Errorf("model requires at least 32GB of memory")
  67. }
  68. case "65B", "70B":
  69. if ggml.FileType() == "F16" && totalResidentMemory < 128*1024*1024 {
  70. return nil, fmt.Errorf("F16 model requires at least 128GB of memory")
  71. } else if totalResidentMemory < 64*1024*1024 {
  72. return nil, fmt.Errorf("model requires at least 64GB of memory")
  73. }
  74. case "180B":
  75. if ggml.FileType() == "F16" && totalResidentMemory < 512*1024*1024 {
  76. return nil, fmt.Errorf("F16 model requires at least 512GB of memory")
  77. } else if totalResidentMemory < 128*1024*1024 {
  78. return nil, fmt.Errorf("model requires at least 128GB of memory")
  79. }
  80. }
  81. switch ggml.Name() {
  82. case "gguf":
  83. opts.NumGQA = 0 // TODO: remove this when llama.cpp runners differ enough to need separate newLlama functions
  84. return newLlama(model, adapters, chooseRunners(workDir, "gguf"), ggml.NumLayers(), opts)
  85. case "ggml", "ggmf", "ggjt", "ggla":
  86. return newLlama(model, adapters, chooseRunners(workDir, "ggml"), ggml.NumLayers(), opts)
  87. default:
  88. return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily())
  89. }
  90. }