Michael Yang 1 год назад
Родитель
Сommit
e43648afe5
9 измененных файлов с 224 добавлено и 251 удалено
  1. 30 8
      app/lifecycle/updater.go
  2. 1 1
      app/ollama.iss
  3. 13 140
      auth/auth.go
  4. 0 72
      auth/request.go
  5. 95 0
      server/auth.go
  6. 5 6
      server/download.go
  7. 71 12
      server/images.go
  8. 2 3
      server/routes.go
  9. 7 9
      server/upload.go

+ 30 - 8
app/lifecycle/updater.go

@@ -2,6 +2,7 @@ package lifecycle
 
 import (
 	"context"
+	"crypto/rand"
 	"encoding/json"
 	"errors"
 	"fmt"
@@ -9,6 +10,7 @@ import (
 	"log/slog"
 	"mime"
 	"net/http"
+	"net/url"
 	"os"
 	"path"
 	"path/filepath"
@@ -21,7 +23,7 @@ import (
 )
 
 var (
-	UpdateCheckURLBase = "https://ollama.ai/api/update"
+	UpdateCheckURLBase = "https://ollama.com/api/update"
 	UpdateDownloaded   = false
 )
 
@@ -47,22 +49,42 @@ func getClient(req *http.Request) http.Client {
 
 func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) {
 	var updateResp UpdateResponse
-	updateCheckURL := UpdateCheckURLBase + "?os=" + runtime.GOOS + "&arch=" + runtime.GOARCH + "&version=" + version.Version
-	headers := make(http.Header)
-	err := auth.SignRequest(http.MethodGet, updateCheckURL, nil, headers)
+
+	requestURL, err := url.Parse(UpdateCheckURLBase)
 	if err != nil {
-		slog.Info(fmt.Sprintf("failed to sign update request %s", err))
+		return false, updateResp
+	}
+
+	query := requestURL.Query()
+	query.Add("os", runtime.GOOS)
+	query.Add("arch", runtime.GOARCH)
+	query.Add("version", version.Version)
+	query.Add("ts", fmt.Sprintf("%d", time.Now().Unix()))
+
+	nonce, err := auth.NewNonce(rand.Reader, 16)
+	if err != nil {
+		return false, updateResp
 	}
-	req, err := http.NewRequestWithContext(ctx, http.MethodGet, updateCheckURL, nil)
+
+	query.Add("nonce", nonce)
+	requestURL.RawQuery = query.Encode()
+
+	data := []byte(fmt.Sprintf("%s,%s", http.MethodGet, requestURL.RequestURI()))
+	signature, err := auth.Sign(ctx, data)
+	if err != nil {
+		return false, updateResp
+	}
+
+	req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil)
 	if err != nil {
 		slog.Warn(fmt.Sprintf("failed to check for update: %s", err))
 		return false, updateResp
 	}
-	req.Header = headers
+	req.Header.Set("Authorization", signature)
 	req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
 	client := getClient(req)
 
-	slog.Debug(fmt.Sprintf("checking for available update at %s with headers %v", updateCheckURL, headers))
+	slog.Debug(fmt.Sprintf("checking for available update at %s with headers %v", requestURL, req.Header))
 	resp, err := client.Do(req)
 	if err != nil {
 		slog.Warn(fmt.Sprintf("failed to check for update: %s", err))

+ 1 - 1
app/ollama.iss

@@ -12,7 +12,7 @@
   #define MyAppVersion "0.0.0"
 #endif
 #define MyAppPublisher "Ollama, Inc."
-#define MyAppURL "https://ollama.ai/"
+#define MyAppURL "https://ollama.com/"
 #define MyAppExeName "ollama app.exe"
 #define MyIcon ".\assets\app.ico"
 

+ 13 - 140
auth/auth.go

@@ -4,185 +4,58 @@ import (
 	"bytes"
 	"context"
 	"crypto/rand"
-	"crypto/sha256"
 	"encoding/base64"
-	"encoding/hex"
-	"encoding/json"
 	"fmt"
 	"io"
 	"log/slog"
-	"net/http"
-	"net/url"
 	"os"
 	"path/filepath"
-	"strconv"
-	"strings"
-	"time"
 
 	"golang.org/x/crypto/ssh"
-
-	"github.com/jmorganca/ollama/api"
-)
-
-const (
-	KeyType = "id_ed25519"
 )
 
-type AuthRedirect struct {
-	Realm   string
-	Service string
-	Scope   string
-}
+const defaultPrivateKey = "id_ed25519"
 
-type SignatureData struct {
-	Method string
-	Path   string
-	Data   []byte
-}
-
-func generateNonce(length int) (string, error) {
+func NewNonce(r io.Reader, length int) (string, error) {
 	nonce := make([]byte, length)
-	_, err := rand.Read(nonce)
-	if err != nil {
+	if _, err := io.ReadFull(r, nonce); err != nil {
 		return "", err
 	}
-	return base64.RawURLEncoding.EncodeToString(nonce), nil
-}
-
-func (r AuthRedirect) URL() (*url.URL, error) {
-	redirectURL, err := url.Parse(r.Realm)
-	if err != nil {
-		return nil, err
-	}
-
-	values := redirectURL.Query()
-
-	values.Add("service", r.Service)
-
-	for _, s := range strings.Split(r.Scope, " ") {
-		values.Add("scope", s)
-	}
 
-	values.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
-
-	nonce, err := generateNonce(16)
-	if err != nil {
-		return nil, err
-	}
-	values.Add("nonce", nonce)
-
-	redirectURL.RawQuery = values.Encode()
-	return redirectURL, nil
+	return base64.RawURLEncoding.EncodeToString(nonce), nil
 }
 
-func SignRequest(method, url string, data []byte, headers http.Header) error {
+func Sign(ctx context.Context, bts []byte) (string, error) {
 	home, err := os.UserHomeDir()
 	if err != nil {
-		return err
-	}
-
-	keyPath := filepath.Join(home, ".ollama", KeyType)
-
-	rawKey, err := os.ReadFile(keyPath)
-	if err != nil {
-		slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
-		return err
-	}
-
-	s := SignatureData{
-		Method: method,
-		Path:   url,
-		Data:   data,
-	}
-
-	sig, err := s.Sign(rawKey)
-	if err != nil {
-		return err
-	}
-
-	headers.Set("Authorization", sig)
-	return nil
-}
-
-func GetAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) {
-	redirectURL, err := redirData.URL()
-	if err != nil {
-		return "", err
-	}
-
-	headers := make(http.Header)
-	err = SignRequest(http.MethodGet, redirectURL.String(), nil, headers)
-	if err != nil {
-		return "", err
-	}
-	resp, err := MakeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil)
-	if err != nil {
-		slog.Info(fmt.Sprintf("couldn't get token: %q", err))
 		return "", err
 	}
-	defer resp.Body.Close()
-
-	if resp.StatusCode >= http.StatusBadRequest {
-		responseBody, err := io.ReadAll(resp.Body)
-		if err != nil {
-			return "", fmt.Errorf("%d: %v", resp.StatusCode, err)
-		} else if len(responseBody) > 0 {
-			return "", fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
-		}
 
-		return "", fmt.Errorf("%s", resp.Status)
-	}
+	keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
 
-	respBody, err := io.ReadAll(resp.Body)
+	privateKeyFile, err := os.ReadFile(keyPath)
 	if err != nil {
+		slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
 		return "", err
 	}
 
-	var tok api.TokenResponse
-	if err := json.Unmarshal(respBody, &tok); err != nil {
-		return "", err
-	}
-
-	return tok.Token, nil
-}
-
-// Bytes returns a byte slice of the data to sign for the request
-func (s SignatureData) Bytes() []byte {
-	// We first derive the content hash of the request body using:
-	//     base64(hex(sha256(request body)))
-
-	hash := sha256.Sum256(s.Data)
-	hashHex := make([]byte, hex.EncodedLen(len(hash)))
-	hex.Encode(hashHex, hash[:])
-	contentHash := base64.StdEncoding.EncodeToString(hashHex)
-
-	// We then put the entire request together in a serialize string using:
-	//       "<method>,<uri>,<content hash>"
-	// e.g.  "GET,http://localhost,OTdkZjM1O..."
-
-	return []byte(strings.Join([]string{s.Method, s.Path, contentHash}, ","))
-}
-
-// SignData takes a SignatureData object and signs it with a raw private key
-func (s SignatureData) Sign(rawKey []byte) (string, error) {
-	signer, err := ssh.ParsePrivateKey(rawKey)
+	privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
 	if err != nil {
 		return "", err
 	}
 
 	// get the pubkey, but remove the type
-	pubKey := ssh.MarshalAuthorizedKey(signer.PublicKey())
-	parts := bytes.Split(pubKey, []byte(" "))
+	publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
+	parts := bytes.Split(publicKey, []byte(" "))
 	if len(parts) < 2 {
 		return "", fmt.Errorf("malformed public key")
 	}
 
-	signedData, err := signer.Sign(nil, s.Bytes())
+	signedData, err := privateKey.Sign(rand.Reader, bts)
 	if err != nil {
 		return "", err
 	}
 
 	// signature is <pubkey>:<signature>
-	sig := fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob))
-	return sig, nil
+	return fmt.Sprintf("%s:%s", bytes.TrimSpace(parts[1]), base64.StdEncoding.EncodeToString(signedData.Blob)), nil
 }

+ 0 - 72
auth/request.go

@@ -1,72 +0,0 @@
-package auth
-
-import (
-	"context"
-	"fmt"
-	"io"
-	"net/http"
-	"net/url"
-	"runtime"
-	"strconv"
-
-	"github.com/jmorganca/ollama/version"
-)
-
-type RegistryOptions struct {
-	Insecure bool
-	Username string
-	Password string
-	Token    string
-}
-
-func MakeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
-	if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure {
-		requestURL.Scheme = "http"
-	}
-
-	req, err := http.NewRequestWithContext(ctx, method, requestURL.String(), body)
-	if err != nil {
-		return nil, err
-	}
-
-	if headers != nil {
-		req.Header = headers
-	}
-
-	if regOpts != nil {
-		if regOpts.Token != "" {
-			req.Header.Set("Authorization", "Bearer "+regOpts.Token)
-		} else if regOpts.Username != "" && regOpts.Password != "" {
-			req.SetBasicAuth(regOpts.Username, regOpts.Password)
-		}
-	}
-
-	req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
-
-	if s := req.Header.Get("Content-Length"); s != "" {
-		contentLength, err := strconv.ParseInt(s, 10, 64)
-		if err != nil {
-			return nil, err
-		}
-
-		req.ContentLength = contentLength
-	}
-
-	proxyURL, err := http.ProxyFromEnvironment(req)
-	if err != nil {
-		return nil, err
-	}
-
-	client := http.Client{
-		Transport: &http.Transport{
-			Proxy: http.ProxyURL(proxyURL),
-		},
-	}
-
-	resp, err := client.Do(req)
-	if err != nil {
-		return nil, err
-	}
-
-	return resp, nil
-}

+ 95 - 0
server/auth.go

@@ -0,0 +1,95 @@
+package server
+
+import (
+	"context"
+	"crypto/rand"
+	"crypto/sha256"
+	"encoding/base64"
+	"encoding/hex"
+	"encoding/json"
+	"fmt"
+	"io"
+	"net/http"
+	"net/url"
+	"strconv"
+	"strings"
+	"time"
+
+	"github.com/jmorganca/ollama/api"
+	"github.com/jmorganca/ollama/auth"
+)
+
+type registryChallenge struct {
+	Realm   string
+	Service string
+	Scope   string
+}
+
+func (r registryChallenge) URL() (*url.URL, error) {
+	redirectURL, err := url.Parse(r.Realm)
+	if err != nil {
+		return nil, err
+	}
+
+	values := redirectURL.Query()
+	values.Add("service", r.Service)
+	for _, s := range strings.Split(r.Scope, " ") {
+		values.Add("scope", s)
+	}
+
+	values.Add("ts", strconv.FormatInt(time.Now().Unix(), 10))
+
+	nonce, err := auth.NewNonce(rand.Reader, 16)
+	if err != nil {
+		return nil, err
+	}
+
+	values.Add("nonce", nonce)
+
+	redirectURL.RawQuery = values.Encode()
+	return redirectURL, nil
+}
+
+func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (string, error) {
+	redirectURL, err := challenge.URL()
+	if err != nil {
+		return "", err
+	}
+
+	sha256sum := sha256.Sum256(nil)
+	data := []byte(fmt.Sprintf("%s,%s,%s", http.MethodGet, redirectURL.String(), base64.StdEncoding.EncodeToString([]byte(hex.EncodeToString(sha256sum[:])))))
+
+	headers := make(http.Header)
+	signature, err := auth.Sign(ctx, data)
+	if err != nil {
+		return "", err
+	}
+
+	headers.Add("Authorization", signature)
+
+	response, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil)
+	if err != nil {
+		return "", err
+	}
+	defer response.Body.Close()
+
+	body, err := io.ReadAll(response.Body)
+	if err != nil {
+		return "", fmt.Errorf("%d: %v", response.StatusCode, err)
+	}
+
+	if response.StatusCode >= http.StatusBadRequest {
+		if len(body) > 0 {
+			return "", fmt.Errorf("%d: %s", response.StatusCode, body)
+		} else {
+			return "", fmt.Errorf("%d", response.StatusCode)
+		}
+	}
+
+	var token api.TokenResponse
+	if err := json.Unmarshal(body, &token); err != nil {
+		return "", err
+	}
+
+	return token.Token, nil
+}

+ 5 - 6
server/download.go

@@ -22,7 +22,6 @@ import (
 	"golang.org/x/sync/errgroup"
 
 	"github.com/jmorganca/ollama/api"
-	"github.com/jmorganca/ollama/auth"
 	"github.com/jmorganca/ollama/format"
 )
 
@@ -86,7 +85,7 @@ func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
 	return n, nil
 }
 
-func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) error {
+func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
 	partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
 	if err != nil {
 		return err
@@ -138,11 +137,11 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *a
 	return nil
 }
 
-func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) {
+func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) {
 	b.err = b.run(ctx, requestURL, opts)
 }
 
-func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) error {
+func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
 	defer blobDownloadManager.Delete(b.Digest)
 	ctx, b.CancelFunc = context.WithCancel(ctx)
 
@@ -211,7 +210,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *auth.
 	return nil
 }
 
-func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *auth.RegistryOptions) error {
+func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *registryOptions) error {
 	g, ctx := errgroup.WithContext(ctx)
 	g.Go(func() error {
 		headers := make(http.Header)
@@ -335,7 +334,7 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
 type downloadOpts struct {
 	mp      ModelPath
 	digest  string
-	regOpts *auth.RegistryOptions
+	regOpts *registryOptions
 	fn      func(api.ProgressResponse)
 }
 

+ 71 - 12
server/images.go

@@ -16,17 +16,25 @@ import (
 	"os"
 	"path/filepath"
 	"runtime"
+	"strconv"
 	"strings"
 	"text/template"
 
 	"golang.org/x/exp/slices"
 
 	"github.com/jmorganca/ollama/api"
-	"github.com/jmorganca/ollama/auth"
 	"github.com/jmorganca/ollama/llm"
 	"github.com/jmorganca/ollama/parser"
+	"github.com/jmorganca/ollama/version"
 )
 
+type registryOptions struct {
+	Insecure bool
+	Username string
+	Password string
+	Token    string
+}
+
 type Model struct {
 	Name           string `json:"name"`
 	Config         ConfigV2
@@ -312,7 +320,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
 				switch {
 				case errors.Is(err, os.ErrNotExist):
 					fn(api.ProgressResponse{Status: "pulling model"})
-					if err := PullModel(ctx, c.Args, &auth.RegistryOptions{}, fn); err != nil {
+					if err := PullModel(ctx, c.Args, &registryOptions{}, fn); err != nil {
 						return err
 					}
 
@@ -832,7 +840,7 @@ PARAMETER {{ $k }} {{ printf "%#v" $parameter }}
 	return buf.String(), nil
 }
 
-func PushModel(ctx context.Context, name string, regOpts *auth.RegistryOptions, fn func(api.ProgressResponse)) error {
+func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
 	mp := ParseModelPath(name)
 	fn(api.ProgressResponse{Status: "retrieving manifest"})
 
@@ -882,7 +890,7 @@ func PushModel(ctx context.Context, name string, regOpts *auth.RegistryOptions,
 	return nil
 }
 
-func PullModel(ctx context.Context, name string, regOpts *auth.RegistryOptions, fn func(api.ProgressResponse)) error {
+func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
 	mp := ParseModelPath(name)
 
 	var manifest *ManifestV2
@@ -988,7 +996,7 @@ func PullModel(ctx context.Context, name string, regOpts *auth.RegistryOptions,
 	return nil
 }
 
-func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *auth.RegistryOptions) (*ManifestV2, error) {
+func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*ManifestV2, error) {
 	requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
 
 	headers := make(http.Header)
@@ -1020,9 +1028,9 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
 
 var errUnauthorized = fmt.Errorf("unauthorized")
 
-func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *auth.RegistryOptions) (*http.Response, error) {
+func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
 	for i := 0; i < 2; i++ {
-		resp, err := auth.MakeRequest(ctx, method, requestURL, headers, body, regOpts)
+		resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
 		if err != nil {
 			if !errors.Is(err, context.Canceled) {
 				slog.Info(fmt.Sprintf("request failed: %v", err))
@@ -1034,9 +1042,8 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
 		switch {
 		case resp.StatusCode == http.StatusUnauthorized:
 			// Handle authentication error with one retry
-			authenticate := resp.Header.Get("www-authenticate")
-			authRedir := ParseAuthRedirectString(authenticate)
-			token, err := auth.GetAuthToken(ctx, authRedir)
+			challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
+			token, err := getAuthorizationToken(ctx, challenge)
 			if err != nil {
 				return nil, err
 			}
@@ -1063,6 +1070,58 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
 	return nil, errUnauthorized
 }
 
+func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *registryOptions) (*http.Response, error) {
+	if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure {
+		requestURL.Scheme = "http"
+	}
+
+	req, err := http.NewRequestWithContext(ctx, method, requestURL.String(), body)
+	if err != nil {
+		return nil, err
+	}
+
+	if headers != nil {
+		req.Header = headers
+	}
+
+	if regOpts != nil {
+		if regOpts.Token != "" {
+			req.Header.Set("Authorization", "Bearer "+regOpts.Token)
+		} else if regOpts.Username != "" && regOpts.Password != "" {
+			req.SetBasicAuth(regOpts.Username, regOpts.Password)
+		}
+	}
+
+	req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
+
+	if s := req.Header.Get("Content-Length"); s != "" {
+		contentLength, err := strconv.ParseInt(s, 10, 64)
+		if err != nil {
+			return nil, err
+		}
+
+		req.ContentLength = contentLength
+	}
+
+	proxyURL, err := http.ProxyFromEnvironment(req)
+	if err != nil {
+		return nil, err
+	}
+
+	client := http.Client{
+		Transport: &http.Transport{
+			Proxy: http.ProxyURL(proxyURL),
+		},
+	}
+
+	resp, err := client.Do(req)
+	if err != nil {
+		return nil, err
+	}
+
+	return resp, nil
+}
+
 func getValue(header, key string) string {
 	startIdx := strings.Index(header, key+"=")
 	if startIdx == -1 {
@@ -1086,10 +1145,10 @@ func getValue(header, key string) string {
 	return header[startIdx:endIdx]
 }
 
-func ParseAuthRedirectString(authStr string) auth.AuthRedirect {
+func parseRegistryChallenge(authStr string) registryChallenge {
 	authStr = strings.TrimPrefix(authStr, "Bearer ")
 
-	return auth.AuthRedirect{
+	return registryChallenge{
 		Realm:   getValue(authStr, "realm"),
 		Service: getValue(authStr, "service"),
 		Scope:   getValue(authStr, "scope"),

+ 2 - 3
server/routes.go

@@ -25,7 +25,6 @@ import (
 	"golang.org/x/exp/slices"
 
 	"github.com/jmorganca/ollama/api"
-	"github.com/jmorganca/ollama/auth"
 	"github.com/jmorganca/ollama/gpu"
 	"github.com/jmorganca/ollama/llm"
 	"github.com/jmorganca/ollama/openai"
@@ -480,7 +479,7 @@ func PullModelHandler(c *gin.Context) {
 			ch <- r
 		}
 
-		regOpts := &auth.RegistryOptions{
+		regOpts := &registryOptions{
 			Insecure: req.Insecure,
 		}
 
@@ -529,7 +528,7 @@ func PushModelHandler(c *gin.Context) {
 			ch <- r
 		}
 
-		regOpts := &auth.RegistryOptions{
+		regOpts := &registryOptions{
 			Insecure: req.Insecure,
 		}
 

+ 7 - 9
server/upload.go

@@ -18,7 +18,6 @@ import (
 	"time"
 
 	"github.com/jmorganca/ollama/api"
-	"github.com/jmorganca/ollama/auth"
 	"github.com/jmorganca/ollama/format"
 	"golang.org/x/sync/errgroup"
 )
@@ -50,7 +49,7 @@ const (
 	maxUploadPartSize int64 = 1000 * format.MegaByte
 )
 
-func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) error {
+func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
 	p, err := GetBlobsPath(b.Digest)
 	if err != nil {
 		return err
@@ -122,7 +121,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *aut
 
 // Run uploads blob parts to the upstream. If the upstream supports redirection, parts will be uploaded
 // in parallel as defined by Prepare. Otherwise, parts will be uploaded serially. Run sets b.err on error.
-func (b *blobUpload) Run(ctx context.Context, opts *auth.RegistryOptions) {
+func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
 	defer blobUploadManager.Delete(b.Digest)
 	ctx, b.CancelFunc = context.WithCancel(ctx)
 
@@ -213,7 +212,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *auth.RegistryOptions) {
 	b.done = true
 }
 
-func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *auth.RegistryOptions) error {
+func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *registryOptions) error {
 	headers := make(http.Header)
 	headers.Set("Content-Type", "application/octet-stream")
 	headers.Set("Content-Length", fmt.Sprintf("%d", part.Size))
@@ -228,7 +227,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
 	md5sum := md5.New()
 	w := &progressWriter{blobUpload: b}
 
-	resp, err := auth.MakeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, io.MultiWriter(w, md5sum)), opts)
+	resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, io.MultiWriter(w, md5sum)), opts)
 	if err != nil {
 		w.Rollback()
 		return err
@@ -278,9 +277,8 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
 
 	case resp.StatusCode == http.StatusUnauthorized:
 		w.Rollback()
-		authenticate := resp.Header.Get("www-authenticate")
-		authRedir := ParseAuthRedirectString(authenticate)
-		token, err := auth.GetAuthToken(ctx, authRedir)
+		challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
+		token, err := getAuthorizationToken(ctx, challenge)
 		if err != nil {
 			return err
 		}
@@ -365,7 +363,7 @@ func (p *progressWriter) Rollback() {
 	p.written = 0
 }
 
-func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *auth.RegistryOptions, fn func(api.ProgressResponse)) error {
+func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
 	requestURL := mp.BaseURL()
 	requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)