registry.go 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401
  1. package registry
  2. import (
  3. "cmp"
  4. "context"
  5. "encoding/json"
  6. "encoding/xml"
  7. "errors"
  8. "fmt"
  9. "io"
  10. "iter"
  11. "log/slog"
  12. "net/http"
  13. "net/url"
  14. "os"
  15. "sync"
  16. "github.com/ollama/ollama/client/ollama"
  17. "github.com/ollama/ollama/client/registry/apitype"
  18. "github.com/ollama/ollama/types/model"
  19. "golang.org/x/exp/constraints"
  20. "golang.org/x/sync/errgroup"
  21. )
  22. // Errors
  23. var (
  24. ErrLayerNotFound = errors.New("layer not found")
  25. )
  26. type Client struct {
  27. BaseURL string
  28. Logger *slog.Logger
  29. }
  30. func (c *Client) logger() *slog.Logger {
  31. return cmp.Or(c.Logger, slog.Default())
  32. }
  33. func (c *Client) oclient() *ollama.Client {
  34. return &ollama.Client{
  35. BaseURL: c.BaseURL,
  36. }
  37. }
  38. type ReadAtSeekCloser interface {
  39. io.ReaderAt
  40. io.Seeker
  41. io.Closer
  42. }
  43. type Cache interface {
  44. // LayerFile returns the absolute file path to the layer file for
  45. // the given model digest.
  46. //
  47. // If the digest is invalid, or the layer does not exist, the empty
  48. // string is returned.
  49. LayerFile(model.Digest) string
  50. // OpenLayer opens the layer file for the given model digest and
  51. // returns it, or an if any. The caller is responsible for closing
  52. // the returned file.
  53. OpenLayer(model.Digest) (ReadAtSeekCloser, error)
  54. // PutLayerFile moves the layer file at fromPath to the cache for
  55. // the given model digest. It is a hack intended to short circuit a
  56. // file copy operation.
  57. //
  58. // The file returned is expected to exist for the lifetime of the
  59. // cache.
  60. //
  61. // TODO(bmizerany): remove this; find a better way. Once we move
  62. // this into a build package, we should be able to get rid of this.
  63. PutLayerFile(_ model.Digest, fromPath string) error
  64. // SetManifestData sets the provided manifest data for the given
  65. // model name. If the manifest data is empty, the manifest is
  66. // removed. If the manifeest exists, it is overwritten.
  67. //
  68. // It is an error to call SetManifestData with a name that is not
  69. // complete.
  70. SetManifestData(model.Name, []byte) error
  71. // ManifestData returns the manifest data for the given model name.
  72. //
  73. // If the name incomplete, or the manifest does not exist, the empty
  74. // string is returned.
  75. ManifestData(name model.Name) []byte
  76. }
  77. // Pull pulls the manifest for name, and downloads any of its required
  78. // layers that are not already in the cache. It returns an error if any part
  79. // of the process fails, specifically:
  80. func (c *Client) Pull(ctx context.Context, cache Cache, name string) error {
  81. mn := model.ParseName(name)
  82. if !mn.IsFullyQualified() {
  83. return fmt.Errorf("ollama: pull: invalid name: %s", name)
  84. }
  85. log := c.logger().With("name", name)
  86. pr, err := ollama.Do[*apitype.PullResponse](ctx, c.oclient(), "GET", "/v1/pull/"+name, nil)
  87. if err != nil {
  88. return fmt.Errorf("ollama: pull: %w: %s", err, name)
  89. }
  90. if pr.Manifest == nil || len(pr.Manifest.Layers) == 0 {
  91. return fmt.Errorf("ollama: pull: invalid manifest: %s: no layers found", name)
  92. }
  93. // download required layers we do not already have
  94. for _, l := range pr.Manifest.Layers {
  95. d, err := model.ParseDigest(l.Digest)
  96. if err != nil {
  97. return fmt.Errorf("ollama: reading manifest: %w: %s", err, l.Digest)
  98. }
  99. if cache.LayerFile(d) != "" {
  100. continue
  101. }
  102. err = func() error {
  103. log := log.With("digest", l.Digest, "mediaType", l.MediaType, "size", l.Size)
  104. log.Debug("starting download")
  105. // TODO(bmizerany): stop using temp which might not
  106. // be on same device as cache.... instead let cache
  107. // give us a place to store parts...
  108. tmpFile, err := os.CreateTemp("", "ollama-download-")
  109. if err != nil {
  110. return err
  111. }
  112. defer func() {
  113. tmpFile.Close()
  114. os.Remove(tmpFile.Name()) // in case we fail before committing
  115. }()
  116. g, ctx := errgroup.WithContext(ctx)
  117. g.SetLimit(8) // TODO(bmizerany): make this configurable
  118. // TODO(bmizerany): make chunk size configurable
  119. const chunkSize = 50 * 1024 * 1024 // 50MB
  120. chunks(l.Size, chunkSize)(func(_ int, rng chunkRange[int64]) bool {
  121. g.Go(func() (err error) {
  122. defer func() {
  123. if err == nil {
  124. return
  125. }
  126. safeURL := redactAmzSignature(l.URL)
  127. err = fmt.Errorf("%w: %s %s bytes=%s: %s", err, pr.Name, l.Digest, rng, safeURL)
  128. }()
  129. log.Debug("downloading", "range", rng)
  130. // TODO(bmizerany): retry
  131. // TODO(bmizerany): use real http client
  132. // TODO(bmizerany): resumable
  133. // TODO(bmizerany): multipart download
  134. req, err := http.NewRequestWithContext(ctx, "GET", l.URL, nil)
  135. if err != nil {
  136. return err
  137. }
  138. req.Header.Set("Range", "bytes="+rng.String())
  139. res, err := http.DefaultClient.Do(req)
  140. if err != nil {
  141. return err
  142. }
  143. defer res.Body.Close()
  144. if res.StatusCode/100 != 2 {
  145. log.Debug("unexpected non-2XX status code", "status", res.StatusCode)
  146. return fmt.Errorf("unexpected status code fetching layer: %d", res.StatusCode)
  147. }
  148. if res.ContentLength != rng.Size() {
  149. return fmt.Errorf("unexpected content length: %d", res.ContentLength)
  150. }
  151. w := io.NewOffsetWriter(tmpFile, rng.Start)
  152. _, err = io.Copy(w, res.Body)
  153. return err
  154. })
  155. return true
  156. })
  157. if err := g.Wait(); err != nil {
  158. return err
  159. }
  160. tmpFile.Close() // release our hold on the file before moving it
  161. return cache.PutLayerFile(d, tmpFile.Name())
  162. }()
  163. if err != nil {
  164. return fmt.Errorf("ollama: pull: %w", err)
  165. }
  166. }
  167. // do not store the presigned URLs in the cache
  168. for i := range pr.Manifest.Layers {
  169. pr.Manifest.Layers[i].URL = ""
  170. }
  171. data, err := json.Marshal(pr.Manifest)
  172. if err != nil {
  173. return err
  174. }
  175. // TODO(bmizerany): remove dep on model.Name
  176. return cache.SetManifestData(mn, data)
  177. }
  178. type nopSeeker struct {
  179. io.Reader
  180. }
  181. func (nopSeeker) Seek(int64, int) (int64, error) {
  182. return 0, nil
  183. }
  184. // Push pushes a manifest to the server and responds to the server's
  185. // requests for layer uploads, if any, and finally commits the manifest for
  186. // name. It returns an error if any part of the process fails, specifically:
  187. //
  188. // If the server requests layers not found in the cache, ErrLayerNotFound is
  189. // returned.
  190. func (c *Client) Push(ctx context.Context, cache Cache, name string) error {
  191. mn := model.ParseName(name)
  192. if !mn.IsFullyQualified() {
  193. return fmt.Errorf("ollama: push: invalid name: %s", name)
  194. }
  195. manifest := cache.ManifestData(mn)
  196. if len(manifest) == 0 {
  197. return fmt.Errorf("manifest not found: %s", name)
  198. }
  199. var mu sync.Mutex
  200. var completed []*apitype.CompletePart
  201. push := func() (*apitype.PushResponse, error) {
  202. v, err := ollama.Do[*apitype.PushResponse](ctx, c.oclient(), "POST", "/v1/push", &apitype.PushRequest{
  203. Name: name,
  204. Manifest: manifest,
  205. CompleteParts: completed,
  206. })
  207. if err != nil {
  208. return nil, fmt.Errorf("Do: %w", err)
  209. }
  210. return v, nil
  211. }
  212. pr, err := push()
  213. if err != nil {
  214. return err
  215. }
  216. var g errgroup.Group
  217. for _, need := range pr.Needs {
  218. g.Go(func() error {
  219. nd, err := model.ParseDigest(need.Digest)
  220. if err != nil {
  221. return fmt.Errorf("ParseDigest: %w: %s", err, need.Digest)
  222. }
  223. f, err := cache.OpenLayer(nd)
  224. if err != nil {
  225. return fmt.Errorf("OpenLayer: %w: %s", err, need.Digest)
  226. }
  227. defer f.Close()
  228. cp, err := PushLayer(ctx, f, need.URL, need.Start, need.End)
  229. if err != nil {
  230. return fmt.Errorf("PushLayer: %w: %s", err, need.Digest)
  231. }
  232. mu.Lock()
  233. completed = append(completed, cp)
  234. mu.Unlock()
  235. return nil
  236. })
  237. }
  238. if err := g.Wait(); err != nil {
  239. return fmt.Errorf("Push: Required: %w", err)
  240. }
  241. if len(completed) > 0 {
  242. pr, err := push()
  243. if err != nil {
  244. return err
  245. }
  246. if len(pr.Needs) > 0 {
  247. var errs []error
  248. for _, r := range pr.Needs {
  249. errs = append(errs, fmt.Errorf("Push: server failed to find part: %q", r.Digest))
  250. }
  251. return errors.Join(errs...)
  252. }
  253. }
  254. return cache.SetManifestData(mn, manifest)
  255. }
  256. func PushLayer(ctx context.Context, body io.ReaderAt, url string, start, end int64) (*apitype.CompletePart, error) {
  257. if start < 0 || end < start {
  258. return nil, errors.New("start must satisfy 0 <= start <= end")
  259. }
  260. file := io.NewSectionReader(body, start, end-start+1)
  261. req, err := http.NewRequest("PUT", url, file)
  262. if err != nil {
  263. return nil, err
  264. }
  265. req.ContentLength = end - start + 1
  266. // TODO(bmizerany): take content type param
  267. req.Header.Set("Content-Type", "text/plain")
  268. if start != 0 || end != 0 {
  269. req.Header.Set("x-amz-copy-source-range", fmt.Sprintf("bytes=%d-%d", start, end))
  270. }
  271. res, err := http.DefaultClient.Do(req)
  272. if err != nil {
  273. return nil, err
  274. }
  275. defer res.Body.Close()
  276. if res.StatusCode != 200 {
  277. e := parseS3Error(res)
  278. return nil, fmt.Errorf("unexpected status code: %d; %w", res.StatusCode, e)
  279. }
  280. cp := &apitype.CompletePart{
  281. URL: url,
  282. ETag: res.Header.Get("ETag"),
  283. // TODO(bmizerany): checksum
  284. }
  285. return cp, nil
  286. }
  287. type s3Error struct {
  288. XMLName xml.Name `xml:"Error"`
  289. Code string `xml:"Code"`
  290. Message string `xml:"Message"`
  291. Resource string `xml:"Resource"`
  292. RequestId string `xml:"RequestId"`
  293. }
  294. func (e *s3Error) Error() string {
  295. return fmt.Sprintf("S3 (%s): %s: %s: %s", e.RequestId, e.Resource, e.Code, e.Message)
  296. }
  297. // parseS3Error parses an XML error response from S3.
  298. func parseS3Error(res *http.Response) error {
  299. var se *s3Error
  300. if err := xml.NewDecoder(res.Body).Decode(&se); err != nil {
  301. return err
  302. }
  303. return se
  304. }
  305. // TODO: replace below by using upload pkg after we have rangefunc; until
  306. // then, we need to keep this free of rangefunc for now.
  307. type chunkRange[I constraints.Integer] struct {
  308. // Start is the byte offset of the chunk.
  309. Start I
  310. // End is the byte offset of the last byte in the chunk.
  311. End I
  312. }
  313. func (c chunkRange[I]) Size() I {
  314. return c.End - c.Start + 1
  315. }
  316. func (c chunkRange[I]) String() string {
  317. return fmt.Sprintf("%d-%d", c.Start, c.End)
  318. }
  319. func (c chunkRange[I]) LogValue() slog.Value {
  320. return slog.StringValue(c.String())
  321. }
  322. // Chunks yields a sequence of a part number and a Chunk. The Chunk is the offset
  323. // and size of the chunk. The last chunk may be smaller than chunkSize if size is
  324. // not a multiple of chunkSize.
  325. //
  326. // The first part number is 1 and increases monotonically.
  327. func chunks[I constraints.Integer](size, chunkSize I) iter.Seq2[int, chunkRange[I]] {
  328. return func(yield func(int, chunkRange[I]) bool) {
  329. var n int
  330. for off := I(0); off < size; off += chunkSize {
  331. n++
  332. if !yield(n, chunkRange[I]{
  333. Start: off,
  334. End: off + min(chunkSize, size-off) - 1,
  335. }) {
  336. return
  337. }
  338. }
  339. }
  340. }
  341. func redactAmzSignature(s string) string {
  342. u, err := url.Parse(s)
  343. if err != nil {
  344. return ""
  345. }
  346. q := u.Query()
  347. q.Set("X-Amz-Signature", "REDACTED")
  348. u.RawQuery = q.Encode()
  349. return u.String()
  350. }