|
@@ -5,6 +5,7 @@ import (
|
|
|
"bytes"
|
|
|
"context"
|
|
|
"crypto/sha256"
|
|
|
+ "encoding/base64"
|
|
|
"encoding/hex"
|
|
|
"encoding/json"
|
|
|
"errors"
|
|
@@ -25,10 +26,12 @@ import (
|
|
|
"golang.org/x/exp/slices"
|
|
|
|
|
|
"github.com/ollama/ollama/api"
|
|
|
+ "github.com/ollama/ollama/auth"
|
|
|
"github.com/ollama/ollama/convert"
|
|
|
"github.com/ollama/ollama/format"
|
|
|
"github.com/ollama/ollama/llm"
|
|
|
"github.com/ollama/ollama/parser"
|
|
|
+ "github.com/ollama/ollama/types/errtypes"
|
|
|
"github.com/ollama/ollama/types/model"
|
|
|
"github.com/ollama/ollama/version"
|
|
|
)
|
|
@@ -980,9 +983,6 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
|
|
for _, layer := range layers {
|
|
|
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
|
|
|
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
|
|
|
- if errors.Is(err, errUnauthorized) {
|
|
|
- return fmt.Errorf("unable to push %s, make sure this namespace exists and you are authorized to push to it", ParseModelPath(name).GetNamespaceRepository())
|
|
|
- }
|
|
|
return err
|
|
|
}
|
|
|
}
|
|
@@ -1145,9 +1145,40 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
|
|
|
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
|
|
|
}
|
|
|
|
|
|
-var errUnauthorized = errors.New("unauthorized")
|
|
|
+var errUnauthorized = fmt.Errorf("unauthorized: access denied")
|
|
|
+
|
|
|
+// getTokenSubject returns the subject of a JWT token, it does not validate the token
|
|
|
+func getTokenSubject(token string) string {
|
|
|
+ parts := strings.Split(token, ".")
|
|
|
+ if len(parts) != 3 {
|
|
|
+ slog.Error("jwt token does not contain 3 parts")
|
|
|
+ return ""
|
|
|
+ }
|
|
|
+
|
|
|
+ payload := parts[1]
|
|
|
+ payloadBytes, err := base64.RawURLEncoding.DecodeString(payload)
|
|
|
+ if err != nil {
|
|
|
+ slog.Error(fmt.Sprintf("failed to decode jwt payload: %v", err))
|
|
|
+ return ""
|
|
|
+ }
|
|
|
+
|
|
|
+ var payloadMap map[string]interface{}
|
|
|
+ if err := json.Unmarshal(payloadBytes, &payloadMap); err != nil {
|
|
|
+ slog.Error(fmt.Sprintf("failed to unmarshal payload JSON: %v", err))
|
|
|
+ return ""
|
|
|
+ }
|
|
|
+
|
|
|
+ sub, ok := payloadMap["sub"]
|
|
|
+ if !ok {
|
|
|
+ slog.Error("jwt does not contain 'sub' field")
|
|
|
+ return ""
|
|
|
+ }
|
|
|
+
|
|
|
+ return fmt.Sprintf("%s", sub)
|
|
|
+}
|
|
|
|
|
|
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
|
|
|
+ anonymous := true // access will default to anonymous if no user is found associated with the public key
|
|
|
for i := 0; i < 2; i++ {
|
|
|
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
|
|
if err != nil {
|
|
@@ -1166,6 +1197,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
|
|
if err != nil {
|
|
|
return nil, err
|
|
|
}
|
|
|
+ anonymous = getTokenSubject(token) == "anonymous"
|
|
|
regOpts.Token = token
|
|
|
if body != nil {
|
|
|
_, err = body.Seek(0, io.SeekStart)
|
|
@@ -1186,6 +1218,16 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ if anonymous {
|
|
|
+ // no user is associated with the public key, and the request requires non-anonymous access
|
|
|
+ pubKey, nestedErr := auth.GetPublicKey()
|
|
|
+ if nestedErr != nil {
|
|
|
+ slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
|
|
|
+ return nil, errUnauthorized
|
|
|
+ }
|
|
|
+ return nil, &errtypes.UnknownOllamaKey{Key: pubKey}
|
|
|
+ }
|
|
|
+ // user is associated with the public key, but is not authorized to make the request
|
|
|
return nil, errUnauthorized
|
|
|
}
|
|
|
|