Browse Source

prompt to display and add local ollama keys to account (#3717)

- return descriptive error messages when unauthorized to create blob or push a model
- display the local public key associated with the request that was denied
Bruce MacDonald 1 year ago
parent
commit
0a7fdbe533
4 changed files with 155 additions and 7 deletions
  1. 33 3
      auth/auth.go
  2. 58 0
      cmd/cmd.go
  3. 46 4
      server/images.go
  4. 18 0
      types/errtypes/errtypes.go

+ 33 - 3
auth/auth.go

@@ -10,12 +10,44 @@ import (
 	"log/slog"
 	"os"
 	"path/filepath"
+	"strings"
 
 	"golang.org/x/crypto/ssh"
 )
 
 const defaultPrivateKey = "id_ed25519"
 
+func keyPath() (string, error) {
+	home, err := os.UserHomeDir()
+	if err != nil {
+		return "", err
+	}
+
+	return filepath.Join(home, ".ollama", defaultPrivateKey), nil
+}
+
+func GetPublicKey() (string, error) {
+	keyPath, err := keyPath()
+	if err != nil {
+		return "", err
+	}
+
+	privateKeyFile, err := os.ReadFile(keyPath)
+	if err != nil {
+		slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
+		return "", err
+	}
+
+	privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
+	if err != nil {
+		return "", err
+	}
+
+	publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
+
+	return strings.TrimSpace(string(publicKey)), nil
+}
+
 func NewNonce(r io.Reader, length int) (string, error) {
 	nonce := make([]byte, length)
 	if _, err := io.ReadFull(r, nonce); err != nil {
@@ -26,13 +58,11 @@ func NewNonce(r io.Reader, length int) (string, error) {
 }
 
 func Sign(ctx context.Context, bts []byte) (string, error) {
-	home, err := os.UserHomeDir()
+	keyPath, err := keyPath()
 	if err != nil {
 		return "", err
 	}
 
-	keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
-
 	privateKeyFile, err := os.ReadFile(keyPath)
 	if err != nil {
 		slog.Info(fmt.Sprintf("Failed to load private key: %v", err))

+ 58 - 0
cmd/cmd.go

@@ -32,10 +32,13 @@ import (
 	"golang.org/x/term"
 
 	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/auth"
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/parser"
 	"github.com/ollama/ollama/progress"
 	"github.com/ollama/ollama/server"
+	"github.com/ollama/ollama/types/errtypes"
+	"github.com/ollama/ollama/types/model"
 	"github.com/ollama/ollama/version"
 )
 
@@ -357,6 +360,47 @@ func RunHandler(cmd *cobra.Command, args []string) error {
 	return generateInteractive(cmd, opts)
 }
 
+func errFromUnknownKey(unknownKeyErr error) error {
+	// find SSH public key in the error message
+	sshKeyPattern := `ssh-\w+ [^\s"]+`
+	re := regexp.MustCompile(sshKeyPattern)
+	matches := re.FindStringSubmatch(unknownKeyErr.Error())
+
+	if len(matches) > 0 {
+		serverPubKey := matches[0]
+
+		localPubKey, err := auth.GetPublicKey()
+		if err != nil {
+			return unknownKeyErr
+		}
+
+		if runtime.GOOS == "linux" && serverPubKey != localPubKey {
+			// try the ollama service public key
+			svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
+			if err != nil {
+				return unknownKeyErr
+			}
+			localPubKey = strings.TrimSpace(string(svcPubKey))
+		}
+
+		// check if the returned public key matches the local public key, this prevents adding a remote key to the user's account
+		if serverPubKey != localPubKey {
+			return unknownKeyErr
+		}
+
+		var msg strings.Builder
+		msg.WriteString(unknownKeyErr.Error())
+		msg.WriteString("\n\nYour ollama key is:\n")
+		msg.WriteString(localPubKey)
+		msg.WriteString("\nAdd your key at:\n")
+		msg.WriteString("https://ollama.com/settings/keys")
+
+		return errors.New(msg.String())
+	}
+
+	return unknownKeyErr
+}
+
 func PushHandler(cmd *cobra.Command, args []string) error {
 	client, err := api.ClientFromEnvironment()
 	if err != nil {
@@ -404,6 +448,20 @@ func PushHandler(cmd *cobra.Command, args []string) error {
 
 	request := api.PushRequest{Name: args[0], Insecure: insecure}
 	if err := client.Push(cmd.Context(), &request, fn); err != nil {
+		if spinner != nil {
+			spinner.Stop()
+		}
+		if strings.Contains(err.Error(), "access denied") {
+			return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
+		}
+		host := model.ParseName(args[0]).Host
+		isOllamaHost := strings.HasSuffix(host, ".ollama.ai") || strings.HasSuffix(host, ".ollama.com")
+		if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost {
+			// the user has not added their ollama key to ollama.com
+			// re-throw an error with a more user-friendly message
+			return errFromUnknownKey(err)
+		}
+
 		return err
 	}
 

+ 46 - 4
server/images.go

@@ -5,6 +5,7 @@ import (
 	"bytes"
 	"context"
 	"crypto/sha256"
+	"encoding/base64"
 	"encoding/hex"
 	"encoding/json"
 	"errors"
@@ -25,10 +26,12 @@ import (
 	"golang.org/x/exp/slices"
 
 	"github.com/ollama/ollama/api"
+	"github.com/ollama/ollama/auth"
 	"github.com/ollama/ollama/convert"
 	"github.com/ollama/ollama/format"
 	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/parser"
+	"github.com/ollama/ollama/types/errtypes"
 	"github.com/ollama/ollama/types/model"
 	"github.com/ollama/ollama/version"
 )
@@ -980,9 +983,6 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
 	for _, layer := range layers {
 		if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
 			slog.Info(fmt.Sprintf("error uploading blob: %v", err))
-			if errors.Is(err, errUnauthorized) {
-				return fmt.Errorf("unable to push %s, make sure this namespace exists and you are authorized to push to it", ParseModelPath(name).GetNamespaceRepository())
-			}
 			return err
 		}
 	}
@@ -1145,9 +1145,40 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
 	return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
 }
 
-var errUnauthorized = errors.New("unauthorized")
+var errUnauthorized = fmt.Errorf("unauthorized: access denied")
+
+// getTokenSubject returns the subject of a JWT token, it does not validate the token
+func getTokenSubject(token string) string {
+	parts := strings.Split(token, ".")
+	if len(parts) != 3 {
+		slog.Error("jwt token does not contain 3 parts")
+		return ""
+	}
+
+	payload := parts[1]
+	payloadBytes, err := base64.RawURLEncoding.DecodeString(payload)
+	if err != nil {
+		slog.Error(fmt.Sprintf("failed to decode jwt payload: %v", err))
+		return ""
+	}
+
+	var payloadMap map[string]interface{}
+	if err := json.Unmarshal(payloadBytes, &payloadMap); err != nil {
+		slog.Error(fmt.Sprintf("failed to unmarshal payload JSON: %v", err))
+		return ""
+	}
+
+	sub, ok := payloadMap["sub"]
+	if !ok {
+		slog.Error("jwt does not contain 'sub' field")
+		return ""
+	}
+
+	return fmt.Sprintf("%s", sub)
+}
 
 func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
+	anonymous := true // access will default to anonymous if no user is found associated with the public key
 	for i := 0; i < 2; i++ {
 		resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
 		if err != nil {
@@ -1166,6 +1197,7 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
 			if err != nil {
 				return nil, err
 			}
+			anonymous = getTokenSubject(token) == "anonymous"
 			regOpts.Token = token
 			if body != nil {
 				_, err = body.Seek(0, io.SeekStart)
@@ -1186,6 +1218,16 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
 		}
 	}
 
+	if anonymous {
+		// no user is associated with the public key, and the request requires non-anonymous access
+		pubKey, nestedErr := auth.GetPublicKey()
+		if nestedErr != nil {
+			slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
+			return nil, errUnauthorized
+		}
+		return nil, &errtypes.UnknownOllamaKey{Key: pubKey}
+	}
+	// user is associated with the public key, but is not authorized to make the request
 	return nil, errUnauthorized
 }
 

+ 18 - 0
types/errtypes/errtypes.go

@@ -0,0 +1,18 @@
+// Package errtypes contains custom error types
+package errtypes
+
+import (
+	"fmt"
+	"strings"
+)
+
+const UnknownOllamaKeyErrMsg = "unknown ollama key"
+
+// TODO: This should have a structured response from the API
+type UnknownOllamaKey struct {
+	Key string
+}
+
+func (e *UnknownOllamaKey) Error() string {
+	return fmt.Sprintf("unauthorized: %s %q", UnknownOllamaKeyErrMsg, strings.TrimSpace(e.Key))
+}