|
@@ -16,17 +16,25 @@ import (
|
|
|
"os"
|
|
|
"path/filepath"
|
|
|
"runtime"
|
|
|
+ "strconv"
|
|
|
"strings"
|
|
|
"text/template"
|
|
|
|
|
|
"golang.org/x/exp/slices"
|
|
|
|
|
|
"github.com/jmorganca/ollama/api"
|
|
|
- "github.com/jmorganca/ollama/auth"
|
|
|
"github.com/jmorganca/ollama/llm"
|
|
|
"github.com/jmorganca/ollama/parser"
|
|
|
+ "github.com/jmorganca/ollama/version"
|
|
|
)
|
|
|
|
|
|
+type registryOptions struct {
|
|
|
+ Insecure bool
|
|
|
+ Username string
|
|
|
+ Password string
|
|
|
+ Token string
|
|
|
+}
|
|
|
+
|
|
|
type Model struct {
|
|
|
Name string `json:"name"`
|
|
|
Config ConfigV2
|
|
@@ -312,7 +320,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
|
|
switch {
|
|
|
case errors.Is(err, os.ErrNotExist):
|
|
|
fn(api.ProgressResponse{Status: "pulling model"})
|
|
|
- if err := PullModel(ctx, c.Args, &auth.RegistryOptions{}, fn); err != nil {
|
|
|
+ if err := PullModel(ctx, c.Args, ®istryOptions{}, fn); err != nil {
|
|
|
return err
|
|
|
}
|
|
|
|
|
@@ -832,7 +840,7 @@ PARAMETER {{ $k }} {{ printf "%#v" $parameter }}
|
|
|
return buf.String(), nil
|
|
|
}
|
|
|
|
|
|
-func PushModel(ctx context.Context, name string, regOpts *auth.RegistryOptions, fn func(api.ProgressResponse)) error {
|
|
|
+func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
|
|
mp := ParseModelPath(name)
|
|
|
fn(api.ProgressResponse{Status: "retrieving manifest"})
|
|
|
|
|
@@ -882,7 +890,7 @@ func PushModel(ctx context.Context, name string, regOpts *auth.RegistryOptions,
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-func PullModel(ctx context.Context, name string, regOpts *auth.RegistryOptions, fn func(api.ProgressResponse)) error {
|
|
|
+func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
|
|
|
mp := ParseModelPath(name)
|
|
|
|
|
|
var manifest *ManifestV2
|
|
@@ -988,7 +996,7 @@ func PullModel(ctx context.Context, name string, regOpts *auth.RegistryOptions,
|
|
|
return nil
|
|
|
}
|
|
|
|
|
|
-func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *auth.RegistryOptions) (*ManifestV2, error) {
|
|
|
+func pullModelManifest(ctx context.Context, mp ModelPath, regOpts *registryOptions) (*ManifestV2, error) {
|
|
|
requestURL := mp.BaseURL().JoinPath("v2", mp.GetNamespaceRepository(), "manifests", mp.Tag)
|
|
|
|
|
|
headers := make(http.Header)
|
|
@@ -1020,9 +1028,9 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
|
|
|
|
|
|
var errUnauthorized = fmt.Errorf("unauthorized")
|
|
|
|
|
|
-func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *auth.RegistryOptions) (*http.Response, error) {
|
|
|
+func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
|
|
|
for i := 0; i < 2; i++ {
|
|
|
- resp, err := auth.MakeRequest(ctx, method, requestURL, headers, body, regOpts)
|
|
|
+ resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
|
|
if err != nil {
|
|
|
if !errors.Is(err, context.Canceled) {
|
|
|
slog.Info(fmt.Sprintf("request failed: %v", err))
|
|
@@ -1034,9 +1042,8 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
|
|
switch {
|
|
|
case resp.StatusCode == http.StatusUnauthorized:
|
|
|
// Handle authentication error with one retry
|
|
|
- authenticate := resp.Header.Get("www-authenticate")
|
|
|
- authRedir := ParseAuthRedirectString(authenticate)
|
|
|
- token, err := auth.GetAuthToken(ctx, authRedir)
|
|
|
+ challenge := parseRegistryChallenge(resp.Header.Get("www-authenticate"))
|
|
|
+ token, err := getAuthorizationToken(ctx, challenge)
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
@@ -1063,6 +1070,58 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
|
|
return nil, errUnauthorized
|
|
|
}
|
|
|
|
|
|
+func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *registryOptions) (*http.Response, error) {
|
|
|
+ if requestURL.Scheme != "http" && regOpts != nil && regOpts.Insecure {
|
|
|
+ requestURL.Scheme = "http"
|
|
|
+ }
|
|
|
+
|
|
|
+ req, err := http.NewRequestWithContext(ctx, method, requestURL.String(), body)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ if headers != nil {
|
|
|
+ req.Header = headers
|
|
|
+ }
|
|
|
+
|
|
|
+ if regOpts != nil {
|
|
|
+ if regOpts.Token != "" {
|
|
|
+ req.Header.Set("Authorization", "Bearer "+regOpts.Token)
|
|
|
+ } else if regOpts.Username != "" && regOpts.Password != "" {
|
|
|
+ req.SetBasicAuth(regOpts.Username, regOpts.Password)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ req.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
|
|
|
+
|
|
|
+ if s := req.Header.Get("Content-Length"); s != "" {
|
|
|
+ contentLength, err := strconv.ParseInt(s, 10, 64)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ req.ContentLength = contentLength
|
|
|
+ }
|
|
|
+
|
|
|
+ proxyURL, err := http.ProxyFromEnvironment(req)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ client := http.Client{
|
|
|
+ Transport: &http.Transport{
|
|
|
+ Proxy: http.ProxyURL(proxyURL),
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ resp, err := client.Do(req)
|
|
|
+ if err != nil {
|
|
|
+ return nil, err
|
|
|
+ }
|
|
|
+
|
|
|
+ return resp, nil
|
|
|
+}
|
|
|
+
|
|
|
func getValue(header, key string) string {
|
|
|
startIdx := strings.Index(header, key+"=")
|
|
|
if startIdx == -1 {
|
|
@@ -1086,10 +1145,10 @@ func getValue(header, key string) string {
|
|
|
return header[startIdx:endIdx]
|
|
|
}
|
|
|
|
|
|
-func ParseAuthRedirectString(authStr string) auth.AuthRedirect {
|
|
|
+func parseRegistryChallenge(authStr string) registryChallenge {
|
|
|
authStr = strings.TrimPrefix(authStr, "Bearer ")
|
|
|
|
|
|
- return auth.AuthRedirect{
|
|
|
+ return registryChallenge{
|
|
|
Realm: getValue(authStr, "realm"),
|
|
|
Service: getValue(authStr, "service"),
|
|
|
Scope: getValue(authStr, "scope"),
|