payload.go 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. package llm
  2. import (
  3. "compress/gzip"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "io/fs"
  8. "log/slog"
  9. "os"
  10. "path/filepath"
  11. "runtime"
  12. "strings"
  13. "golang.org/x/exp/slices"
  14. "golang.org/x/sync/errgroup"
  15. "github.com/ollama/ollama/gpu"
  16. )
  17. var errPayloadMissing = errors.New("expected payloads not included in this build of ollama")
  18. func Init() error {
  19. payloadsDir, err := gpu.PayloadsDir()
  20. if err != nil {
  21. return err
  22. }
  23. if runtime.GOOS != "windows" {
  24. slog.Info("extracting embedded files", "dir", payloadsDir)
  25. binGlob := "build/*/*/*/bin/*"
  26. // extract server libraries
  27. err = extractFiles(payloadsDir, binGlob)
  28. if err != nil {
  29. return fmt.Errorf("extract binaries: %v", err)
  30. }
  31. }
  32. var variants []string
  33. for v := range availableServers() {
  34. variants = append(variants, v)
  35. }
  36. slog.Info(fmt.Sprintf("Dynamic LLM libraries %v", variants))
  37. slog.Debug("Override detection logic by setting OLLAMA_LLM_LIBRARY")
  38. return nil
  39. }
  40. // binary names may contain an optional variant separated by '_'
  41. // For example, "ollama_rocm_v6" and "ollama_rocm_v5" or "ollama_cpu" and "ollama_cpu_avx2"
  42. // Any library without a variant is the lowest common denominator
  43. func availableServers() map[string]string {
  44. payloadsDir, err := gpu.PayloadsDir()
  45. if err != nil {
  46. slog.Error("payload lookup error", "error", err)
  47. return nil
  48. }
  49. // glob payloadsDir for files that start with ollama_
  50. pattern := filepath.Join(payloadsDir, "*")
  51. files, err := filepath.Glob(pattern)
  52. if err != nil {
  53. slog.Debug("could not glob", "pattern", pattern, "error", err)
  54. return nil
  55. }
  56. servers := make(map[string]string)
  57. for _, file := range files {
  58. slog.Debug("availableServers : found", "file", file)
  59. servers[filepath.Base(file)] = file
  60. }
  61. return servers
  62. }
  63. // serversForGpu returns a list of compatible servers give the provided GPU
  64. // info, ordered by performance. assumes Init() has been called
  65. // TODO - switch to metadata based mapping
  66. func serversForGpu(info gpu.GpuInfo) []string {
  67. // glob workDir for files that start with ollama_
  68. availableServers := availableServers()
  69. requested := info.Library
  70. if info.Variant != "" {
  71. requested += "_" + info.Variant
  72. }
  73. servers := []string{}
  74. // exact match first
  75. for a := range availableServers {
  76. if a == requested {
  77. servers = []string{a}
  78. if a == "metal" {
  79. return servers
  80. }
  81. break
  82. }
  83. }
  84. alt := []string{}
  85. // Then for GPUs load alternates and sort the list for consistent load ordering
  86. if info.Library != "cpu" {
  87. for a := range availableServers {
  88. if info.Library == strings.Split(a, "_")[0] && a != requested {
  89. alt = append(alt, a)
  90. }
  91. }
  92. slices.Sort(alt)
  93. servers = append(servers, alt...)
  94. }
  95. // Load up the best CPU variant if not primary requested
  96. if info.Library != "cpu" {
  97. variant := gpu.GetCPUVariant()
  98. // If no variant, then we fall back to default
  99. // If we have a variant, try that if we find an exact match
  100. // Attempting to run the wrong CPU instructions will panic the
  101. // process
  102. if variant != "" {
  103. for cmp := range availableServers {
  104. if cmp == "cpu_"+variant {
  105. servers = append(servers, cmp)
  106. break
  107. }
  108. }
  109. } else {
  110. servers = append(servers, "cpu")
  111. }
  112. }
  113. if len(servers) == 0 {
  114. servers = []string{"cpu"}
  115. }
  116. return servers
  117. }
  118. // Return the optimal server for this CPU architecture
  119. func serverForCpu() string {
  120. if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
  121. return "metal"
  122. }
  123. variant := gpu.GetCPUVariant()
  124. availableServers := availableServers()
  125. if variant != "" {
  126. for cmp := range availableServers {
  127. if cmp == "cpu_"+variant {
  128. return cmp
  129. }
  130. }
  131. }
  132. return "cpu"
  133. }
  134. // extract extracts the embedded files to the target directory
  135. func extractFiles(targetDir string, glob string) error {
  136. files, err := fs.Glob(libEmbed, glob)
  137. if err != nil || len(files) == 0 {
  138. return errPayloadMissing
  139. }
  140. if err := os.MkdirAll(targetDir, 0o755); err != nil {
  141. return fmt.Errorf("extractFiles could not mkdir %s: %v", targetDir, err)
  142. }
  143. g := new(errgroup.Group)
  144. // build/$OS/$GOARCH/$VARIANT/{bin,lib}/$FILE
  145. for _, file := range files {
  146. filename := file
  147. variant := filepath.Base(filepath.Dir(filepath.Dir(filename)))
  148. slog.Debug("extracting", "variant", variant, "file", filename)
  149. g.Go(func() error {
  150. srcf, err := libEmbed.Open(filename)
  151. if err != nil {
  152. return err
  153. }
  154. defer srcf.Close()
  155. src := io.Reader(srcf)
  156. if strings.HasSuffix(filename, ".gz") {
  157. src, err = gzip.NewReader(src)
  158. if err != nil {
  159. return fmt.Errorf("decompress payload %s: %v", filename, err)
  160. }
  161. filename = strings.TrimSuffix(filename, ".gz")
  162. }
  163. variantDir := filepath.Join(targetDir, variant)
  164. if err := os.MkdirAll(variantDir, 0o755); err != nil {
  165. return fmt.Errorf("extractFiles could not mkdir %s: %v", variantDir, err)
  166. }
  167. base := filepath.Base(filename)
  168. destFilename := filepath.Join(variantDir, base)
  169. _, err = os.Stat(destFilename)
  170. switch {
  171. case errors.Is(err, os.ErrNotExist):
  172. destFile, err := os.OpenFile(destFilename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
  173. if err != nil {
  174. return fmt.Errorf("write payload %s: %v", filename, err)
  175. }
  176. defer destFile.Close()
  177. if _, err := io.Copy(destFile, src); err != nil {
  178. return fmt.Errorf("copy payload %s: %v", filename, err)
  179. }
  180. case err != nil:
  181. return fmt.Errorf("stat payload %s: %v", filename, err)
  182. }
  183. return nil
  184. })
  185. }
  186. err = g.Wait()
  187. if err != nil {
  188. // If we fail to extract, the payload dir is unusable, so cleanup whatever we extracted
  189. gpu.Cleanup()
  190. return err
  191. }
  192. return nil
  193. }