payload.go 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211
  1. package llm
  2. import (
  3. "compress/gzip"
  4. "errors"
  5. "fmt"
  6. "io"
  7. "io/fs"
  8. "log/slog"
  9. "os"
  10. "path/filepath"
  11. "strings"
  12. "golang.org/x/exp/slices"
  13. "golang.org/x/sync/errgroup"
  14. "ollama.com/gpu"
  15. )
  16. var errPayloadMissing = fmt.Errorf("expected payloads not included in this build of ollama")
  17. func Init() error {
  18. payloadsDir, err := gpu.PayloadsDir()
  19. if err != nil {
  20. return err
  21. }
  22. slog.Info("extracting embedded files", "dir", payloadsDir)
  23. binGlob := "build/*/*/*/bin/*"
  24. // extract server libraries
  25. err = extractFiles(payloadsDir, binGlob)
  26. if err != nil {
  27. return fmt.Errorf("extract binaries: %v", err)
  28. }
  29. var variants []string
  30. for v := range availableServers() {
  31. variants = append(variants, v)
  32. }
  33. slog.Info(fmt.Sprintf("Dynamic LLM libraries %v", variants))
  34. slog.Debug("Override detection logic by setting OLLAMA_LLM_LIBRARY")
  35. return nil
  36. }
  37. // binary names may contain an optional variant separated by '_'
  38. // For example, "ollama_rocm_v6" and "ollama_rocm_v5" or "ollama_cpu" and "ollama_cpu_avx2"
  39. // Any library without a variant is the lowest common denominator
  40. func availableServers() map[string]string {
  41. payloadsDir, err := gpu.PayloadsDir()
  42. if err != nil {
  43. slog.Error("payload lookup error", "error", err)
  44. return nil
  45. }
  46. // glob payloadsDir for files that start with ollama_
  47. pattern := filepath.Join(payloadsDir, "*")
  48. files, err := filepath.Glob(pattern)
  49. if err != nil {
  50. slog.Debug("could not glob", "pattern", pattern, "error", err)
  51. return nil
  52. }
  53. servers := make(map[string]string)
  54. for _, file := range files {
  55. slog.Debug("availableServers : found", "file", file)
  56. servers[filepath.Base(file)] = file
  57. }
  58. return servers
  59. }
  60. // serversForGpu returns a list of compatible servers give the provided GPU
  61. // info, ordered by performance. assumes Init() has been called
  62. // TODO - switch to metadata based mapping
  63. func serversForGpu(info gpu.GpuInfo) []string {
  64. // glob workDir for files that start with ollama_
  65. availableServers := availableServers()
  66. requested := info.Library
  67. if info.Variant != "" {
  68. requested += "_" + info.Variant
  69. }
  70. servers := []string{}
  71. // exact match first
  72. for a := range availableServers {
  73. if a == requested {
  74. servers = []string{a}
  75. if a == "metal" {
  76. return servers
  77. }
  78. break
  79. }
  80. }
  81. alt := []string{}
  82. // Then for GPUs load alternates and sort the list for consistent load ordering
  83. if info.Library != "cpu" {
  84. for a := range availableServers {
  85. if info.Library == strings.Split(a, "_")[0] && a != requested {
  86. alt = append(alt, a)
  87. }
  88. }
  89. slices.Sort(alt)
  90. servers = append(servers, alt...)
  91. }
  92. // Load up the best CPU variant if not primary requested
  93. if info.Library != "cpu" {
  94. variant := gpu.GetCPUVariant()
  95. // If no variant, then we fall back to default
  96. // If we have a variant, try that if we find an exact match
  97. // Attempting to run the wrong CPU instructions will panic the
  98. // process
  99. if variant != "" {
  100. for cmp := range availableServers {
  101. if cmp == "cpu_"+variant {
  102. servers = append(servers, cmp)
  103. break
  104. }
  105. }
  106. } else {
  107. servers = append(servers, "cpu")
  108. }
  109. }
  110. if len(servers) == 0 {
  111. servers = []string{"cpu"}
  112. }
  113. return servers
  114. }
  115. // extract extracts the embedded files to the target directory
  116. func extractFiles(targetDir string, glob string) error {
  117. files, err := fs.Glob(libEmbed, glob)
  118. if err != nil || len(files) == 0 {
  119. return errPayloadMissing
  120. }
  121. if err := os.MkdirAll(targetDir, 0o755); err != nil {
  122. return fmt.Errorf("extractFiles could not mkdir %s: %v", targetDir, err)
  123. }
  124. g := new(errgroup.Group)
  125. // build/$OS/$GOARCH/$VARIANT/{bin,lib}/$FILE
  126. for _, file := range files {
  127. filename := file
  128. variant := filepath.Base(filepath.Dir(filepath.Dir(filename)))
  129. slog.Debug("extracting", "variant", variant, "file", filename)
  130. g.Go(func() error {
  131. srcf, err := libEmbed.Open(filename)
  132. if err != nil {
  133. return err
  134. }
  135. defer srcf.Close()
  136. src := io.Reader(srcf)
  137. if strings.HasSuffix(filename, ".gz") {
  138. src, err = gzip.NewReader(src)
  139. if err != nil {
  140. return fmt.Errorf("decompress payload %s: %v", filename, err)
  141. }
  142. filename = strings.TrimSuffix(filename, ".gz")
  143. }
  144. variantDir := filepath.Join(targetDir, variant)
  145. if err := os.MkdirAll(variantDir, 0o755); err != nil {
  146. return fmt.Errorf("extractFiles could not mkdir %s: %v", variantDir, err)
  147. }
  148. base := filepath.Base(filename)
  149. destFilename := filepath.Join(variantDir, base)
  150. _, err = os.Stat(destFilename)
  151. switch {
  152. case errors.Is(err, os.ErrNotExist):
  153. destFile, err := os.OpenFile(destFilename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
  154. if err != nil {
  155. return fmt.Errorf("write payload %s: %v", filename, err)
  156. }
  157. defer destFile.Close()
  158. if _, err := io.Copy(destFile, src); err != nil {
  159. return fmt.Errorf("copy payload %s: %v", filename, err)
  160. }
  161. case err != nil:
  162. return fmt.Errorf("stat payload %s: %v", filename, err)
  163. }
  164. return nil
  165. })
  166. }
  167. err = g.Wait()
  168. if err != nil {
  169. // If we fail to extract, the payload dir is unusable, so cleanup whatever we extracted
  170. gpu.Cleanup()
  171. return err
  172. }
  173. return nil
  174. }