modelpath.go 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. package server
  2. import (
  3. "errors"
  4. "fmt"
  5. "net/url"
  6. "os"
  7. "path/filepath"
  8. "regexp"
  9. "strings"
  10. )
  11. type ModelPath struct {
  12. ProtocolScheme string
  13. Registry string
  14. Namespace string
  15. Repository string
  16. Tag string
  17. }
  18. const (
  19. DefaultRegistry = "registry.ollama.ai"
  20. DefaultNamespace = "library"
  21. DefaultTag = "latest"
  22. DefaultProtocolScheme = "https"
  23. )
  24. var (
  25. ErrInvalidImageFormat = errors.New("invalid image format")
  26. ErrInvalidProtocol = errors.New("invalid protocol scheme")
  27. ErrInsecureProtocol = errors.New("insecure protocol http")
  28. ErrInvalidDigestFormat = errors.New("invalid digest format")
  29. )
  30. func ParseModelPath(name string) ModelPath {
  31. mp := ModelPath{
  32. ProtocolScheme: DefaultProtocolScheme,
  33. Registry: DefaultRegistry,
  34. Namespace: DefaultNamespace,
  35. Repository: "",
  36. Tag: DefaultTag,
  37. }
  38. before, after, found := strings.Cut(name, "://")
  39. if found {
  40. mp.ProtocolScheme = before
  41. name = after
  42. }
  43. name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
  44. parts := strings.Split(name, "/")
  45. switch len(parts) {
  46. case 3:
  47. mp.Registry = parts[0]
  48. mp.Namespace = parts[1]
  49. mp.Repository = parts[2]
  50. case 2:
  51. mp.Namespace = parts[0]
  52. mp.Repository = parts[1]
  53. case 1:
  54. mp.Repository = parts[0]
  55. }
  56. if repo, tag, found := strings.Cut(mp.Repository, ":"); found {
  57. mp.Repository = repo
  58. mp.Tag = tag
  59. }
  60. return mp
  61. }
  62. var errModelPathInvalid = errors.New("invalid model path")
  63. func (mp ModelPath) Validate() error {
  64. if mp.Repository == "" {
  65. return fmt.Errorf("%w: model repository name is required", errModelPathInvalid)
  66. }
  67. if strings.Contains(mp.Tag, ":") {
  68. return fmt.Errorf("%w: ':' (colon) is not allowed in tag names", errModelPathInvalid)
  69. }
  70. return nil
  71. }
  72. func (mp ModelPath) GetNamespaceRepository() string {
  73. return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
  74. }
  75. func (mp ModelPath) GetFullTagname() string {
  76. return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
  77. }
  78. func (mp ModelPath) GetShortTagname() string {
  79. if mp.Registry == DefaultRegistry {
  80. if mp.Namespace == DefaultNamespace {
  81. return fmt.Sprintf("%s:%s", mp.Repository, mp.Tag)
  82. }
  83. return fmt.Sprintf("%s/%s:%s", mp.Namespace, mp.Repository, mp.Tag)
  84. }
  85. return fmt.Sprintf("%s/%s/%s:%s", mp.Registry, mp.Namespace, mp.Repository, mp.Tag)
  86. }
  87. // modelsDir returns the value of the OLLAMA_MODELS environment variable or the user's home directory if OLLAMA_MODELS is not set.
  88. // The models directory is where Ollama stores its model files and manifests.
  89. func modelsDir() (string, error) {
  90. if models, exists := os.LookupEnv("OLLAMA_MODELS"); exists {
  91. return models, nil
  92. }
  93. home, err := os.UserHomeDir()
  94. if err != nil {
  95. return "", err
  96. }
  97. return filepath.Join(home, ".ollama", "models"), nil
  98. }
  99. // GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
  100. func (mp ModelPath) GetManifestPath() (string, error) {
  101. dir, err := modelsDir()
  102. if err != nil {
  103. return "", err
  104. }
  105. return filepath.Join(dir, "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag), nil
  106. }
  107. func (mp ModelPath) BaseURL() *url.URL {
  108. return &url.URL{
  109. Scheme: mp.ProtocolScheme,
  110. Host: mp.Registry,
  111. }
  112. }
  113. func GetManifestPath() (string, error) {
  114. dir, err := modelsDir()
  115. if err != nil {
  116. return "", err
  117. }
  118. path := filepath.Join(dir, "manifests")
  119. if err := os.MkdirAll(path, 0o755); err != nil {
  120. return "", err
  121. }
  122. return path, nil
  123. }
  124. func GetBlobsPath(digest string) (string, error) {
  125. dir, err := modelsDir()
  126. if err != nil {
  127. return "", err
  128. }
  129. // only accept actual sha256 digests
  130. pattern := "^sha256[:-][0-9a-fA-F]{64}$"
  131. re := regexp.MustCompile(pattern)
  132. if digest != "" && !re.MatchString(digest) {
  133. return "", ErrInvalidDigestFormat
  134. }
  135. digest = strings.ReplaceAll(digest, ":", "-")
  136. path := filepath.Join(dir, "blobs", digest)
  137. dirPath := filepath.Dir(path)
  138. if digest == "" {
  139. dirPath = path
  140. }
  141. if err := os.MkdirAll(dirPath, 0o755); err != nil {
  142. return "", err
  143. }
  144. return path, nil
  145. }