server_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473
  1. package registry
  2. import (
  3. "bufio"
  4. "bytes"
  5. "cmp"
  6. "context"
  7. "crypto/sha256"
  8. "encoding/json"
  9. "errors"
  10. "fmt"
  11. "io"
  12. "net"
  13. "net/http/httptest"
  14. "net/url"
  15. "os"
  16. "os/exec"
  17. "strconv"
  18. "syscall"
  19. "testing"
  20. "time"
  21. "github.com/minio/minio-go/v7"
  22. "github.com/minio/minio-go/v7/pkg/credentials"
  23. "github.com/ollama/ollama/x/registry/apitype"
  24. "github.com/ollama/ollama/x/utils/backoff"
  25. "github.com/ollama/ollama/x/utils/upload"
  26. "kr.dev/diff"
  27. )
  28. // const ref = "registry.ollama.ai/x/y:latest+Z"
  29. // const manifest = `{
  30. // "layers": [
  31. // {"digest": "sha256-1", "size": 1},
  32. // {"digest": "sha256-2", "size": 2},
  33. // {"digest": "sha256-3", "size": 3}
  34. // ]
  35. // }`
  36. // ts := newTestServer(t)
  37. // ts.pushNotOK(ref, `{}`, &ollama.Error{
  38. // Status: 400,
  39. // Code: "invalid",
  40. // Message: "name must be fully qualified",
  41. // })
  42. // ts.push(ref, `{
  43. // "layers": [
  44. // {"digest": "sha256-1", "size": 1},
  45. // {"digest": "sha256-2", "size": 2},
  46. // {"digest": "sha256-3", "size": 3}
  47. // ]
  48. // }`)
  49. type tWriter struct {
  50. t *testing.T
  51. }
  52. func (w tWriter) Write(p []byte) (n int, err error) {
  53. w.t.Logf("%s", p)
  54. return len(p), nil
  55. }
  56. func TestPushBasic(t *testing.T) {
  57. const MB = 1024 * 1024
  58. mc := startMinio(t, true)
  59. defer func() {
  60. mcc := &minio.Core{Client: mc}
  61. // fail if there are any incomplete uploads
  62. for x := range mcc.ListIncompleteUploads(context.Background(), "test", "theKey", true) {
  63. t.Errorf("incomplete: %v", x)
  64. }
  65. }()
  66. const ref = "registry.ollama.ai/x/y:latest+Z"
  67. // Upload two small layers and one large layer that will
  68. // trigger a multipart upload.
  69. manifest := []byte(`{
  70. "layers": [
  71. {"digest": "sha256-1", "size": 1},
  72. {"digest": "sha256-2", "size": 2},
  73. {"digest": "sha256-3", "size": 11000000}
  74. ]
  75. }`)
  76. hs := httptest.NewServer(&Server{
  77. S3Client: mc,
  78. UploadChunkSize: 5 * MB,
  79. })
  80. t.Cleanup(hs.Close)
  81. c := &Client{BaseURL: hs.URL}
  82. requirements, err := c.Push(context.Background(), ref, manifest, nil)
  83. if err != nil {
  84. t.Fatal(err)
  85. }
  86. if len(requirements) < 3 {
  87. t.Errorf("expected at least 3 requirements; got %d", len(requirements))
  88. t.Logf("requirements: %v", requirements)
  89. }
  90. var uploaded []apitype.CompletePart
  91. for i, r := range requirements {
  92. t.Logf("[%d] pushing layer: offset=%d size=%d", i, r.Offset, r.Size)
  93. cp, err := PushLayer(context.Background(), &abcReader{}, r.URL, r.Offset, r.Size)
  94. if err != nil {
  95. t.Fatal(err)
  96. }
  97. uploaded = append(uploaded, cp)
  98. }
  99. requirements, err = c.Push(context.Background(), ref, manifest, &PushParams{
  100. CompleteParts: uploaded,
  101. })
  102. if err != nil {
  103. t.Fatal(err)
  104. }
  105. if len(requirements) != 0 {
  106. t.Errorf("unexpected requirements: %v", requirements)
  107. }
  108. var paths []string
  109. keys := mc.ListObjects(context.Background(), "test", minio.ListObjectsOptions{
  110. Recursive: true,
  111. })
  112. for k := range keys {
  113. paths = append(paths, k.Key)
  114. }
  115. t.Logf("paths: %v", paths)
  116. diff.Test(t, t.Errorf, paths, []string{
  117. "blobs/sha256-1",
  118. "blobs/sha256-2",
  119. "blobs/sha256-3",
  120. "manifests/registry.ollama.ai/x/y/latest/Z",
  121. })
  122. obj, err := mc.GetObject(context.Background(), "test", "manifests/registry.ollama.ai/x/y/latest/Z", minio.GetObjectOptions{})
  123. if err != nil {
  124. t.Fatal(err)
  125. }
  126. defer obj.Close()
  127. var gotM apitype.Manifest
  128. if err := json.NewDecoder(obj).Decode(&gotM); err != nil {
  129. t.Fatal(err)
  130. }
  131. diff.Test(t, t.Errorf, gotM, apitype.Manifest{
  132. Layers: []apitype.Layer{
  133. {Digest: "sha256-1", Size: 1},
  134. {Digest: "sha256-2", Size: 2},
  135. {Digest: "sha256-3", Size: 11000000},
  136. },
  137. })
  138. // checksum the blobs
  139. for i, l := range gotM.Layers {
  140. obj, err := mc.GetObject(context.Background(), "test", "blobs/"+l.Digest, minio.GetObjectOptions{})
  141. if err != nil {
  142. t.Fatal(err)
  143. }
  144. defer obj.Close()
  145. info, err := obj.Stat()
  146. if err != nil {
  147. t.Fatal(err)
  148. }
  149. t.Logf("[%d] layer info: name=%q l.Size=%d size=%d", i, info.Key, l.Size, info.Size)
  150. if msg := checkABCs(obj, int(l.Size)); msg != "" {
  151. t.Errorf("[%d] %s", i, msg)
  152. }
  153. }
  154. }
  155. // TestBasicPresignS3MultipartReferenceDoNotDelete tests the basic flow of
  156. // presigning a multipart upload, uploading the parts, and completing the
  157. // upload. It is for future reference and should not be deleted. This flow
  158. // is tricky and if we get it wrong in our server, we can refer back to this
  159. // as a "back to basics" test/reference.
  160. func TestBasicPresignS3MultipartReferenceDoNotDelete(t *testing.T) {
  161. t.Skip("skipping reference test; unskip when needed")
  162. mc := startMinio(t, true)
  163. mcc := &minio.Core{Client: mc}
  164. uploadID, err := mcc.NewMultipartUpload(context.Background(), "test", "theKey", minio.PutObjectOptions{})
  165. if err != nil {
  166. t.Fatal(err)
  167. }
  168. var completed []minio.CompletePart
  169. const size int64 = 10 * 1024 * 1024
  170. const chunkSize = 5 * 1024 * 1024
  171. for partNumber, c := range upload.Chunks(size, chunkSize) {
  172. u, err := mcc.Presign(context.Background(), "PUT", "test", "theKey", 15*time.Minute, url.Values{
  173. "partNumber": {strconv.Itoa(partNumber)},
  174. "uploadId": {uploadID},
  175. })
  176. if err != nil {
  177. t.Fatalf("[partNumber=%d]: %v", partNumber, err)
  178. }
  179. t.Logf("[partNumber=%d]: %v", partNumber, u)
  180. var body abcReader
  181. cp, err := PushLayer(context.Background(), &body, u.String(), c.Offset, c.N)
  182. if err != nil {
  183. t.Fatalf("[partNumber=%d]: %v", partNumber, err)
  184. }
  185. t.Logf("completed part: %v", cp)
  186. // behave like server here (don't cheat and use partNumber)
  187. // instead get partNumber from the URL
  188. retPartNumber, err := strconv.Atoi(u.Query().Get("partNumber"))
  189. if err != nil {
  190. t.Fatalf("[partNumber=%d]: %v", partNumber, err)
  191. }
  192. completed = append(completed, minio.CompletePart{
  193. PartNumber: retPartNumber,
  194. ETag: cp.ETag,
  195. })
  196. }
  197. defer func() {
  198. // fail if there are any incomplete uploads
  199. for x := range mcc.ListIncompleteUploads(context.Background(), "test", "theKey", true) {
  200. t.Errorf("incomplete: %v", x)
  201. }
  202. }()
  203. info, err := mcc.CompleteMultipartUpload(context.Background(), "test", "theKey", uploadID, completed, minio.PutObjectOptions{})
  204. if err != nil {
  205. t.Fatal(err)
  206. }
  207. t.Logf("completed: %v", info)
  208. // Check key in bucket
  209. obj, err := mc.GetObject(context.Background(), "test", "theKey", minio.GetObjectOptions{})
  210. if err != nil {
  211. t.Fatal(err)
  212. }
  213. defer obj.Close()
  214. h := sha256.New()
  215. if _, err := io.Copy(h, obj); err != nil {
  216. t.Fatal(err)
  217. }
  218. gotSum := h.Sum(nil)
  219. h.Reset()
  220. var body abcReader
  221. if _, err := io.CopyN(h, &body, size); err != nil {
  222. t.Fatal(err)
  223. }
  224. wantSum := h.Sum(nil)
  225. if !bytes.Equal(gotSum, wantSum) {
  226. t.Errorf("got sum = %x; want %x", gotSum, wantSum)
  227. }
  228. }
  229. func availableAddr() string {
  230. l, err := net.Listen("tcp", "localhost:0")
  231. if err != nil {
  232. panic(err)
  233. }
  234. defer l.Close()
  235. return l.Addr().String()
  236. }
  237. // tracing is "experimental" and may be removed in the future, I can't get it to
  238. // work consistently, but I'm leaving it in for now.
  239. func startMinio(t *testing.T, trace bool) *minio.Client {
  240. t.Helper()
  241. // Trace is enabled by setting the OLLAMA_MINIO_TRACE environment or
  242. // explicitly setting trace to true.
  243. trace = cmp.Or(trace, os.Getenv("OLLAMA_MINIO_TRACE") != "")
  244. dir := t.TempDir()
  245. t.Cleanup(func() {
  246. // TODO(bmizerany): trim temp dir based on dates so that
  247. // future runs may be able to inspect results for some time.
  248. })
  249. waitAndMaybeLogError := func(cmd *exec.Cmd) {
  250. if err := cmd.Wait(); err != nil {
  251. var e *exec.ExitError
  252. if errors.As(err, &e) {
  253. if e.Exited() {
  254. return
  255. }
  256. t.Logf("startMinio: %s stderr: %s", cmd.Path, e.Stderr)
  257. t.Logf("startMinio: %s exit status: %v", cmd.Path, e.ExitCode())
  258. t.Logf("startMinio: %s exited: %v", cmd.Path, e.Exited())
  259. t.Logf("startMinio: %s stderr: %s", cmd.Path, e.Stderr)
  260. } else {
  261. if errors.Is(err, context.Canceled) {
  262. return
  263. }
  264. t.Logf("startMinio: %s exit error: %v", cmd.Path, err)
  265. }
  266. }
  267. }
  268. // Cancel must be called first so do wait to add to Cleanup
  269. // stack as last cleanup.
  270. ctx, cancel := context.WithCancel(context.Background())
  271. deadline, ok := t.Deadline()
  272. if ok {
  273. ctx, cancel = context.WithDeadline(ctx, deadline.Add(-100*time.Millisecond))
  274. }
  275. t.Logf(">> minio: minio server %s", dir)
  276. addr := availableAddr()
  277. cmd := exec.CommandContext(ctx, "minio", "server", "--address", addr, dir)
  278. cmd.Env = os.Environ()
  279. cmd.WaitDelay = 3 * time.Second
  280. cmd.Cancel = func() error {
  281. return cmd.Process.Signal(syscall.SIGQUIT)
  282. }
  283. if err := cmd.Start(); err != nil {
  284. t.Fatalf("startMinio: %v", err)
  285. }
  286. t.Cleanup(func() {
  287. cancel()
  288. waitAndMaybeLogError(cmd)
  289. })
  290. mc, err := minio.New(addr, &minio.Options{
  291. Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""),
  292. Secure: false,
  293. })
  294. if err != nil {
  295. t.Fatalf("startMinio: %v", err)
  296. }
  297. // wait for server to start with exponential backoff
  298. for _, err := range backoff.Upto(ctx, 1*time.Second) {
  299. if err != nil {
  300. t.Fatalf("startMinio: %v", err)
  301. }
  302. // try list buckets to see if server is up
  303. if _, err := mc.ListBuckets(ctx); err == nil {
  304. break
  305. }
  306. t.Logf("startMinio: server is offline; retrying")
  307. }
  308. if trace {
  309. cmd := exec.CommandContext(ctx, "mc", "admin", "trace", "--verbose", "test")
  310. cmd.Env = append(os.Environ(),
  311. "MC_HOST_test=http://minioadmin:minioadmin@"+addr,
  312. )
  313. cmd.WaitDelay = 3 * time.Second
  314. cmd.Cancel = func() error {
  315. return cmd.Process.Signal(syscall.SIGQUIT)
  316. }
  317. stdout, err := cmd.StdoutPipe()
  318. if err != nil {
  319. t.Fatalf("startMinio: %v", err)
  320. }
  321. if err := cmd.Start(); err != nil {
  322. t.Fatalf("startMinio: %v", err)
  323. }
  324. doneLogging := make(chan struct{})
  325. sc := bufio.NewScanner(stdout)
  326. go func() {
  327. defer close(doneLogging)
  328. // Scan lines until the process exits.
  329. for sc.Scan() {
  330. t.Logf("startMinio: mc trace: %s", sc.Text())
  331. }
  332. _ = sc.Err() // ignore (not important)
  333. }()
  334. t.Cleanup(func() {
  335. cancel()
  336. waitAndMaybeLogError(cmd)
  337. // Make sure we do not log after test exists to
  338. // avoid panic.
  339. <-doneLogging
  340. })
  341. }
  342. if err := mc.MakeBucket(context.Background(), "test", minio.MakeBucketOptions{}); err != nil {
  343. t.Fatalf("startMinio: %v", err)
  344. }
  345. return mc
  346. }
  347. // contextForTest returns a context that is canceled when the test deadline,
  348. // if any, is reached. The returned doneLogging function should be called
  349. // after all Log/Error/Fatalf calls are done before the test returns.
  350. func contextForTest(t *testing.T) (_ context.Context, doneLogging func()) {
  351. done := make(chan struct{})
  352. deadline, ok := t.Deadline()
  353. if !ok {
  354. return context.Background(), func() {}
  355. }
  356. ctx, cancel := context.WithDeadline(context.Background(), deadline.Add(-100*time.Millisecond))
  357. t.Cleanup(func() {
  358. cancel()
  359. <-done
  360. })
  361. return ctx, func() { close(done) }
  362. }
  363. // abcReader repeats the string s infinitely.
  364. type abcReader struct {
  365. pos int
  366. }
  367. const theABCs = "abcdefghijklmnopqrstuvwxyz"
  368. func (r *abcReader) Read(p []byte) (n int, err error) {
  369. for i := range p {
  370. p[i] = theABCs[r.pos]
  371. r.pos++
  372. if r.pos == len(theABCs) {
  373. r.pos = 0
  374. }
  375. }
  376. return len(p), nil
  377. }
  378. func (r *abcReader) ReadAt(p []byte, off int64) (n int, err error) {
  379. for i := range p {
  380. p[i] = theABCs[(off+int64(i))%int64(len(theABCs))]
  381. }
  382. return len(p), nil
  383. }
  384. func checkABCs(r io.Reader, size int) (reason string) {
  385. h := sha256.New()
  386. n, err := io.CopyN(h, &abcReader{}, int64(size))
  387. if err != nil {
  388. return err.Error()
  389. }
  390. if n != int64(size) {
  391. panic("short read; should not happen")
  392. }
  393. want := h.Sum(nil)
  394. h = sha256.New()
  395. n, err = io.Copy(h, r)
  396. if err != nil {
  397. return err.Error()
  398. }
  399. if n != int64(size) {
  400. return fmt.Sprintf("got len(r) = %d; want %d", n, size)
  401. }
  402. got := h.Sum(nil)
  403. if !bytes.Equal(got, want) {
  404. return fmt.Sprintf("got sum = %x; want %x", got, want)
  405. }
  406. return ""
  407. }