payload.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. package llm
  2. import (
  3. "compress/gzip"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "io/fs"
  8. "log/slog"
  9. "os"
  10. "path/filepath"
  11. "runtime"
  12. "strings"
  13. "golang.org/x/exp/slices"
  14. "golang.org/x/sync/errgroup"
  15. "github.com/ollama/ollama/gpu"
  16. )
  17. var errPayloadMissing = errors.New("expected payloads not included in this build of ollama")
  18. func Init() error {
  19. payloadsDir, err := gpu.PayloadsDir()
  20. if err != nil {
  21. return err
  22. }
  23. slog.Info("extracting embedded files", "dir", payloadsDir)
  24. binGlob := "build/*/*/*/bin/*"
  25. // extract server libraries
  26. err = extractFiles(payloadsDir, binGlob)
  27. if err != nil {
  28. return fmt.Errorf("extract binaries: %v", err)
  29. }
  30. var variants []string
  31. for v := range availableServers() {
  32. variants = append(variants, v)
  33. }
  34. slog.Info(fmt.Sprintf("Dynamic LLM libraries %v", variants))
  35. slog.Debug("Override detection logic by setting OLLAMA_LLM_LIBRARY")
  36. return nil
  37. }
  38. // binary names may contain an optional variant separated by '_'
  39. // For example, "ollama_rocm_v6" and "ollama_rocm_v5" or "ollama_cpu" and "ollama_cpu_avx2"
  40. // Any library without a variant is the lowest common denominator
  41. func availableServers() map[string]string {
  42. payloadsDir, err := gpu.PayloadsDir()
  43. if err != nil {
  44. slog.Error("payload lookup error", "error", err)
  45. return nil
  46. }
  47. // glob payloadsDir for files that start with ollama_
  48. pattern := filepath.Join(payloadsDir, "*")
  49. files, err := filepath.Glob(pattern)
  50. if err != nil {
  51. slog.Debug("could not glob", "pattern", pattern, "error", err)
  52. return nil
  53. }
  54. servers := make(map[string]string)
  55. for _, file := range files {
  56. slog.Debug("availableServers : found", "file", file)
  57. servers[filepath.Base(file)] = file
  58. }
  59. return servers
  60. }
  61. // serversForGpu returns a list of compatible servers give the provided GPU
  62. // info, ordered by performance. assumes Init() has been called
  63. // TODO - switch to metadata based mapping
  64. func serversForGpu(info gpu.GpuInfo) []string {
  65. // glob workDir for files that start with ollama_
  66. availableServers := availableServers()
  67. requested := info.Library
  68. if info.Variant != "" {
  69. requested += "_" + info.Variant
  70. }
  71. servers := []string{}
  72. // exact match first
  73. for a := range availableServers {
  74. if a == requested {
  75. servers = []string{a}
  76. if a == "metal" {
  77. return servers
  78. }
  79. break
  80. }
  81. }
  82. alt := []string{}
  83. // Then for GPUs load alternates and sort the list for consistent load ordering
  84. if info.Library != "cpu" {
  85. for a := range availableServers {
  86. if info.Library == strings.Split(a, "_")[0] && a != requested {
  87. alt = append(alt, a)
  88. }
  89. }
  90. slices.Sort(alt)
  91. servers = append(servers, alt...)
  92. }
  93. // Load up the best CPU variant if not primary requested
  94. if info.Library != "cpu" {
  95. variant := gpu.GetCPUVariant()
  96. // If no variant, then we fall back to default
  97. // If we have a variant, try that if we find an exact match
  98. // Attempting to run the wrong CPU instructions will panic the
  99. // process
  100. if variant != "" {
  101. for cmp := range availableServers {
  102. if cmp == "cpu_"+variant {
  103. servers = append(servers, cmp)
  104. break
  105. }
  106. }
  107. } else {
  108. servers = append(servers, "cpu")
  109. }
  110. }
  111. if len(servers) == 0 {
  112. servers = []string{"cpu"}
  113. }
  114. return servers
  115. }
  116. // Return the optimal server for this CPU architecture
  117. func serverForCpu() string {
  118. if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
  119. return "metal"
  120. }
  121. variant := gpu.GetCPUVariant()
  122. availableServers := availableServers()
  123. if variant != "" {
  124. for cmp := range availableServers {
  125. if cmp == "cpu_"+variant {
  126. return cmp
  127. }
  128. }
  129. }
  130. return "cpu"
  131. }
  132. // extract extracts the embedded files to the target directory
  133. func extractFiles(targetDir string, glob string) error {
  134. files, err := fs.Glob(libEmbed, glob)
  135. if err != nil || len(files) == 0 {
  136. return errPayloadMissing
  137. }
  138. if err := os.MkdirAll(targetDir, 0o755); err != nil {
  139. return fmt.Errorf("extractFiles could not mkdir %s: %v", targetDir, err)
  140. }
  141. g := new(errgroup.Group)
  142. // build/$OS/$GOARCH/$VARIANT/{bin,lib}/$FILE
  143. for _, file := range files {
  144. filename := file
  145. variant := filepath.Base(filepath.Dir(filepath.Dir(filename)))
  146. slog.Debug("extracting", "variant", variant, "file", filename)
  147. g.Go(func() error {
  148. srcf, err := libEmbed.Open(filename)
  149. if err != nil {
  150. return err
  151. }
  152. defer srcf.Close()
  153. src := io.Reader(srcf)
  154. if strings.HasSuffix(filename, ".gz") {
  155. src, err = gzip.NewReader(src)
  156. if err != nil {
  157. return fmt.Errorf("decompress payload %s: %v", filename, err)
  158. }
  159. filename = strings.TrimSuffix(filename, ".gz")
  160. }
  161. variantDir := filepath.Join(targetDir, variant)
  162. if err := os.MkdirAll(variantDir, 0o755); err != nil {
  163. return fmt.Errorf("extractFiles could not mkdir %s: %v", variantDir, err)
  164. }
  165. base := filepath.Base(filename)
  166. destFilename := filepath.Join(variantDir, base)
  167. _, err = os.Stat(destFilename)
  168. switch {
  169. case errors.Is(err, os.ErrNotExist):
  170. destFile, err := os.OpenFile(destFilename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
  171. if err != nil {
  172. return fmt.Errorf("write payload %s: %v", filename, err)
  173. }
  174. defer destFile.Close()
  175. if _, err := io.Copy(destFile, src); err != nil {
  176. return fmt.Errorf("copy payload %s: %v", filename, err)
  177. }
  178. case err != nil:
  179. return fmt.Errorf("stat payload %s: %v", filename, err)
  180. }
  181. return nil
  182. })
  183. }
  184. err = g.Wait()
  185. if err != nil {
  186. // If we fail to extract, the payload dir is unusable, so cleanup whatever we extracted
  187. gpu.Cleanup()
  188. return err
  189. }
  190. return nil
  191. }