llm.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. package llm
  2. import (
  3. "context"
  4. "fmt"
  5. "log"
  6. "os"
  7. "github.com/pbnjay/memory"
  8. "github.com/jmorganca/ollama/api"
  9. )
  10. type LLM interface {
  11. Predict(context.Context, []int, string, func(api.GenerateResponse)) error
  12. Embedding(context.Context, string) ([]float64, error)
  13. Encode(context.Context, string) ([]int, error)
  14. Decode(context.Context, []int) (string, error)
  15. SetOptions(api.Options)
  16. Close()
  17. Ping(context.Context) error
  18. }
  19. func New(model string, adapters []string, opts api.Options) (LLM, error) {
  20. if _, err := os.Stat(model); err != nil {
  21. return nil, err
  22. }
  23. f, err := os.Open(model)
  24. if err != nil {
  25. return nil, err
  26. }
  27. defer f.Close()
  28. ggml, err := DecodeGGML(f, ModelFamilyLlama)
  29. if err != nil {
  30. return nil, err
  31. }
  32. switch ggml.FileType().String() {
  33. case "F32", "Q5_0", "Q5_1", "Q8_0":
  34. if opts.NumGPU != 0 {
  35. // F32, F16, Q5_0, Q5_1, and Q8_0 do not support Metal API and will
  36. // cause the runner to segmentation fault so disable GPU
  37. log.Printf("WARNING: GPU disabled for F32, Q5_0, Q5_1, and Q8_0")
  38. opts.NumGPU = 0
  39. }
  40. }
  41. totalResidentMemory := memory.TotalMemory()
  42. switch ggml.ModelType() {
  43. case ModelType3B, ModelType7B:
  44. if ggml.FileType().String() == "F16" && totalResidentMemory < 16*1024*1024 {
  45. return nil, fmt.Errorf("F16 model requires at least 16GB of memory")
  46. } else if totalResidentMemory < 8*1024*1024 {
  47. return nil, fmt.Errorf("model requires at least 8GB of memory")
  48. }
  49. case ModelType13B:
  50. if ggml.FileType().String() == "F16" && totalResidentMemory < 32*1024*1024 {
  51. return nil, fmt.Errorf("F16 model requires at least 32GB of memory")
  52. } else if totalResidentMemory < 16*1024*1024 {
  53. return nil, fmt.Errorf("model requires at least 16GB of memory")
  54. }
  55. case ModelType30B, ModelType34B:
  56. if ggml.FileType().String() == "F16" && totalResidentMemory < 64*1024*1024 {
  57. return nil, fmt.Errorf("F16 model requires at least 64GB of memory")
  58. } else if totalResidentMemory < 32*1024*1024 {
  59. return nil, fmt.Errorf("model requires at least 32GB of memory")
  60. }
  61. case ModelType65B:
  62. if ggml.FileType().String() == "F16" && totalResidentMemory < 128*1024*1024 {
  63. return nil, fmt.Errorf("F16 model requires at least 128GB of memory")
  64. } else if totalResidentMemory < 64*1024*1024 {
  65. return nil, fmt.Errorf("model requires at least 64GB of memory")
  66. }
  67. }
  68. switch ggml.ModelFamily() {
  69. case ModelFamilyLlama:
  70. return newLlama(model, adapters, ggmlRunner(), opts)
  71. default:
  72. return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily())
  73. }
  74. }