server_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465
  1. package registry
  2. import (
  3. "bufio"
  4. "bytes"
  5. "cmp"
  6. "context"
  7. "crypto/sha256"
  8. "encoding/json"
  9. "errors"
  10. "fmt"
  11. "io"
  12. "net"
  13. "net/http"
  14. "net/http/httptest"
  15. "net/url"
  16. "os"
  17. "os/exec"
  18. "strconv"
  19. "strings"
  20. "syscall"
  21. "testing"
  22. "time"
  23. "bllamo.com/registry/apitype"
  24. "bllamo.com/utils/backoff"
  25. "bllamo.com/utils/upload"
  26. "github.com/minio/minio-go/v7"
  27. "github.com/minio/minio-go/v7/pkg/credentials"
  28. "kr.dev/diff"
  29. )
  30. const abc = "abcdefghijklmnopqrstuvwxyz"
  31. func testPush(t *testing.T, chunkSize int64) {
  32. t.Run(fmt.Sprintf("chunkSize=%d", chunkSize), func(t *testing.T) {
  33. mc := startMinio(t, true)
  34. const MB = 1024 * 1024
  35. // Upload two small layers and one large layer that will
  36. // trigger a multipart upload.
  37. manifest := []byte(`{
  38. "layers": [
  39. {"digest": "sha256-1", "size": 1},
  40. {"digest": "sha256-2", "size": 2},
  41. {"digest": "sha256-3", "size": 11000000}
  42. ]
  43. }`)
  44. const ref = "registry.ollama.ai/x/y:latest+Z"
  45. hs := httptest.NewServer(&Server{
  46. minioClient: mc,
  47. UploadChunkSize: 5 * MB,
  48. })
  49. t.Cleanup(hs.Close)
  50. c := &Client{BaseURL: hs.URL}
  51. requirements, err := c.Push(context.Background(), ref, manifest, nil)
  52. if err != nil {
  53. t.Fatal(err)
  54. }
  55. if len(requirements) < 3 {
  56. t.Fatalf("expected at least 3 requirements; got %d", len(requirements))
  57. t.Logf("requirements: %v", requirements)
  58. }
  59. var uploaded []apitype.CompletePart
  60. for i, r := range requirements {
  61. t.Logf("[%d] pushing layer: offset=%d size=%d", i, r.Offset, r.Size)
  62. body := strings.NewReader(abc)
  63. etag, err := PushLayer(context.Background(), r.URL, r.Offset, r.Size, body)
  64. if err != nil {
  65. t.Fatal(err)
  66. }
  67. uploaded = append(uploaded, apitype.CompletePart{
  68. URL: r.URL,
  69. ETag: etag,
  70. })
  71. }
  72. requirements, err = c.Push(context.Background(), ref, manifest, &PushParams{
  73. Uploaded: uploaded,
  74. })
  75. if err != nil {
  76. t.Fatal(err)
  77. }
  78. if len(requirements) != 0 {
  79. t.Fatalf("unexpected requirements: %v", requirements)
  80. }
  81. var paths []string
  82. keys := mc.ListObjects(context.Background(), "test", minio.ListObjectsOptions{
  83. Recursive: true,
  84. })
  85. for k := range keys {
  86. paths = append(paths, k.Key)
  87. }
  88. t.Logf("paths: %v", paths)
  89. diff.Test(t, t.Errorf, paths, []string{
  90. "blobs/sha256-1",
  91. "blobs/sha256-2",
  92. "blobs/sha256-3",
  93. "manifests/registry.ollama.ai/x/y/latest/Z",
  94. })
  95. obj, err := mc.GetObject(context.Background(), "test", "manifests/registry.ollama.ai/x/y/latest/Z", minio.GetObjectOptions{})
  96. if err != nil {
  97. t.Fatal(err)
  98. }
  99. defer obj.Close()
  100. var gotM apitype.Manifest
  101. if err := json.NewDecoder(obj).Decode(&gotM); err != nil {
  102. t.Fatal(err)
  103. }
  104. diff.Test(t, t.Errorf, gotM, apitype.Manifest{
  105. Layers: []apitype.Layer{
  106. {Digest: "sha256-1", Size: 1},
  107. {Digest: "sha256-2", Size: 2},
  108. {Digest: "sha256-3", Size: 3},
  109. },
  110. })
  111. // checksum the blobs
  112. for i, l := range gotM.Layers {
  113. obj, err := mc.GetObject(context.Background(), "test", "blobs/"+l.Digest, minio.GetObjectOptions{})
  114. if err != nil {
  115. t.Fatal(err)
  116. }
  117. defer obj.Close()
  118. info, err := obj.Stat()
  119. if err != nil {
  120. t.Fatal(err)
  121. }
  122. t.Logf("[%d] layer info: name=%q l.Size=%d size=%d", i, info.Key, l.Size, info.Size)
  123. data, err := io.ReadAll(obj)
  124. if err != nil {
  125. t.Fatal(err)
  126. }
  127. got := string(data)
  128. want := abc[:l.Size]
  129. if got != want {
  130. t.Errorf("[%d] got layer data = %q; want %q", i, got, want)
  131. }
  132. }
  133. })
  134. }
  135. func TestPush(t *testing.T) {
  136. testPush(t, 0)
  137. testPush(t, 1)
  138. }
  139. func pushLayer(body io.ReaderAt, url string, off, n int64) (apitype.CompletePart, error) {
  140. var zero apitype.CompletePart
  141. if off < 0 {
  142. return zero, errors.New("off must be >0")
  143. }
  144. file := io.NewSectionReader(body, off, n)
  145. req, err := http.NewRequest("PUT", url, file)
  146. if err != nil {
  147. return zero, err
  148. }
  149. req.ContentLength = n
  150. // TODO(bmizerany): take content type param
  151. req.Header.Set("Content-Type", "text/plain")
  152. if n >= 0 {
  153. req.Header.Set("x-amz-copy-source-range", fmt.Sprintf("bytes=%d-%d", off, off+n-1))
  154. }
  155. res, err := http.DefaultClient.Do(req)
  156. if err != nil {
  157. return zero, err
  158. }
  159. defer res.Body.Close()
  160. if res.StatusCode != 200 {
  161. e := parseS3Error(res)
  162. return zero, fmt.Errorf("unexpected status code: %d; %w", res.StatusCode, e)
  163. }
  164. etag := strings.Trim(res.Header.Get("ETag"), `"`)
  165. cp := apitype.CompletePart{
  166. URL: url,
  167. ETag: etag,
  168. // TODO(bmizerany): checksum
  169. }
  170. return cp, nil
  171. }
  172. // TestBasicPresignS3MultipartReferenceDoNotDelete tests the basic flow of
  173. // presigning a multipart upload, uploading the parts, and completing the
  174. // upload. It is for future reference and should not be deleted. This flow
  175. // is tricky and if we get it wrong in our server, we can refer back to this
  176. // as a "back to basics" test/reference.
  177. func TestBasicPresignS3MultipartReferenceDoNotDelete(t *testing.T) {
  178. t.Skip("skipping reference test; unskip when needed")
  179. mc := startMinio(t, true)
  180. mcc := &minio.Core{Client: mc}
  181. uploadID, err := mcc.NewMultipartUpload(context.Background(), "test", "theKey", minio.PutObjectOptions{})
  182. if err != nil {
  183. t.Fatal(err)
  184. }
  185. var completed []minio.CompletePart
  186. const size int64 = 10 * 1024 * 1024
  187. const chunkSize = 5 * 1024 * 1024
  188. for partNumber, c := range upload.Chunks(size, chunkSize) {
  189. u, err := mcc.Presign(context.Background(), "PUT", "test", "theKey", 15*time.Minute, url.Values{
  190. "partNumber": {strconv.Itoa(partNumber)},
  191. "uploadId": {uploadID},
  192. })
  193. if err != nil {
  194. t.Fatalf("[partNumber=%d]: %v", partNumber, err)
  195. }
  196. t.Logf("[partNumber=%d]: %v", partNumber, u)
  197. var body abcReader
  198. cp, err := pushLayer(&body, u.String(), c.Offset, c.N)
  199. if err != nil {
  200. t.Fatalf("[partNumber=%d]: %v", partNumber, err)
  201. }
  202. t.Logf("completed part: %v", cp)
  203. // behave like server here (don't cheat and use partNumber)
  204. // instead get partNumber from the URL
  205. retPartNumber, err := strconv.Atoi(u.Query().Get("partNumber"))
  206. if err != nil {
  207. t.Fatalf("[partNumber=%d]: %v", partNumber, err)
  208. }
  209. completed = append(completed, minio.CompletePart{
  210. PartNumber: retPartNumber,
  211. ETag: cp.ETag,
  212. })
  213. }
  214. defer func() {
  215. // fail if there are any incomplete uploads
  216. for x := range mcc.ListIncompleteUploads(context.Background(), "test", "theKey", true) {
  217. t.Errorf("incomplete: %v", x)
  218. }
  219. }()
  220. info, err := mcc.CompleteMultipartUpload(context.Background(), "test", "theKey", uploadID, completed, minio.PutObjectOptions{})
  221. if err != nil {
  222. t.Fatal(err)
  223. }
  224. t.Logf("completed: %v", info)
  225. // Check key in bucket
  226. obj, err := mc.GetObject(context.Background(), "test", "theKey", minio.GetObjectOptions{})
  227. if err != nil {
  228. t.Fatal(err)
  229. }
  230. defer obj.Close()
  231. h := sha256.New()
  232. if _, err := io.Copy(h, obj); err != nil {
  233. t.Fatal(err)
  234. }
  235. gotSum := h.Sum(nil)
  236. h.Reset()
  237. var body abcReader
  238. if _, err := io.CopyN(h, &body, size); err != nil {
  239. t.Fatal(err)
  240. }
  241. wantSum := h.Sum(nil)
  242. if !bytes.Equal(gotSum, wantSum) {
  243. t.Errorf("got sum = %x; want %x", gotSum, wantSum)
  244. }
  245. }
  246. func availableAddr() string {
  247. l, err := net.Listen("tcp", "localhost:0")
  248. if err != nil {
  249. panic(err)
  250. }
  251. defer l.Close()
  252. return l.Addr().String()
  253. }
  254. // tracing is "experimental" and may be removed in the future, I can't get it to
  255. // work consistently, but I'm leaving it in for now.
  256. func startMinio(t *testing.T, trace bool) *minio.Client {
  257. t.Helper()
  258. // Trace is enabled by setting the OLLAMA_MINIO_TRACE environment or
  259. // explicitly setting trace to true.
  260. trace = cmp.Or(trace, os.Getenv("OLLAMA_MINIO_TRACE") != "")
  261. dir := t.TempDir() + "-keep" // prevent tempdir from auto delete
  262. t.Cleanup(func() {
  263. // TODO(bmizerany): trim temp dir based on dates so that
  264. // future runs may be able to inspect results for some time.
  265. })
  266. waitAndMaybeLogError := func(cmd *exec.Cmd) {
  267. if err := cmd.Wait(); err != nil {
  268. var e *exec.ExitError
  269. if errors.As(err, &e) {
  270. if !e.Exited() {
  271. // died due to our signal
  272. return
  273. }
  274. t.Errorf("startMinio: %s stderr: %s", cmd.Path, e.Stderr)
  275. t.Errorf("startMinio: %s exit status: %v", cmd.Path, e.ExitCode())
  276. t.Errorf("startMinio: %s exited: %v", cmd.Path, e.Exited())
  277. t.Errorf("startMinio: %s stderr: %s", cmd.Path, e.Stderr)
  278. } else {
  279. if errors.Is(err, context.Canceled) {
  280. return
  281. }
  282. t.Errorf("startMinio: %s exit error: %v", cmd.Path, err)
  283. }
  284. }
  285. }
  286. // Cancel must be called first so do wait to add to Cleanup
  287. // stack as last cleanup.
  288. ctx, cancel := context.WithCancel(context.Background())
  289. deadline, ok := t.Deadline()
  290. if ok {
  291. ctx, cancel = context.WithDeadline(ctx, deadline.Add(-100*time.Millisecond))
  292. }
  293. t.Logf(">> minio: minio server %s", dir)
  294. addr := availableAddr()
  295. cmd := exec.CommandContext(ctx, "minio", "server", "--address", addr, dir)
  296. cmd.Env = os.Environ()
  297. cmd.WaitDelay = 3 * time.Second
  298. cmd.Cancel = func() error {
  299. return cmd.Process.Signal(syscall.SIGQUIT)
  300. }
  301. if err := cmd.Start(); err != nil {
  302. t.Fatalf("startMinio: %v", err)
  303. }
  304. t.Cleanup(func() {
  305. cancel()
  306. waitAndMaybeLogError(cmd)
  307. })
  308. mc, err := minio.New(addr, &minio.Options{
  309. Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""),
  310. Secure: false,
  311. })
  312. if err != nil {
  313. t.Fatalf("startMinio: %v", err)
  314. }
  315. // wait for server to start with exponential backoff
  316. for _, err := range backoff.Upto(ctx, 1*time.Second) {
  317. if err != nil {
  318. t.Fatalf("startMinio: %v", err)
  319. }
  320. if mc.IsOnline() {
  321. break
  322. }
  323. }
  324. if trace {
  325. cmd := exec.CommandContext(ctx, "mc", "admin", "trace", "--verbose", "test")
  326. cmd.Env = append(os.Environ(),
  327. "MC_HOST_test=http://minioadmin:minioadmin@"+addr,
  328. )
  329. cmd.WaitDelay = 3 * time.Second
  330. cmd.Cancel = func() error {
  331. return cmd.Process.Signal(syscall.SIGQUIT)
  332. }
  333. stdout, err := cmd.StdoutPipe()
  334. if err != nil {
  335. t.Fatalf("startMinio: %v", err)
  336. }
  337. if err := cmd.Start(); err != nil {
  338. t.Fatalf("startMinio: %v", err)
  339. }
  340. doneLogging := make(chan struct{})
  341. sc := bufio.NewScanner(stdout)
  342. go func() {
  343. defer close(doneLogging)
  344. // Scan lines until the process exits.
  345. for sc.Scan() {
  346. t.Logf("startMinio: mc trace: %s", sc.Text())
  347. }
  348. _ = sc.Err() // ignore (not important)
  349. }()
  350. t.Cleanup(func() {
  351. cancel()
  352. waitAndMaybeLogError(cmd)
  353. // Make sure we do not log after test exists to
  354. // avoid panic.
  355. <-doneLogging
  356. })
  357. }
  358. if err := mc.MakeBucket(context.Background(), "test", minio.MakeBucketOptions{}); err != nil {
  359. t.Fatalf("startMinio: %v", err)
  360. }
  361. return mc
  362. }
  363. // contextForTest returns a context that is canceled when the test deadline,
  364. // if any, is reached. The returned doneLogging function should be called
  365. // after all Log/Error/Fatalf calls are done before the test returns.
  366. func contextForTest(t *testing.T) (_ context.Context, doneLogging func()) {
  367. done := make(chan struct{})
  368. deadline, ok := t.Deadline()
  369. if !ok {
  370. return context.Background(), func() {}
  371. }
  372. ctx, cancel := context.WithDeadline(context.Background(), deadline.Add(-100*time.Millisecond))
  373. t.Cleanup(func() {
  374. cancel()
  375. <-done
  376. })
  377. return ctx, func() { close(done) }
  378. }
  379. // abcReader repeats the string s infinitely.
  380. type abcReader struct {
  381. pos int
  382. }
  383. const theABCs = "abcdefghijklmnopqrstuvwxyz"
  384. func (r *abcReader) Read(p []byte) (n int, err error) {
  385. for i := range p {
  386. p[i] = theABCs[r.pos]
  387. r.pos++
  388. if r.pos == len(theABCs) {
  389. r.pos = 0
  390. }
  391. }
  392. return len(p), nil
  393. }
  394. func (r *abcReader) ReadAt(p []byte, off int64) (n int, err error) {
  395. for i := range p {
  396. p[i] = theABCs[(off+int64(i))%int64(len(theABCs))]
  397. }
  398. return len(p), nil
  399. }