llm.go 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. package llm
  2. import (
  3. "fmt"
  4. "log"
  5. "os"
  6. "github.com/pbnjay/memory"
  7. "github.com/jmorganca/ollama/api"
  8. )
  9. type LLM interface {
  10. Predict([]int, string, func(api.GenerateResponse)) error
  11. Embedding(string) ([]float64, error)
  12. Encode(string) []int
  13. Decode(...int) string
  14. SetOptions(api.Options)
  15. Close()
  16. }
  17. func New(model string, opts api.Options) (LLM, error) {
  18. if _, err := os.Stat(model); err != nil {
  19. return nil, err
  20. }
  21. f, err := os.Open(model)
  22. if err != nil {
  23. return nil, err
  24. }
  25. ggml, err := DecodeGGML(f, ModelFamilyLlama)
  26. if err != nil {
  27. return nil, err
  28. }
  29. switch ggml.FileType {
  30. case FileTypeF32, FileTypeF16, FileTypeQ5_0, FileTypeQ5_1, FileTypeQ8_0:
  31. if opts.NumGPU != 0 {
  32. // Q5_0, Q5_1, and Q8_0 do not support Metal API and will
  33. // cause the runner to segmentation fault so disable GPU
  34. log.Printf("WARNING: GPU disabled for F32, F16, Q5_0, Q5_1, and Q8_0")
  35. opts.NumGPU = 0
  36. }
  37. }
  38. totalResidentMemory := memory.TotalMemory()
  39. switch ggml.ModelType {
  40. case ModelType3B, ModelType7B:
  41. if totalResidentMemory < 8*1024*1024 {
  42. return nil, fmt.Errorf("model requires at least 8GB of memory")
  43. }
  44. case ModelType13B:
  45. if totalResidentMemory < 16*1024*1024 {
  46. return nil, fmt.Errorf("model requires at least 16GB of memory")
  47. }
  48. case ModelType30B:
  49. if totalResidentMemory < 32*1024*1024 {
  50. return nil, fmt.Errorf("model requires at least 32GB of memory")
  51. }
  52. case ModelType65B:
  53. if totalResidentMemory < 64*1024*1024 {
  54. return nil, fmt.Errorf("model requires at least 64GB of memory")
  55. }
  56. }
  57. switch ggml.ModelFamily {
  58. case ModelFamilyLlama:
  59. return newLlama(model, opts)
  60. default:
  61. return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily)
  62. }
  63. }