Browse Source

local copy

Josh Yan 10 months ago
parent
commit
1a85cb904c
5 changed files with 112 additions and 14 deletions
  1. 7 2
      cmd/cmd.go
  2. 23 0
      cmd/copy_darwin.go
  3. 5 0
      cmd/copy_linux.go
  4. 53 0
      cmd/copy_windows.go
  5. 24 12
      server/routes.go

+ 7 - 2
cmd/cmd.go

@@ -319,7 +319,12 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
 		}
 
 		if err == nil {
-			err = createBlobLocal(path, dest)
+			err = localCopy(path, dest)
+			if err == nil {
+				return digest, nil
+			}
+
+			err = defaultCopy(path, dest)
 			if err == nil {
 				return digest, nil
 			}
@@ -377,7 +382,7 @@ func getLocalPath(ctx context.Context, digest string) (string, error) {
 	return "", ErrBlobExists
 }
 
-func createBlobLocal(path string, dest string) error {
+func defaultCopy(path string, dest string) error {
 	// This function should be called if the server is local
 	// It should find the model directory, copy the blob over, and return the digest
 	dirPath := filepath.Dir(dest)

+ 23 - 0
cmd/copy_darwin.go

@@ -0,0 +1,23 @@
+package cmd
+
+import (
+	"os"
+	"path/filepath"
+
+	"golang.org/x/sys/unix"
+)
+
+func localCopy(src, target string) error {
+	dirPath := filepath.Dir(target)
+
+	if err := os.MkdirAll(dirPath, 0o755); err != nil {
+		return err
+	}
+
+	err := unix.Clonefile(src, target, 0)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}

+ 5 - 0
cmd/copy_linux.go

@@ -0,0 +1,5 @@
+package cmd
+
+func localCopy(src, target string) error {
+	return defaultCopy(src, target)
+}

+ 53 - 0
cmd/copy_windows.go

@@ -0,0 +1,53 @@
+package cmd
+
+import (
+	"os"
+	"path/filepath"
+	"syscall"
+)
+
+func localCopy(src, target string) error {
+	dirPath := filepath.Dir(target)
+
+	if err := os.MkdirAll(dirPath, 0o755); err != nil {
+		return err
+	}
+
+	sourceFile, err := os.Open(src)
+	if err != nil {
+		return err
+	}
+	defer sourceFile.Close()
+
+	targetFile, err := os.Create(target)
+	if err != nil {
+		return err
+	}
+	defer targetFile.Close()
+
+	sourceHandle := syscall.Handle(sourceFile.Fd())
+	targetHandle := syscall.Handle(targetFile.Fd())
+
+	err = copyFileEx(sourceHandle, targetHandle)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func copyFileEx(srcHandle, dstHandle syscall.Handle) error {
+	kernel32 := syscall.NewLazyDLL("kernel32.dll")
+	copyFileEx := kernel32.NewProc("CopyFileExW")
+
+	r1, _, err := copyFileEx.Call(
+		uintptr(srcHandle),
+		uintptr(dstHandle),
+		0, 0, 0, 0)
+
+	if r1 == 0 {
+		return err
+	}
+
+	return nil
+}

+ 24 - 12
server/routes.go

@@ -769,12 +769,13 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
 		}
 	}
 
+	fmt.Println("path2", c.Param("digest"))
 	path, err := GetBlobsPath(c.Param("digest"))
 	if err != nil {
 		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 	}
-
+	fmt.Println("path1", path)
 	_, err = os.Stat(path)
 	switch {
 	case errors.Is(err, os.ErrNotExist):
@@ -786,8 +787,10 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
 		c.Status(http.StatusOK)
 		return
 	}
-
+	fmt.Println("hello")
+	fmt.Println(s.IsLocal(c))
 	if c.GetHeader("X-Redirect-Create") == "1" && s.IsLocal(c) {
+		fmt.Println("entered redirect")
 		c.Header("LocalLocation", path)
 		c.Status(http.StatusTemporaryRedirect)
 		return
@@ -808,25 +811,32 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
 }
 
 func (s *Server) IsLocal(c *gin.Context) bool {
+	fmt.Println("entered islocal")
+	fmt.Println(c.GetHeader("Authorization"), " is authorization")
 	if authz := c.GetHeader("Authorization"); authz != "" {
+
 		parts := strings.Split(authz, ":")
 		if len(parts) != 3 {
+			fmt.Println("failed at lenParts")
 			return false
 		}
 
 		clientPublicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(fmt.Sprintf("ssh-ed25519 %s", parts[0])))
 		if err != nil {
+			fmt.Println("failed at parseAuthorizedKey")
 			return false
 		}
 
 		// partialRequestData is formatted as http.Method,http.requestURI,timestamp,nonce
-		partialRequestData, err := base64.StdEncoding.DecodeString(parts[1])
+		requestData, err := base64.StdEncoding.DecodeString(parts[1])
 		if err != nil {
+			fmt.Println("failed at decodeString")
 			return false
 		}
 
-		partialRequestDataParts := strings.Split(string(partialRequestData), ",")
-		if len(partialRequestDataParts) != 4 {
+		partialRequestDataParts := strings.Split(string(requestData), ",")
+		if len(partialRequestDataParts) != 3 {
+			fmt.Println("failed at lenPartialRequestDataParts")
 			return false
 		}
 
@@ -849,22 +859,24 @@ func (s *Server) IsLocal(c *gin.Context) bool {
 
 		signature, err := base64.StdEncoding.DecodeString(parts[2])
 		if err != nil {
+			fmt.Println("failed at decodeString stdEncoding")
+			return false
+		}
+
+		if err := clientPublicKey.Verify([]byte(requestData), &ssh.Signature{Format: clientPublicKey.Type(), Blob: signature}); err != nil {
+			fmt.Println("failed at verify")
+			fmt.Println(err)
 			return false
 		}
 
 		serverPublicKey, err := auth.GetPublicKey()
 		if err != nil {
+			fmt.Println("failed at getPublicKey")
 			log.Fatal(err)
 		}
 
-		_, key, _ := bytes.Cut(bytes.TrimSpace(ssh.MarshalAuthorizedKey(serverPublicKey)), []byte(" "))
-		requestData := fmt.Sprintf("%s,%s", key, partialRequestData)
-
-		if err := clientPublicKey.Verify([]byte(requestData), &ssh.Signature{Format: clientPublicKey.Type(), Blob: signature}); err != nil {
-			return false
-		}
-
 		if bytes.Equal(serverPublicKey.Marshal(), clientPublicKey.Marshal()) {
+			fmt.Println("true")
 			return true
 		}