server_test.go 11 KB


  1. package registry
  2. import (
  3. "bufio"
  4. "bytes"
  5. "cmp"
  6. "context"
  7. "crypto/sha256"
  8. "encoding/json"
  9. "errors"
  10. "fmt"
  11. "io"
  12. "net"
  13. "net/http/httptest"
  14. "net/url"
  15. "os"
  16. "os/exec"
  17. "strconv"
  18. "syscall"
  19. "testing"
  20. "time"
  21. "bllamo.com/registry/apitype"
  22. "bllamo.com/utils/backoff"
  23. "bllamo.com/utils/upload"
  24. "github.com/minio/minio-go/v7"
  25. "github.com/minio/minio-go/v7/pkg/credentials"
  26. "kr.dev/diff"
  27. )
  28. func testPush(t *testing.T, chunkSize int64) {
  29. t.Run(fmt.Sprintf("chunkSize=%d", chunkSize), func(t *testing.T) {
  30. mc := startMinio(t, true)
  31. const MB = 1024 * 1024
  32. // Upload two small layers and one large layer that will
  33. // trigger a multipart upload.
  34. manifest := []byte(`{
  35. "layers": [
  36. {"digest": "sha256-1", "size": 1},
  37. {"digest": "sha256-2", "size": 2},
  38. {"digest": "sha256-3", "size": 11000000}
  39. ]
  40. }`)
  41. const ref = "registry.ollama.ai/x/y:latest+Z"
  42. hs := httptest.NewServer(&Server{
  43. minioClient: mc,
  44. UploadChunkSize: 5 * MB,
  45. })
  46. t.Cleanup(hs.Close)
  47. c := &Client{BaseURL: hs.URL}
  48. requirements, err := c.Push(context.Background(), ref, manifest, nil)
  49. if err != nil {
  50. t.Fatal(err)
  51. }
  52. if len(requirements) < 3 {
  53. t.Fatalf("expected at least 3 requirements; got %d", len(requirements))
  54. t.Logf("requirements: %v", requirements)
  55. }
  56. var uploaded []apitype.CompletePart
  57. for i, r := range requirements {
  58. t.Logf("[%d] pushing layer: offset=%d size=%d", i, r.Offset, r.Size)
  59. cp, err := PushLayer(context.Background(), &abcReader{}, r.URL, r.Offset, r.Size)
  60. if err != nil {
  61. t.Fatal(err)
  62. }
  63. uploaded = append(uploaded, cp)
  64. }
  65. requirements, err = c.Push(context.Background(), ref, manifest, &PushParams{
  66. Uploaded: uploaded,
  67. })
  68. if err != nil {
  69. t.Fatal(err)
  70. }
  71. if len(requirements) != 0 {
  72. t.Fatalf("unexpected requirements: %v", requirements)
  73. }
  74. var paths []string
  75. keys := mc.ListObjects(context.Background(), "test", minio.ListObjectsOptions{
  76. Recursive: true,
  77. })
  78. for k := range keys {
  79. paths = append(paths, k.Key)
  80. }
  81. t.Logf("paths: %v", paths)
  82. diff.Test(t, t.Errorf, paths, []string{
  83. "blobs/sha256-1",
  84. "blobs/sha256-2",
  85. "blobs/sha256-3",
  86. "manifests/registry.ollama.ai/x/y/latest/Z",
  87. })
  88. obj, err := mc.GetObject(context.Background(), "test", "manifests/registry.ollama.ai/x/y/latest/Z", minio.GetObjectOptions{})
  89. if err != nil {
  90. t.Fatal(err)
  91. }
  92. defer obj.Close()
  93. var gotM apitype.Manifest
  94. if err := json.NewDecoder(obj).Decode(&gotM); err != nil {
  95. t.Fatal(err)
  96. }
  97. diff.Test(t, t.Errorf, gotM, apitype.Manifest{
  98. Layers: []apitype.Layer{
  99. {Digest: "sha256-1", Size: 1},
  100. {Digest: "sha256-2", Size: 2},
  101. {Digest: "sha256-3", Size: 3},
  102. },
  103. })
  104. // checksum the blobs
  105. for i, l := range gotM.Layers {
  106. obj, err := mc.GetObject(context.Background(), "test", "blobs/"+l.Digest, minio.GetObjectOptions{})
  107. if err != nil {
  108. t.Fatal(err)
  109. }
  110. defer obj.Close()
  111. info, err := obj.Stat()
  112. if err != nil {
  113. t.Fatal(err)
  114. }
  115. t.Logf("[%d] layer info: name=%q l.Size=%d size=%d", i, info.Key, l.Size, info.Size)
  116. if msg := checkABCs(obj, int(l.Size)); msg != "" {
  117. t.Errorf("[%d] %s", i, msg)
  118. }
  119. }
  120. })
  121. }
  122. func TestPush(t *testing.T) {
  123. testPush(t, 0)
  124. testPush(t, 1)
  125. }
  126. // TestBasicPresignS3MultipartReferenceDoNotDelete tests the basic flow of
  127. // presigning a multipart upload, uploading the parts, and completing the
  128. // upload. It is for future reference and should not be deleted. This flow
  129. // is tricky and if we get it wrong in our server, we can refer back to this
  130. // as a "back to basics" test/reference.
  131. func TestBasicPresignS3MultipartReferenceDoNotDelete(t *testing.T) {
  132. t.Skip("skipping reference test; unskip when needed")
  133. mc := startMinio(t, true)
  134. mcc := &minio.Core{Client: mc}
  135. uploadID, err := mcc.NewMultipartUpload(context.Background(), "test", "theKey", minio.PutObjectOptions{})
  136. if err != nil {
  137. t.Fatal(err)
  138. }
  139. var completed []minio.CompletePart
  140. const size int64 = 10 * 1024 * 1024
  141. const chunkSize = 5 * 1024 * 1024
  142. for partNumber, c := range upload.Chunks(size, chunkSize) {
  143. u, err := mcc.Presign(context.Background(), "PUT", "test", "theKey", 15*time.Minute, url.Values{
  144. "partNumber": {strconv.Itoa(partNumber)},
  145. "uploadId": {uploadID},
  146. })
  147. if err != nil {
  148. t.Fatalf("[partNumber=%d]: %v", partNumber, err)
  149. }
  150. t.Logf("[partNumber=%d]: %v", partNumber, u)
  151. var body abcReader
  152. cp, err := PushLayer(context.Background(), &body, u.String(), c.Offset, c.N)
  153. if err != nil {
  154. t.Fatalf("[partNumber=%d]: %v", partNumber, err)
  155. }
  156. t.Logf("completed part: %v", cp)
  157. // behave like server here (don't cheat and use partNumber)
  158. // instead get partNumber from the URL
  159. retPartNumber, err := strconv.Atoi(u.Query().Get("partNumber"))
  160. if err != nil {
  161. t.Fatalf("[partNumber=%d]: %v", partNumber, err)
  162. }
  163. completed = append(completed, minio.CompletePart{
  164. PartNumber: retPartNumber,
  165. ETag: cp.ETag,
  166. })
  167. }
  168. defer func() {
  169. // fail if there are any incomplete uploads
  170. for x := range mcc.ListIncompleteUploads(context.Background(), "test", "theKey", true) {
  171. t.Errorf("incomplete: %v", x)
  172. }
  173. }()
  174. info, err := mcc.CompleteMultipartUpload(context.Background(), "test", "theKey", uploadID, completed, minio.PutObjectOptions{})
  175. if err != nil {
  176. t.Fatal(err)
  177. }
  178. t.Logf("completed: %v", info)
  179. // Check key in bucket
  180. obj, err := mc.GetObject(context.Background(), "test", "theKey", minio.GetObjectOptions{})
  181. if err != nil {
  182. t.Fatal(err)
  183. }
  184. defer obj.Close()
  185. h := sha256.New()
  186. if _, err := io.Copy(h, obj); err != nil {
  187. t.Fatal(err)
  188. }
  189. gotSum := h.Sum(nil)
  190. h.Reset()
  191. var body abcReader
  192. if _, err := io.CopyN(h, &body, size); err != nil {
  193. t.Fatal(err)
  194. }
  195. wantSum := h.Sum(nil)
  196. if !bytes.Equal(gotSum, wantSum) {
  197. t.Errorf("got sum = %x; want %x", gotSum, wantSum)
  198. }
  199. }
  200. func availableAddr() string {
  201. l, err := net.Listen("tcp", "localhost:0")
  202. if err != nil {
  203. panic(err)
  204. }
  205. defer l.Close()
  206. return l.Addr().String()
  207. }
  208. // tracing is "experimental" and may be removed in the future, I can't get it to
  209. // work consistently, but I'm leaving it in for now.
  210. func startMinio(t *testing.T, trace bool) *minio.Client {
  211. t.Helper()
  212. // Trace is enabled by setting the OLLAMA_MINIO_TRACE environment or
  213. // explicitly setting trace to true.
  214. trace = cmp.Or(trace, os.Getenv("OLLAMA_MINIO_TRACE") != "")
  215. dir := t.TempDir()
  216. t.Cleanup(func() {
  217. // TODO(bmizerany): trim temp dir based on dates so that
  218. // future runs may be able to inspect results for some time.
  219. })
  220. waitAndMaybeLogError := func(cmd *exec.Cmd) {
  221. if err := cmd.Wait(); err != nil {
  222. var e *exec.ExitError
  223. if errors.As(err, &e) {
  224. if e.Exited() {
  225. return
  226. }
  227. t.Logf("startMinio: %s stderr: %s", cmd.Path, e.Stderr)
  228. t.Logf("startMinio: %s exit status: %v", cmd.Path, e.ExitCode())
  229. t.Logf("startMinio: %s exited: %v", cmd.Path, e.Exited())
  230. t.Logf("startMinio: %s stderr: %s", cmd.Path, e.Stderr)
  231. } else {
  232. if errors.Is(err, context.Canceled) {
  233. return
  234. }
  235. t.Logf("startMinio: %s exit error: %v", cmd.Path, err)
  236. }
  237. }
  238. }
  239. // Cancel must be called first so do wait to add to Cleanup
  240. // stack as last cleanup.
  241. ctx, cancel := context.WithCancel(context.Background())
  242. deadline, ok := t.Deadline()
  243. if ok {
  244. ctx, cancel = context.WithDeadline(ctx, deadline.Add(-100*time.Millisecond))
  245. }
  246. t.Logf(">> minio: minio server %s", dir)
  247. addr := availableAddr()
  248. cmd := exec.CommandContext(ctx, "minio", "server", "--address", addr, dir)
  249. cmd.Env = os.Environ()
  250. cmd.WaitDelay = 3 * time.Second
  251. cmd.Cancel = func() error {
  252. return cmd.Process.Signal(syscall.SIGQUIT)
  253. }
  254. if err := cmd.Start(); err != nil {
  255. t.Fatalf("startMinio: %v", err)
  256. }
  257. t.Cleanup(func() {
  258. cancel()
  259. waitAndMaybeLogError(cmd)
  260. })
  261. mc, err := minio.New(addr, &minio.Options{
  262. Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""),
  263. Secure: false,
  264. })
  265. if err != nil {
  266. t.Fatalf("startMinio: %v", err)
  267. }
  268. // wait for server to start with exponential backoff
  269. for _, err := range backoff.Upto(ctx, 1*time.Second) {
  270. if err != nil {
  271. t.Fatalf("startMinio: %v", err)
  272. }
  273. if mc.IsOnline() {
  274. break
  275. }
  276. }
  277. if trace {
  278. cmd := exec.CommandContext(ctx, "mc", "admin", "trace", "--verbose", "test")
  279. cmd.Env = append(os.Environ(),
  280. "MC_HOST_test=http://minioadmin:minioadmin@"+addr,
  281. )
  282. cmd.WaitDelay = 3 * time.Second
  283. cmd.Cancel = func() error {
  284. return cmd.Process.Signal(syscall.SIGQUIT)
  285. }
  286. stdout, err := cmd.StdoutPipe()
  287. if err != nil {
  288. t.Fatalf("startMinio: %v", err)
  289. }
  290. if err := cmd.Start(); err != nil {
  291. t.Fatalf("startMinio: %v", err)
  292. }
  293. doneLogging := make(chan struct{})
  294. sc := bufio.NewScanner(stdout)
  295. go func() {
  296. defer close(doneLogging)
  297. // Scan lines until the process exits.
  298. for sc.Scan() {
  299. t.Logf("startMinio: mc trace: %s", sc.Text())
  300. }
  301. _ = sc.Err() // ignore (not important)
  302. }()
  303. t.Cleanup(func() {
  304. cancel()
  305. waitAndMaybeLogError(cmd)
  306. // Make sure we do not log after test exists to
  307. // avoid panic.
  308. <-doneLogging
  309. })
  310. }
  311. if err := mc.MakeBucket(context.Background(), "test", minio.MakeBucketOptions{}); err != nil {
  312. t.Fatalf("startMinio: %v", err)
  313. }
  314. return mc
  315. }
  316. // contextForTest returns a context that is canceled when the test deadline,
  317. // if any, is reached. The returned doneLogging function should be called
  318. // after all Log/Error/Fatalf calls are done before the test returns.
  319. func contextForTest(t *testing.T) (_ context.Context, doneLogging func()) {
  320. done := make(chan struct{})
  321. deadline, ok := t.Deadline()
  322. if !ok {
  323. return context.Background(), func() {}
  324. }
  325. ctx, cancel := context.WithDeadline(context.Background(), deadline.Add(-100*time.Millisecond))
  326. t.Cleanup(func() {
  327. cancel()
  328. <-done
  329. })
  330. return ctx, func() { close(done) }
  331. }
  332. // abcReader repeats the string s infinitely.
  333. type abcReader struct {
  334. pos int
  335. }
  336. const theABCs = "abcdefghijklmnopqrstuvwxyz"
  337. func (r *abcReader) Read(p []byte) (n int, err error) {
  338. for i := range p {
  339. p[i] = theABCs[r.pos]
  340. r.pos++
  341. if r.pos == len(theABCs) {
  342. r.pos = 0
  343. }
  344. }
  345. return len(p), nil
  346. }
  347. func (r *abcReader) ReadAt(p []byte, off int64) (n int, err error) {
  348. for i := range p {
  349. p[i] = theABCs[(off+int64(i))%int64(len(theABCs))]
  350. }
  351. return len(p), nil
  352. }
  353. func checkABCs(r io.Reader, size int) (reason string) {
  354. h := sha256.New()
  355. n, err := io.CopyN(h, &abcReader{}, int64(size))
  356. if err != nil {
  357. return err.Error()
  358. }
  359. if n != int64(size) {
  360. panic("short read; should not happen")
  361. }
  362. want := h.Sum(nil)
  363. h = sha256.New()
  364. n, err = io.Copy(h, r)
  365. if err != nil {
  366. return err.Error()
  367. }
  368. if n != int64(size) {
  369. return fmt.Sprintf("got len(r) = %d; want %d", n, size)
  370. }
  371. got := h.Sum(nil)
  372. if !bytes.Equal(got, want) {
  373. return fmt.Sprintf("got sum = %x; want %x", got, want)
  374. }
  375. return ""
  376. }