modelpath.go 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. package server
  2. import (
  3. "errors"
  4. "fmt"
  5. "net/url"
  6. "os"
  7. "path/filepath"
  8. "regexp"
  9. "strings"
  10. "github.com/ollama/ollama/envconfig"
  11. )
  12. type ModelPath struct {
  13. ProtocolScheme string
  14. Registry string
  15. Namespace string
  16. Repository string
  17. Tag string
  18. }
  19. const (
  20. DefaultRegistry = "registry.ollama.ai"
  21. DefaultNamespace = "library"
  22. DefaultTag = "latest"
  23. DefaultProtocolScheme = "https"
  24. )
  25. var (
  26. ErrInvalidImageFormat = errors.New("invalid image format")
  27. ErrInvalidProtocol = errors.New("invalid protocol scheme")
  28. ErrInsecureProtocol = errors.New("insecure protocol http")
  29. ErrInvalidDigestFormat = errors.New("invalid digest format")
  30. )
  31. func ParseModelPath(name string) ModelPath {
  32. mp := ModelPath{
  33. ProtocolScheme: DefaultProtocolScheme,
  34. Registry: DefaultRegistry,
  35. Namespace: DefaultNamespace,
  36. Repository: "",
  37. Tag: DefaultTag,
  38. }
  39. before, after, found := strings.Cut(name, "://")
  40. if found {
  41. mp.ProtocolScheme = before
  42. name = after
  43. }
  44. name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
  45. parts := strings.Split(name, "/")
  46. switch len(parts) {
  47. case 3:
  48. mp.Registry = parts[0]
  49. mp.Namespace = parts[1]
  50. mp.Repository = parts[2]
  51. case 2:
  52. mp.Namespace = parts[0]
  53. mp.Repository = parts[1]
  54. case 1:
  55. mp.Repository = parts[0]
  56. }
  57. if repo, tag, found := strings.Cut(mp.Repository, ":"); found {
  58. mp.Repository = repo
  59. mp.Tag = tag
  60. }
  61. return mp
  62. }
  63. var errModelPathInvalid = errors.New("invalid model path")
  64. func (mp ModelPath) Validate() error {
  65. if mp.Repository == "" {
  66. return fmt.Errorf("%w: model repository name is required", errModelPathInvalid)
  67. }
  68. if strings.Contains(mp.Tag, ":") {
  69. return fmt.Errorf("%w: ':' (colon) is not allowed in tag names", errModelPathInvalid)
  70. }
  71. return nil
  72. }
  73. func (mp ModelPath) GetNamespaceRepository() string {
  74. return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
  75. }
  76. func (mp ModelPath) GetFullTagname() string {
  77. return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
  78. }
  79. func (mp ModelPath) GetShortTagname() string {
  80. if mp.Registry == DefaultRegistry {
  81. if mp.Namespace == DefaultNamespace {
  82. return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag)
  83. }
  84. return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag)
  85. }
  86. return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
  87. }
  88. // modelsDir returns the value of the OLLAMA_MODELS environment variable or the user's home directory if OLLAMA_MODELS is not set.
  89. // The models directory is where Ollama stores its model files and manifests.
  90. func modelsDir() (string, error) {
  91. return envconfig.ModelsDir, nil
  92. }
  93. // GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
  94. func (mp ModelPath) GetManifestPath() (string, error) {
  95. dir, err := modelsDir()
  96. if err != nil {
  97. return "", err
  98. }
  99. return filepath.Join(dir, "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag), nil
  100. }
  101. func (mp ModelPath) BaseURL() *url.URL {
  102. return &url.URL{
  103. Scheme: mp.ProtocolScheme,
  104. Host: mp.Registry,
  105. }
  106. }
  107. func GetManifestPath() (string, error) {
  108. dir, err := modelsDir()
  109. if err != nil {
  110. return "", err
  111. }
  112. path := filepath.Join(dir, "manifests")
  113. if err := os.MkdirAll(path, 0o755); err != nil {
  114. return "", err
  115. }
  116. return path, nil
  117. }
  118. func GetBlobsPath(digest string) (string, error) {
  119. dir, err := modelsDir()
  120. if err != nil {
  121. return "", err
  122. }
  123. // only accept actual sha256 digests
  124. pattern := "^sha256[:-][0-9a-fA-F]{64}$"
  125. re := regexp.MustCompile(pattern)
  126. if digest != "" && !re.MatchString(digest) {
  127. return "", ErrInvalidDigestFormat
  128. }
  129. digest = strings.ReplaceAll(digest, ":", "-")
  130. path := filepath.Join(dir, "blobs", digest)
  131. dirPath := filepath.Dir(path)
  132. if digest == "" {
  133. dirPath = path
  134. }
  135. if err := os.MkdirAll(dirPath, 0o755); err != nil {
  136. return "", err
  137. }
  138. return path, nil
  139. }