浏览代码

hide initialize keypair

Josh Yan 9 月之前
父节点
当前提交
6b1b85ba3d
共有 1 个文件被更改,包括 18 次插入26 次删除
  1. 18 26
      auth/auth.go

+ 18 - 26
auth/auth.go

@@ -26,7 +26,14 @@ func privateKey() (ssh.Signer, error) {
 
 	keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
 	privateKeyFile, err := os.ReadFile(keyPath)
-	if err != nil {
+	if os.IsNotExist(err) {
+		err := initializeKeypair()
+		if err != nil {
+			return nil, err
+		}
+
+		return privateKey()
+	} else if err != nil {
 		slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
 		return nil, err
 	}
@@ -36,42 +43,27 @@ func privateKey() (ssh.Signer, error) {
 
 func GetPublicKey() (ssh.PublicKey, error) {
 	// try to read pubkey first
-	pubkey, err := readPubkey()
-	if err == nil {
-		return pubkey, nil
-	}
-
-	privateKey, err := privateKey()
-	if err == nil {
-		return privateKey.PublicKey(), nil
-	}
-
-	err = initializeKeypair()
-	if err != nil {
-		return nil, err
-	}
-
-	return readPubkey()
-}
-
-func readPubkey() (ssh.PublicKey, error) {
 	home, err := os.UserHomeDir()
 	if err != nil {
 		return nil, err
 	}
 
 	pubkeyPath := filepath.Join(home, ".ollama", defaultPrivateKey+".pub")
-	_, err = os.Stat(pubkeyPath)
+	pubKeyFile, err := os.ReadFile(pubkeyPath)
 	if os.IsNotExist(err) {
-		return nil, fmt.Errorf("public key not found")
-	}
+		// try from privateKey
+		privateKey, err := privateKey()
+		if err != nil {
+			return nil, fmt.Errorf("failed to read public key: %w", err)
+		}
 
-	pubKeyFile, err := os.ReadFile(pubkeyPath)
-	if err != nil {
+		return privateKey.PublicKey(), nil
+	} else if err != nil {
 		return nil, fmt.Errorf("failed to read public key: %w", err)
 	}
 
-	return ssh.ParsePublicKey(pubKeyFile)
+	pubKey, _, _, _, err := ssh.ParseAuthorizedKey(pubKeyFile)
+	return pubKey, err
 }
 
 func NewNonce(r io.Reader, length int) (string, error) {