client.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. package registry
  2. import (
  3. "cmp"
  4. "context"
  5. "encoding/xml"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "net/http"
  10. "strings"
  11. "github.com/ollama/ollama/x/client/ollama"
  12. "github.com/ollama/ollama/x/registry/apitype"
  13. )
  14. type Client struct {
  15. BaseURL string
  16. HTTPClient *http.Client
  17. }
  18. func (c *Client) oclient() *ollama.Client {
  19. return (*ollama.Client)(c)
  20. }
  21. type PushParams struct {
  22. CompleteParts []apitype.CompletePart
  23. }
  24. // Push pushes a manifest to the server.
  25. func (c *Client) Push(ctx context.Context, ref string, manifest []byte, p *PushParams) ([]apitype.Requirement, error) {
  26. p = cmp.Or(p, &PushParams{})
  27. // TODO(bmizerany): backoff
  28. v, err := ollama.Do[apitype.PushResponse](ctx, c.oclient(), "POST", "/v1/push", &apitype.PushRequest{
  29. Name: ref,
  30. Manifest: manifest,
  31. CompleteParts: p.CompleteParts,
  32. })
  33. if err != nil {
  34. return nil, err
  35. }
  36. return v.Requirements, nil
  37. }
  38. func PushLayer(ctx context.Context, body io.ReaderAt, url string, off, n int64) (apitype.CompletePart, error) {
  39. var zero apitype.CompletePart
  40. if off < 0 {
  41. return zero, errors.New("off must be >0")
  42. }
  43. file := io.NewSectionReader(body, off, n)
  44. req, err := http.NewRequest("PUT", url, file)
  45. if err != nil {
  46. return zero, err
  47. }
  48. req.ContentLength = n
  49. // TODO(bmizerany): take content type param
  50. req.Header.Set("Content-Type", "text/plain")
  51. if n >= 0 {
  52. req.Header.Set("x-amz-copy-source-range", fmt.Sprintf("bytes=%d-%d", off, off+n-1))
  53. }
  54. res, err := http.DefaultClient.Do(req)
  55. if err != nil {
  56. return zero, err
  57. }
  58. defer res.Body.Close()
  59. if res.StatusCode != 200 {
  60. e := parseS3Error(res)
  61. return zero, fmt.Errorf("unexpected status code: %d; %w", res.StatusCode, e)
  62. }
  63. etag := strings.Trim(res.Header.Get("ETag"), `"`)
  64. cp := apitype.CompletePart{
  65. URL: url,
  66. ETag: etag,
  67. // TODO(bmizerany): checksum
  68. }
  69. return cp, nil
  70. }
  71. type s3Error struct {
  72. XMLName xml.Name `xml:"Error"`
  73. Code string `xml:"Code"`
  74. Message string `xml:"Message"`
  75. Resource string `xml:"Resource"`
  76. RequestId string `xml:"RequestId"`
  77. }
  78. func (e *s3Error) Error() string {
  79. return fmt.Sprintf("S3 (%s): %s: %s: %s", e.RequestId, e.Resource, e.Code, e.Message)
  80. }
  81. // parseS3Error parses an XML error response from S3.
  82. func parseS3Error(res *http.Response) error {
  83. var se *s3Error
  84. if err := xml.NewDecoder(res.Body).Decode(&se); err != nil {
  85. return err
  86. }
  87. return se
  88. }