Browse Source

Move hub auth out to new package

Daniel Hiltgen 1 năm trước cách đây
mục cha
commit
f397e0e988
6 tập tin đã thay đổi với 129 bổ sung102 xóa
  1. 27 15
      auth/auth.go
  2. 72 0
      auth/request.go
  3. 6 5
      server/download.go
  4. 12 72
      server/images.go
  5. 3 2
      server/routes.go
  6. 9 8
      server/upload.go

+ 27 - 15
server/auth.go → auth/auth.go

@@ -1,4 +1,4 @@
-package server
+package auth
 
 
 import (
 import (
 	"bytes"
 	"bytes"
@@ -24,6 +24,10 @@ import (
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/api"
 )
 )
 
 
+const (
+	KeyType = "id_ed25519"
+)
+
 type AuthRedirect struct {
 type AuthRedirect struct {
 	Realm   string
 	Realm   string
 	Service string
 	Service string
@@ -71,39 +75,47 @@ func (r AuthRedirect) URL() (*url.URL, error) {
 	return redirectURL, nil
 	return redirectURL, nil
 }
 }
 
 
-func getAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) {
-	redirectURL, err := redirData.URL()
-	if err != nil {
-		return "", err
-	}
-
+func SignRequest(method, url string, data []byte, headers http.Header) error {
 	home, err := os.UserHomeDir()
 	home, err := os.UserHomeDir()
 	if err != nil {
 	if err != nil {
-		return "", err
+		return err
 	}
 	}
 
 
-	keyPath := filepath.Join(home, ".ollama", "id_ed25519")
+	keyPath := filepath.Join(home, ".ollama", KeyType)
 
 
 	rawKey, err := os.ReadFile(keyPath)
 	rawKey, err := os.ReadFile(keyPath)
 	if err != nil {
 	if err != nil {
 		slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
 		slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
-		return "", err
+		return err
 	}
 	}
 
 
 	s := SignatureData{
 	s := SignatureData{
-		Method: http.MethodGet,
-		Path:   redirectURL.String(),
-		Data:   nil,
+		Method: method,
+		Path:   url,
+		Data:   data,
 	}
 	}
 
 
 	sig, err := s.Sign(rawKey)
 	sig, err := s.Sign(rawKey)
+	if err != nil {
+		return err
+	}
+
+	headers.Set("Authorization", sig)
+	return nil
+}
+
+func GetAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) {
+	redirectURL, err := redirData.URL()
 	if err != nil {
 	if err != nil {
 		return "", err
 		return "", err
 	}
 	}
 
 
 	headers := make(http.Header)
 	headers := make(http.Header)
-	headers.Set("Authorization", sig)
-	resp, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil)
+	err = SignRequest(http.MethodGet, redirectURL.String(), nil, headers)
+	if err != nil {
+		return "", err
+	}
+	resp, err := MakeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil)
 	if err != nil {
 	if err != nil {
 		slog.Info(fmt.Sprintf("couldn't get token: %q", err))
 		slog.Info(fmt.Sprintf("couldn't get token: %q", err))
 		return "", err
 		return "", err

+ 72 - 0
auth/request.go

@@ -0,0 +1,72 @@
+package auth
+
+import (
+	"context"
+	"fmt"
+	"io"
+	"net/http"
+	"net/url"
+	"runtime"
+	"strconv"
+
+	"github.com/jmorganca/ollama/version"
+)
+
+type RegistryOptions struct {
+	Insecure bool
+	Username string
+	Password string
+	Token    string
+}
+
+func MakeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
+	if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure {
+		requestURL.Scheme = "http"
+	}
+
+	req, err := http.NewRequestWithContext(ctx, method, requestURL.String(), body)
+	if err != nil {
+		return nil, err
+	}
+
+	if headers != nil {
+		req.Header = headers
+	}
+
+	if regOpts != nil {
+		if regOpts.Token != "" {
+			req.Header.Set("Authorization", "Bearer "+regOpts.Token)
+		} else if regOpts.Username != "" && regOpts.Password != "" {
+			req.SetBasicAuth(regOpts.Username, regOpts.Password)
+		}
+	}
+
+	req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
+
+	if s := req.Header.Get("Content-Length"); s != "" {
+		contentLength, err := strconv.ParseInt(s, 10, 64)
+		if err != nil {
+			return nil, err
+		}
+
+		req.ContentLength = contentLength
+	}
+
+	proxyURL, err := http.ProxyFromEnvironment(req)
+	if err != nil {
+		return nil, err
+	}
+
+	client := http.Client{
+		Transport: &http.Transport{
+			Proxy: http.ProxyURL(proxyURL),
+		},
+	}
+
+	resp, err := client.Do(req)
+	if err != nil {
+		return nil, err
+	}
+
+	return resp, nil
+}

+ 6 - 5
server/download.go

@@ -22,6 +22,7 @@ import (
 	"golang.org/x/sync/errgroup"
 	"golang.org/x/sync/errgroup"
 
 
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/api"
+	"github.com/jmorganca/ollama/auth"
 	"github.com/jmorganca/ollama/format"
 	"github.com/jmorganca/ollama/format"
 )
 )
 
 
@@ -85,7 +86,7 @@ func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
 	return n, nil
 	return n, nil
 }
 }
 
 
-func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
+func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) error {
 	partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
 	partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -137,11 +138,11 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *R
 	return nil
 	return nil
 }
 }
 
 
-func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) {
+func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) {
 	b.err = b.run(ctx, requestURL, opts)
 	b.err = b.run(ctx, requestURL, opts)
 }
 }
 
 
-func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
+func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) error {
 	defer blobDownloadManager.Delete(b.Digest)
 	defer blobDownloadManager.Delete(b.Digest)
 	ctx, b.CancelFunc = context.WithCancel(ctx)
 	ctx, b.CancelFunc = context.WithCancel(ctx)
 
 
@@ -210,7 +211,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
 	return nil
 	return nil
 }
 }
 
 
-func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
+func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *auth.RegistryOptions) error {
 	g, ctx := errgroup.WithContext(ctx)
 	g, ctx := errgroup.WithContext(ctx)
 	g.Go(func() error {
 	g.Go(func() error {
 		headers := make(http.Header)
 		headers := make(http.Header)
@@ -334,7 +335,7 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
 type downloadOpts struct {
 type downloadOpts struct {
 	mp      ModelPath
 	mp      ModelPath
 	digest  string
 	digest  string
-	regOpts *RegistryOptions
+	regOpts *auth.RegistryOptions
 	fn      func(api.ProgressResponse)
 	fn      func(api.ProgressResponse)
 }
 }
 
 

+ 12 - 72
server/images.go

@@ -16,25 +16,17 @@ import (
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
 	"runtime"
 	"runtime"
-	"strconv"
 	"strings"
 	"strings"
 	"text/template"
 	"text/template"
 
 
 	"golang.org/x/exp/slices"
 	"golang.org/x/exp/slices"
 
 
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/api"
+	"github.com/jmorganca/ollama/auth"
 	"github.com/jmorganca/ollama/llm"
 	"github.com/jmorganca/ollama/llm"
 	"github.com/jmorganca/ollama/parser"
 	"github.com/jmorganca/ollama/parser"
-	"github.com/jmorganca/ollama/version"
 )
 )
 
 
-type RegistryOptions struct {
-	Insecure bool
-	Username string
-	Password string
-	Token    string
-}
-
 type Model struct {
 type Model struct {
 	Name           string `json:"name"`
 	Name           string `json:"name"`
 	Config         ConfigV2
 	Config         ConfigV2
@@ -320,7 +312,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
 				switch {
 				switch {
 				case errors.Is(err, os.ErrNotExist):
 				case errors.Is(err, os.ErrNotExist):
 					fn(api.ProgressResponse{Status: "pulling model"})
 					fn(api.ProgressResponse{Status: "pulling model"})
-					if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
+					if err := PullModel(ctx, c.Args, &auth.RegistryOptions{}, fn); err != nil {
 						return err
 						return err
 					}
 					}
 
 
@@ -840,7 +832,7 @@ PARAMETER {{ $k }} {{ printf "%#v" $parameter }}
 	return buf.String(), nil
 	return buf.String(), nil
 }
 }
 
 
-func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
+func PushModel(ctx context.Context, name string, regOpts *auth.RegistryOptions, fn func(api.ProgressResponse)) error {
 	mp := ParseModelPath(name)
 	mp := ParseModelPath(name)
 	fn(api.ProgressResponse{Status: "retrieving manifest"})
 	fn(api.ProgressResponse{Status: "retrieving manifest"})
 
 
@@ -890,7 +882,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
 	return nil
 	return nil
 }
 }
 
 
-func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
+func PullModel(ctx context.Context, name string, regOpts *auth.RegistryOptions, fn func(api.ProgressResponse)) error {
 	mp := ParseModelPath(name)
 	mp := ParseModelPath(name)
 
 
 	var manifest *ManifestV2
 	var manifest *ManifestV2
@@ -996,7 +988,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
 	return nil
 	return nil
 }
 }
 
 
-func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, error) {
+func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *auth.RegistryOptions) (*ManifestV2, error) {
 	requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
 	requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
 
 
 	headers := make(http.Header)
 	headers := make(http.Header)
@@ -1028,9 +1020,9 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
 
 
 var errUnauthorized = fmt.Errorf("unauthorized")
 var errUnauthorized = fmt.Errorf("unauthorized")
 
 
-func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
+func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *auth.RegistryOptions) (*http.Response, error) {
 	for i := 0; i < 2; i++ {
 	for i := 0; i < 2; i++ {
-		resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
+		resp, err := auth.MakeRequest(ctx, method, requestURL, headers, body, regOpts)
 		if err != nil {
 		if err != nil {
 			if !errors.Is(err, context.Canceled) {
 			if !errors.Is(err, context.Canceled) {
 				slog.Info(fmt.Sprintf("request failed: %v", err))
 				slog.Info(fmt.Sprintf("request failed: %v", err))
@@ -1042,9 +1034,9 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
 		switch {
 		switch {
 		case resp.StatusCode == http.StatusUnauthorized:
 		case resp.StatusCode == http.StatusUnauthorized:
 			// Handle authentication error with one retry
 			// Handle authentication error with one retry
-			auth := resp.Header.Get("www-authenticate")
-			authRedir := ParseAuthRedirectString(auth)
-			token, err := getAuthToken(ctx, authRedir)
+			authenticate := resp.Header.Get("www-authenticate")
+			authRedir := ParseAuthRedirectString(authenticate)
+			token, err := auth.GetAuthToken(ctx, authRedir)
 			if err != nil {
 			if err != nil {
 				return nil, err
 				return nil, err
 			}
 			}
@@ -1071,58 +1063,6 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
 	return nil, errUnauthorized
 	return nil, errUnauthorized
 }
 }
 
 
-func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {
-	if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure {
-		requestURL.Scheme = "http"
-	}
-
-	req, err := http.NewRequestWithContext(ctx, method, requestURL.String(), body)
-	if err != nil {
-		return nil, err
-	}
-
-	if headers != nil {
-		req.Header = headers
-	}
-
-	if regOpts != nil {
-		if regOpts.Token != "" {
-			req.Header.Set("Authorization", "Bearer "+regOpts.Token)
-		} else if regOpts.Username != "" && regOpts.Password != "" {
-			req.SetBasicAuth(regOpts.Username, regOpts.Password)
-		}
-	}
-
-	req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
-
-	if s := req.Header.Get("Content-Length"); s != "" {
-		contentLength, err := strconv.ParseInt(s, 10, 64)
-		if err != nil {
-			return nil, err
-		}
-
-		req.ContentLength = contentLength
-	}
-
-	proxyURL, err := http.ProxyFromEnvironment(req)
-	if err != nil {
-		return nil, err
-	}
-
-	client := http.Client{
-		Transport: &http.Transport{
-			Proxy: http.ProxyURL(proxyURL),
-		},
-	}
-
-	resp, err := client.Do(req)
-	if err != nil {
-		return nil, err
-	}
-
-	return resp, nil
-}
-
 func getValue(header, key string) string {
 func getValue(header, key string) string {
 	startIdx := strings.Index(header, key+"=")
 	startIdx := strings.Index(header, key+"=")
 	if startIdx == -1 {
 	if startIdx == -1 {
@@ -1146,10 +1086,10 @@ func getValue(header, key string) string {
 	return header[startIdx:endIdx]
 	return header[startIdx:endIdx]
 }
 }
 
 
-func ParseAuthRedirectString(authStr string) AuthRedirect {
+func ParseAuthRedirectString(authStr string) auth.AuthRedirect {
 	authStr = strings.TrimPrefix(authStr, "Bearer ")
 	authStr = strings.TrimPrefix(authStr, "Bearer ")
 
 
-	return AuthRedirect{
+	return auth.AuthRedirect{
 		Realm:   getValue(authStr, "realm"),
 		Realm:   getValue(authStr, "realm"),
 		Service: getValue(authStr, "service"),
 		Service: getValue(authStr, "service"),
 		Scope:   getValue(authStr, "scope"),
 		Scope:   getValue(authStr, "scope"),

+ 3 - 2
server/routes.go

@@ -25,6 +25,7 @@ import (
 	"golang.org/x/exp/slices"
 	"golang.org/x/exp/slices"
 
 
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/api"
+	"github.com/jmorganca/ollama/auth"
 	"github.com/jmorganca/ollama/gpu"
 	"github.com/jmorganca/ollama/gpu"
 	"github.com/jmorganca/ollama/llm"
 	"github.com/jmorganca/ollama/llm"
 	"github.com/jmorganca/ollama/openai"
 	"github.com/jmorganca/ollama/openai"
@@ -479,7 +480,7 @@ func PullModelHandler(c *gin.Context) {
 			ch <- r
 			ch <- r
 		}
 		}
 
 
-		regOpts := &RegistryOptions{
+		regOpts := &auth.RegistryOptions{
 			Insecure: req.Insecure,
 			Insecure: req.Insecure,
 		}
 		}
 
 
@@ -528,7 +529,7 @@ func PushModelHandler(c *gin.Context) {
 			ch <- r
 			ch <- r
 		}
 		}
 
 
-		regOpts := &RegistryOptions{
+		regOpts := &auth.RegistryOptions{
 			Insecure: req.Insecure,
 			Insecure: req.Insecure,
 		}
 		}
 
 

+ 9 - 8
server/upload.go

@@ -18,6 +18,7 @@ import (
 	"time"
 	"time"
 
 
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/api"
+	"github.com/jmorganca/ollama/auth"
 	"github.com/jmorganca/ollama/format"
 	"github.com/jmorganca/ollama/format"
 	"golang.org/x/sync/errgroup"
 	"golang.org/x/sync/errgroup"
 )
 )
@@ -49,7 +50,7 @@ const (
 	maxUploadPartSize int64 = 1000 * format.MegaByte
 	maxUploadPartSize int64 = 1000 * format.MegaByte
 )
 )
 
 
-func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
+func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *auth.RegistryOptions) error {
 	p, err := GetBlobsPath(b.Digest)
 	p, err := GetBlobsPath(b.Digest)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
@@ -121,7 +122,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *Reg
 
 
 // Run uploads blob parts to the upstream. If the upstream supports redirection, parts will be uploaded
 // Run uploads blob parts to the upstream. If the upstream supports redirection, parts will be uploaded
 // in parallel as defined by Prepare. Otherwise, parts will be uploaded serially. Run sets b.err on error.
 // in parallel as defined by Prepare. Otherwise, parts will be uploaded serially. Run sets b.err on error.
-func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
+func (b *blobUpload) Run(ctx context.Context, opts *auth.RegistryOptions) {
 	defer blobUploadManager.Delete(b.Digest)
 	defer blobUploadManager.Delete(b.Digest)
 	ctx, b.CancelFunc = context.WithCancel(ctx)
 	ctx, b.CancelFunc = context.WithCancel(ctx)
 
 
@@ -212,7 +213,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
 	b.done = true
 	b.done = true
 }
 }
 
 
-func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *RegistryOptions) error {
+func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *url.URL, part *blobUploadPart, opts *auth.RegistryOptions) error {
 	headers := make(http.Header)
 	headers := make(http.Header)
 	headers.Set("Content-Type", "application/octet-stream")
 	headers.Set("Content-Type", "application/octet-stream")
 	headers.Set("Content-Length", fmt.Sprintf("%d", part.Size))
 	headers.Set("Content-Length", fmt.Sprintf("%d", part.Size))
@@ -227,7 +228,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
 	md5sum := md5.New()
 	md5sum := md5.New()
 	w := &progressWriter{blobUpload: b}
 	w := &progressWriter{blobUpload: b}
 
 
-	resp, err := makeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, io.MultiWriter(w, md5sum)), opts)
+	resp, err := auth.MakeRequest(ctx, method, requestURL, headers, io.TeeReader(sr, io.MultiWriter(w, md5sum)), opts)
 	if err != nil {
 	if err != nil {
 		w.Rollback()
 		w.Rollback()
 		return err
 		return err
@@ -277,9 +278,9 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
 
 
 	case resp.StatusCode == http.StatusUnauthorized:
 	case resp.StatusCode == http.StatusUnauthorized:
 		w.Rollback()
 		w.Rollback()
-		auth := resp.Header.Get("www-authenticate")
-		authRedir := ParseAuthRedirectString(auth)
-		token, err := getAuthToken(ctx, authRedir)
+		authenticate := resp.Header.Get("www-authenticate")
+		authRedir := ParseAuthRedirectString(authenticate)
+		token, err := auth.GetAuthToken(ctx, authRedir)
 		if err != nil {
 		if err != nil {
 			return err
 			return err
 		}
 		}
@@ -364,7 +365,7 @@ func (p *progressWriter) Rollback() {
 	p.written = 0
 	p.written = 0
 }
 }
 
 
-func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *RegistryOptions, fn func(api.ProgressResponse)) error {
+func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *auth.RegistryOptions, fn func(api.ProgressResponse)) error {
 	requestURL := mp.BaseURL()
 	requestURL := mp.BaseURL()
 	requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
 	requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)