Browse Source

Merge pull request #443 from jmorganca/mxyng/fix-list-models

windows: fix filepath bugs
Michael Yang 1 year ago
parent
commit
9304f0e7a8
4 changed files with 49 additions and 73 deletions
  1. 2 2
      api/types.go
  2. 32 43
      server/images.go
  3. 1 1
      server/modelpath.go
  4. 14 27
      server/routes.go

+ 2 - 2
api/types.go

@@ -88,10 +88,10 @@ type PushRequest struct {
 }
 
 type ListResponse struct {
-	Models []ListResponseModel `json:"models"`
+	Models []ModelResponse `json:"models"`
 }
 
-type ListResponseModel struct {
+type ModelResponse struct {
 	Name       string    `json:"name"`
 	ModifiedAt time.Time `json:"modified_at"`
 	Size       int       `json:"size"`

+ 32 - 43
server/images.go

@@ -235,8 +235,8 @@ func GetModel(name string) (*Model, error) {
 
 func filenameWithPath(path, f string) (string, error) {
 	// if filePath starts with ~/, replace it with the user's home directory.
-	if strings.HasPrefix(f, "~/") {
-		parts := strings.Split(f, "/")
+	if strings.HasPrefix(f, fmt.Sprintf("~%s", string(os.PathSeparator))) {
+		parts := strings.Split(f, string(os.PathSeparator))
 		home, err := os.UserHomeDir()
 		if err != nil {
 			return "", fmt.Errorf("failed to open file: %v", err)
@@ -374,20 +374,9 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
 		case "adapter":
 			fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
 
-			fp := c.Args
-			if strings.HasPrefix(fp, "~/") {
-				parts := strings.Split(fp, "/")
-				home, err := os.UserHomeDir()
-				if err != nil {
-					return fmt.Errorf("failed to open file: %v", err)
-				}
-
-				fp = filepath.Join(home, filepath.Join(parts[1:]...))
-			}
-
-			// If filePath is not an absolute path, make it relative to the modelfile path
-			if !filepath.IsAbs(fp) {
-				fp = filepath.Join(filepath.Dir(path), fp)
+			fp, err := filenameWithPath(path, c.Args)
+			if err != nil {
+				return err
 			}
 
 			// create a model from this specified file
@@ -859,38 +848,38 @@ func DeleteModel(name string) error {
 	if err != nil {
 		return err
 	}
-	err = filepath.Walk(fp, func(path string, info os.FileInfo, err error) error {
-		if err != nil {
-			return err
+
+	walkFunc := func(path string, info os.FileInfo, _ error) error {
+		if info.IsDir() {
+			return nil
 		}
-		if !info.IsDir() {
-			path := path[len(fp)+1:]
-			slashIndex := strings.LastIndex(path, "/")
-			if slashIndex == -1 {
-				return nil
-			}
-			tag := path[:slashIndex] + ":" + path[slashIndex+1:]
-			fmp := ParseModelPath(tag)
 
-			// skip the manifest we're trying to delete
-			if mp.GetFullTagname() == fmp.GetFullTagname() {
-				return nil
-			}
+		dir, file := filepath.Split(path)
+		dir = strings.Trim(strings.TrimPrefix(dir, fp), string(os.PathSeparator))
+		tag := strings.Join([]string{dir, file}, ":")
+		fmp := ParseModelPath(tag)
 
-			// save (i.e. delete from the deleteMap) any files used in other manifests
-			manifest, _, err := GetManifest(fmp)
-			if err != nil {
-				log.Printf("skipping file: %s", fp)
-				return nil
-			}
-			for _, layer := range manifest.Layers {
-				delete(deleteMap, layer.Digest)
-			}
-			delete(deleteMap, manifest.Config.Digest)
+		// skip the manifest we're trying to delete
+		if mp.GetFullTagname() == fmp.GetFullTagname() {
+			return nil
+		}
+
+		// save (i.e. delete from the deleteMap) any files used in other manifests
+		manifest, _, err := GetManifest(fmp)
+		if err != nil {
+			log.Printf("skipping file: %s", fp)
+			return nil
 		}
+
+		for _, layer := range manifest.Layers {
+			delete(deleteMap, layer.Digest)
+		}
+
+		delete(deleteMap, manifest.Config.Digest)
 		return nil
-	})
-	if err != nil {
+	}
+
+	if err := filepath.Walk(fp, walkFunc); err != nil {
 		return err
 	}
 

+ 1 - 1
server/modelpath.go

@@ -46,7 +46,7 @@ func ParseModelPath(name string) ModelPath {
 		name = after
 	}
 
-	parts := strings.Split(name, "/")
+	parts := strings.Split(name, string(os.PathSeparator))
 	switch len(parts) {
 	case 3:
 		mp.Registry = parts[0]

+ 14 - 27
server/routes.go

@@ -3,7 +3,6 @@ package server
 import (
 	"context"
 	"encoding/json"
-	"errors"
 	"fmt"
 	"io"
 	"log"
@@ -365,32 +364,18 @@ func DeleteModelHandler(c *gin.Context) {
 }
 
 func ListModelsHandler(c *gin.Context) {
-	var models []api.ListResponseModel
+	var models []api.ModelResponse
 	fp, err := GetManifestPath()
 	if err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 	}
-	err = filepath.Walk(fp, func(path string, info os.FileInfo, err error) error {
-		if err != nil {
-			if errors.Is(err, os.ErrNotExist) {
-				log.Printf("manifest file does not exist: %s", fp)
-				return nil
-			}
-			return err
-		}
+
+	walkFunc := func(path string, info os.FileInfo, _ error) error {
 		if !info.IsDir() {
-			fi, err := os.Stat(path)
-			if err != nil {
-				log.Printf("skipping file: %s", fp)
-				return nil
-			}
-			path := path[len(fp)+1:]
-			slashIndex := strings.LastIndex(path, "/")
-			if slashIndex == -1 {
-				return nil
-			}
-			tag := path[:slashIndex] + ":" + path[slashIndex+1:]
+			dir, file := filepath.Split(path)
+			dir = strings.Trim(strings.TrimPrefix(dir, fp), string(os.PathSeparator))
+			tag := strings.Join([]string{dir, file}, ":")
 
 			mp := ParseModelPath(tag)
 			manifest, digest, err := GetManifest(mp)
@@ -398,17 +383,19 @@ func ListModelsHandler(c *gin.Context) {
 				log.Printf("skipping file: %s", fp)
 				return nil
 			}
-			model := api.ListResponseModel{
+
+			models = append(models, api.ModelResponse{
 				Name:       mp.GetShortTagname(),
 				Size:       manifest.GetTotalSize(),
 				Digest:     digest,
-				ModifiedAt: fi.ModTime(),
-			}
-			models = append(models, model)
+				ModifiedAt: info.ModTime(),
+			})
 		}
+
 		return nil
-	})
-	if err != nil {
+	}
+
+	if err := filepath.Walk(fp, walkFunc); err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 	}