|
@@ -24,6 +24,8 @@ import (
|
|
"github.com/jmorganca/ollama/vector"
|
|
"github.com/jmorganca/ollama/vector"
|
|
)
|
|
)
|
|
|
|
|
|
|
|
+const MaxRetries = 3
|
|
|
|
+
|
|
type RegistryOptions struct {
|
|
type RegistryOptions struct {
|
|
Insecure bool
|
|
Insecure bool
|
|
Username string
|
|
Username string
|
|
@@ -856,7 +858,7 @@ func DeleteModel(name string) error {
|
|
return nil
|
|
return nil
|
|
}
|
|
}
|
|
|
|
|
|
-func PushModel(name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
|
|
|
|
|
+func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
|
mp := ParseModelPath(name)
|
|
mp := ParseModelPath(name)
|
|
|
|
|
|
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
|
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
|
@@ -872,7 +874,7 @@ func PushModel(name string, regOpts *RegistryOptions, fn func(api.ProgressRespon
|
|
layers = append(layers, &manifest.Config)
|
|
layers = append(layers, &manifest.Config)
|
|
|
|
|
|
for _, layer := range layers {
|
|
for _, layer := range layers {
|
|
- exists, err := checkBlobExistence(mp, layer.Digest, regOpts)
|
|
|
|
|
|
+ exists, err := checkBlobExistence(ctx, mp, layer.Digest, regOpts)
|
|
if err != nil {
|
|
if err != nil {
|
|
return err
|
|
return err
|
|
}
|
|
}
|
|
@@ -894,13 +896,13 @@ func PushModel(name string, regOpts *RegistryOptions, fn func(api.ProgressRespon
|
|
Total: layer.Size,
|
|
Total: layer.Size,
|
|
})
|
|
})
|
|
|
|
|
|
- location, err := startUpload(mp, regOpts)
|
|
|
|
|
|
+ location, err := startUpload(ctx, mp, regOpts)
|
|
if err != nil {
|
|
if err != nil {
|
|
log.Printf("couldn't start upload: %v", err)
|
|
log.Printf("couldn't start upload: %v", err)
|
|
return err
|
|
return err
|
|
}
|
|
}
|
|
|
|
|
|
- err = uploadBlobChunked(mp, location, layer, regOpts, fn)
|
|
|
|
|
|
+ err = uploadBlobChunked(ctx, mp, location, layer, regOpts, fn)
|
|
if err != nil {
|
|
if err != nil {
|
|
log.Printf("error uploading blob: %v", err)
|
|
log.Printf("error uploading blob: %v", err)
|
|
return err
|
|
return err
|
|
@@ -918,7 +920,7 @@ func PushModel(name string, regOpts *RegistryOptions, fn func(api.ProgressRespon
|
|
return err
|
|
return err
|
|
}
|
|
}
|
|
|
|
|
|
- resp, err := makeRequest("PUT", url, headers, bytes.NewReader(manifestJSON), regOpts)
|
|
|
|
|
|
+ resp, err := makeRequest(ctx, "PUT", url, headers, bytes.NewReader(manifestJSON), regOpts)
|
|
if err != nil {
|
|
if err != nil {
|
|
return err
|
|
return err
|
|
}
|
|
}
|
|
@@ -940,7 +942,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
|
|
|
|
|
|
fn(api.ProgressResponse{Status: "pulling manifest"})
|
|
fn(api.ProgressResponse{Status: "pulling manifest"})
|
|
|
|
|
|
- manifest, err := pullModelManifest(mp, regOpts)
|
|
|
|
|
|
+ manifest, err := pullModelManifest(ctx, mp, regOpts)
|
|
if err != nil {
|
|
if err != nil {
|
|
return fmt.Errorf("pull model manifest: %s", err)
|
|
return fmt.Errorf("pull model manifest: %s", err)
|
|
}
|
|
}
|
|
@@ -996,13 +998,13 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
|
|
return nil
|
|
return nil
|
|
}
|
|
}
|
|
|
|
|
|
-func pullModelManifest(mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) {
|
|
|
|
|
|
+func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) {
|
|
url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
|
|
url := fmt.Sprintf("%s/v2/%s/manifests/%s", mp.Registry, mp.GetNamespaceRepository(), mp.Tag)
|
|
headers := map[string]string{
|
|
headers := map[string]string{
|
|
"Accept": "application/vnd.docker.distribution.manifest.v2+json",
|
|
"Accept": "application/vnd.docker.distribution.manifest.v2+json",
|
|
}
|
|
}
|
|
|
|
|
|
- resp, err := makeRequest("GET", url, headers, nil, regOpts)
|
|
|
|
|
|
+ resp, err := makeRequest(ctx, "GET", url, headers, nil, regOpts)
|
|
if err != nil {
|
|
if err != nil {
|
|
log.Printf("couldn't get manifest: %v", err)
|
|
log.Printf("couldn't get manifest: %v", err)
|
|
return nil, err
|
|
return nil, err
|
|
@@ -1061,10 +1063,10 @@ func GetSHA256Digest(r io.Reader) (string, int) {
|
|
return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n)
|
|
return fmt.Sprintf("sha256:%x", h.Sum(nil)), int(n)
|
|
}
|
|
}
|
|
|
|
|
|
-func startUpload(mp ModelPath, regOpts *RegistryOptions) (string, error) {
|
|
|
|
|
|
+func startUpload(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (string, error) {
|
|
url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", mp.Registry, mp.GetNamespaceRepository())
|
|
url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", mp.Registry, mp.GetNamespaceRepository())
|
|
|
|
|
|
- resp, err := makeRequest("POST", url, nil, nil, regOpts)
|
|
|
|
|
|
+ resp, err := makeRequest(ctx, "POST", url, nil, nil, regOpts)
|
|
if err != nil {
|
|
if err != nil {
|
|
log.Printf("couldn't start upload: %v", err)
|
|
log.Printf("couldn't start upload: %v", err)
|
|
return "", err
|
|
return "", err
|
|
@@ -1087,10 +1089,10 @@ func startUpload(mp ModelPath, regOpts *RegistryOptions) (string, error) {
|
|
}
|
|
}
|
|
|
|
|
|
// Function to check if a blob already exists in the Docker registry
|
|
// Function to check if a blob already exists in the Docker registry
|
|
-func checkBlobExistence(mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) {
|
|
|
|
|
|
+func checkBlobExistence(ctx context.Context, mp ModelPath, digest string, regOpts *RegistryOptions) (bool, error) {
|
|
url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), digest)
|
|
url := fmt.Sprintf("%s/v2/%s/blobs/%s", mp.Registry, mp.GetNamespaceRepository(), digest)
|
|
|
|
|
|
- resp, err := makeRequest("HEAD", url, nil, nil, regOpts)
|
|
|
|
|
|
+ resp, err := makeRequest(ctx, "HEAD", url, nil, nil, regOpts)
|
|
if err != nil {
|
|
if err != nil {
|
|
log.Printf("couldn't check for blob: %v", err)
|
|
log.Printf("couldn't check for blob: %v", err)
|
|
return false, err
|
|
return false, err
|
|
@@ -1101,7 +1103,7 @@ func checkBlobExistence(mp ModelPath, digest string, regOpts *RegistryOptions) (
|
|
return resp.StatusCode == http.StatusOK, nil
|
|
return resp.StatusCode == http.StatusOK, nil
|
|
}
|
|
}
|
|
|
|
|
|
-func uploadBlobChunked(mp ModelPath, url string, layer *Layer, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
|
|
|
|
|
+func uploadBlobChunked(ctx context.Context, mp ModelPath, url string, layer *Layer, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
|
|
// TODO allow resumability
|
|
// TODO allow resumability
|
|
// TODO allow canceling uploads via DELETE
|
|
// TODO allow canceling uploads via DELETE
|
|
// TODO allow cross repo blob mount
|
|
// TODO allow cross repo blob mount
|
|
@@ -1158,7 +1160,7 @@ func uploadBlobChunked(mp ModelPath, url string, layer *Layer, regOpts *Registry
|
|
headers["Content-Length"] = strconv.Itoa(int(layer.Size))
|
|
headers["Content-Length"] = strconv.Itoa(int(layer.Size))
|
|
|
|
|
|
// finish the upload
|
|
// finish the upload
|
|
- resp, err := makeRequest("PUT", url, headers, r, regOpts)
|
|
|
|
|
|
+ resp, err := makeRequest(ctx, "PUT", url, headers, r, regOpts)
|
|
if err != nil {
|
|
if err != nil {
|
|
log.Printf("couldn't finish upload: %v", err)
|
|
log.Printf("couldn't finish upload: %v", err)
|
|
return err
|
|
return err
|
|
@@ -1172,7 +1174,16 @@ func uploadBlobChunked(mp ModelPath, url string, layer *Layer, regOpts *Registry
|
|
return nil
|
|
return nil
|
|
}
|
|
}
|
|
|
|
|
|
-func makeRequest(method, url string, headers map[string]string, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
|
|
|
|
|
|
+func makeRequest(ctx context.Context, method, url string, headers map[string]string, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
|
|
|
|
+ retryCtx := ctx.Value("retries")
|
|
|
|
+ var retries int
|
|
|
|
+ var ok bool
|
|
|
|
+ if retries, ok = retryCtx.(int); ok {
|
|
|
|
+ if retries > MaxRetries {
|
|
|
|
+ return nil, fmt.Errorf("Maximum retries hit; are you sure you have access to this resource?")
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
if !strings.HasPrefix(url, "http") {
|
|
if !strings.HasPrefix(url, "http") {
|
|
if regOpts.Insecure {
|
|
if regOpts.Insecure {
|
|
url = "http://" + url
|
|
url = "http://" + url
|
|
@@ -1225,13 +1236,14 @@ func makeRequest(method, url string, headers map[string]string, body io.Reader,
|
|
if resp.StatusCode == http.StatusUnauthorized {
|
|
if resp.StatusCode == http.StatusUnauthorized {
|
|
auth := resp.Header.Get("Www-Authenticate")
|
|
auth := resp.Header.Get("Www-Authenticate")
|
|
authRedir := ParseAuthRedirectString(string(auth))
|
|
authRedir := ParseAuthRedirectString(string(auth))
|
|
- token, err := getAuthToken(authRedir, regOpts)
|
|
|
|
|
|
+ token, err := getAuthToken(ctx, authRedir, regOpts)
|
|
if err != nil {
|
|
if err != nil {
|
|
return nil, err
|
|
return nil, err
|
|
}
|
|
}
|
|
regOpts.Token = token
|
|
regOpts.Token = token
|
|
bodyCopy = bytes.NewReader(buf.Bytes())
|
|
bodyCopy = bytes.NewReader(buf.Bytes())
|
|
- return makeRequest(method, url, headers, bodyCopy, regOpts)
|
|
|
|
|
|
+ ctx = context.WithValue(ctx, "retries", retries+1)
|
|
|
|
+ return makeRequest(ctx, method, url, headers, bodyCopy, regOpts)
|
|
}
|
|
}
|
|
|
|
|
|
return resp, nil
|
|
return resp, nil
|