modelpath.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. package server
  2. import (
  3. "errors"
  4. "fmt"
  5. "io/fs"
  6. "log/slog"
  7. "net/url"
  8. "os"
  9. "path/filepath"
  10. "regexp"
  11. "strings"
  12. "github.com/ollama/ollama/envconfig"
  13. )
  14. type ModelPath struct {
  15. ProtocolScheme string
  16. Registry string
  17. Namespace string
  18. Repository string
  19. Tag string
  20. }
  21. const (
  22. DefaultRegistry = "ollama.com"
  23. DefaultNamespace = "library"
  24. DefaultTag = "latest"
  25. DefaultProtocolScheme = "https"
  26. )
  27. var (
  28. ErrInvalidImageFormat = errors.New("invalid image format")
  29. ErrInvalidProtocol = errors.New("invalid protocol scheme")
  30. ErrInsecureProtocol = errors.New("insecure protocol http")
  31. ErrInvalidDigestFormat = errors.New("invalid digest format")
  32. )
  33. func ParseModelPath(name string) ModelPath {
  34. mp := ModelPath{
  35. ProtocolScheme: DefaultProtocolScheme,
  36. Registry: DefaultRegistry,
  37. Namespace: DefaultNamespace,
  38. Repository: "",
  39. Tag: DefaultTag,
  40. }
  41. before, after, found := strings.Cut(name, "://")
  42. if found {
  43. mp.ProtocolScheme = before
  44. name = after
  45. }
  46. parts := strings.Split(filepath.ToSlash(name), "/")
  47. switch len(parts) {
  48. case 3:
  49. mp.Registry = parts[0]
  50. mp.Namespace = parts[1]
  51. mp.Repository = parts[2]
  52. case 2:
  53. mp.Namespace = parts[0]
  54. mp.Repository = parts[1]
  55. case 1:
  56. mp.Repository = parts[0]
  57. }
  58. if repo, tag, found := strings.Cut(mp.Repository, ":"); found {
  59. mp.Repository = repo
  60. mp.Tag = tag
  61. }
  62. return mp
  63. }
  64. var errModelPathInvalid = errors.New("invalid model path")
  65. func (mp ModelPath) Validate() error {
  66. if mp.Repository == "" {
  67. return fmt.Errorf("%w: model repository name is required", errModelPathInvalid)
  68. }
  69. if strings.Contains(mp.Tag, ":") {
  70. return fmt.Errorf("%w: ':' (colon) is not allowed in tag names", errModelPathInvalid)
  71. }
  72. return nil
  73. }
  74. func (mp ModelPath) GetNamespaceRepository() string {
  75. return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
  76. }
  77. func (mp ModelPath) GetFullTagname() string {
  78. return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
  79. }
  80. func (mp ModelPath) GetShortTagname() string {
  81. if mp.Registry == DefaultRegistry {
  82. if mp.Namespace == DefaultNamespace {
  83. return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag)
  84. }
  85. return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag)
  86. }
  87. return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
  88. }
  89. // GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
  90. func (mp ModelPath) GetManifestPath() (string, error) {
  91. dir := envconfig.ModelsDir
  92. return filepath.Join(dir, "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag), nil
  93. }
  94. func (mp ModelPath) BaseURL() *url.URL {
  95. return &url.URL{
  96. Scheme: mp.ProtocolScheme,
  97. Host: mp.Registry,
  98. }
  99. }
  100. func GetManifestPath() (string, error) {
  101. dir := envconfig.ModelsDir
  102. path := filepath.Join(dir, "manifests")
  103. if err := os.MkdirAll(path, 0o755); err != nil {
  104. return "", err
  105. }
  106. return path, nil
  107. }
  108. func GetBlobsPath(digest string) (string, error) {
  109. dir := envconfig.ModelsDir
  110. // only accept actual sha256 digests
  111. pattern := "^sha256[:-][0-9a-fA-F]{64}$"
  112. re := regexp.MustCompile(pattern)
  113. if digest != "" && !re.MatchString(digest) {
  114. return "", ErrInvalidDigestFormat
  115. }
  116. digest = strings.ReplaceAll(digest, ":", "-")
  117. path := filepath.Join(dir, "blobs", digest)
  118. dirPath := filepath.Dir(path)
  119. if digest == "" {
  120. dirPath = path
  121. }
  122. if err := os.MkdirAll(dirPath, 0o755); err != nil {
  123. return "", err
  124. }
  125. return path, nil
  126. }
  127. func migrateRegistryDomain() error {
  128. manifests, err := GetManifestPath()
  129. if err != nil {
  130. return err
  131. }
  132. olddomainpath := filepath.Join(manifests, "registry.ollama.ai")
  133. newdomainpath := filepath.Join(manifests, DefaultRegistry)
  134. return filepath.Walk(olddomainpath, func(path string, info fs.FileInfo, err error) error {
  135. if err != nil {
  136. return err
  137. }
  138. if !info.IsDir() {
  139. slog.Info("migrating registry domain", "path", path)
  140. newpath := filepath.Join(newdomainpath, strings.TrimPrefix(path, olddomainpath))
  141. if err := os.MkdirAll(filepath.Dir(newpath), 0o755); err != nil {
  142. return err
  143. }
  144. if err := os.Rename(path, newpath); err != nil {
  145. return err
  146. }
  147. }
  148. return nil
  149. })
  150. }