Browse Source

fix `FROM` instruction erroring when referring to a file

Jeffrey Morgan 1 year ago
parent
commit
a9f6c56652
5 changed files with 47 additions and 115 deletions
  1. 6 2
      cmd/cmd.go
  2. 16 42
      server/images.go
  3. 15 23
      server/modelpath.go
  4. 9 43
      server/modelpath_test.go
  5. 1 5
      server/routes.go

+ 6 - 2
cmd/cmd.go

@@ -102,11 +102,15 @@ func RunHandler(cmd *cobra.Command, args []string) error {
 		return err
 	}
 
-	mp, err := server.ParseModelPath(args[0], insecure)
+	mp := server.ParseModelPath(args[0])
 	if err != nil {
 		return err
 	}
 
+	if mp.ProtocolScheme == "http" && !insecure {
+		return fmt.Errorf("insecure protocol http")
+	}
+
 	fp, err := mp.GetManifestPath(false)
 	if err != nil {
 		return err
@@ -515,7 +519,7 @@ func generateInteractive(cmd *cobra.Command, model string) error {
 		case strings.HasPrefix(line, "/show"):
 			args := strings.Fields(line)
 			if len(args) > 1 {
-				mp, err := server.ParseModelPath(model, false)
+				mp := server.ParseModelPath(model)
 				if err != nil {
 					return err
 				}

+ 16 - 42
server/images.go

@@ -153,11 +153,7 @@ func GetManifest(mp ModelPath) (*ManifestV2, error) {
 }
 
 func GetModel(name string) (*Model, error) {
-	mp, err := ParseModelPath(name, false)
-	if err != nil {
-		return nil, err
-	}
-
+	mp := ParseModelPath(name)
 	manifest, err := GetManifest(mp)
 	if err != nil {
 		return nil, err
@@ -276,11 +272,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
 			fn(api.ProgressResponse{Status: "looking for model"})
 			embed.model = c.Args
 
-			mp, err := ParseModelPath(c.Args, false)
-			if err != nil {
-				return err
-			}
-
+			mp := ParseModelPath(c.Args)
 			mf, err := GetManifest(mp)
 			if err != nil {
 				modelFile, err := filenameWithPath(path, c.Args)
@@ -682,11 +674,7 @@ func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force
 }
 
 func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error {
-	mp, err := ParseModelPath(name, false)
-	if err != nil {
-		return err
-	}
-
+	mp := ParseModelPath(name)
 	manifest := ManifestV2{
 		SchemaVersion: 2,
 		MediaType:     "application/vnd.docker.distribution.manifest.v2+json",
@@ -817,21 +805,13 @@ func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
 }
 
 func CopyModel(src, dest string) error {
-	srcModelPath, err := ParseModelPath(src, false)
-	if err != nil {
-		return err
-	}
-
+	srcModelPath := ParseModelPath(src)
 	srcPath, err := srcModelPath.GetManifestPath(false)
 	if err != nil {
 		return err
 	}
 
-	destModelPath, err := ParseModelPath(dest, false)
-	if err != nil {
-		return err
-	}
-
+	destModelPath := ParseModelPath(dest)
 	destPath, err := destModelPath.GetManifestPath(true)
 	if err != nil {
 		return err
@@ -854,11 +834,7 @@ func CopyModel(src, dest string) error {
 }
 
 func DeleteModel(name string) error {
-	mp, err := ParseModelPath(name, false)
-	if err != nil {
-		return err
-	}
-
+	mp := ParseModelPath(name)
 	manifest, err := GetManifest(mp)
 	if err != nil {
 		return err
@@ -884,10 +860,7 @@ func DeleteModel(name string) error {
 				return nil
 			}
 			tag := path[:slashIndex] + ":" + path[slashIndex+1:]
-			fmp, err := ParseModelPath(tag, false)
-			if err != nil {
-				return err
-			}
+			fmp := ParseModelPath(tag)
 
 			// skip the manifest we're trying to delete
 			if mp.GetFullTagname() == fmp.GetFullTagname() {
@@ -940,13 +913,13 @@ func DeleteModel(name string) error {
 }
 
 func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
-	mp, err := ParseModelPath(name, regOpts.Insecure)
-	if err != nil {
-		return err
-	}
-
+	mp := ParseModelPath(name)
 	fn(api.ProgressResponse{Status: "retrieving manifest"})
 
+	if mp.ProtocolScheme == "http" && !regOpts.Insecure {
+		return fmt.Errorf("insecure protocol http")
+	}
+
 	manifest, err := GetManifest(mp)
 	if err != nil {
 		fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
@@ -1026,9 +999,10 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
 }
 
 func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
-	mp, err := ParseModelPath(name, regOpts.Insecure)
-	if err != nil {
-		return err
+	mp := ParseModelPath(name)
+
+	if mp.ProtocolScheme == "http" && !regOpts.Insecure {
+		return fmt.Errorf("insecure protocol http")
 	}
 
 	fn(api.ProgressResponse{Status: "pulling manifest"})

+ 15 - 23
server/modelpath.go

@@ -30,7 +30,7 @@ var (
 	ErrInsecureProtocol   = errors.New("insecure protocol http")
 )
 
-func ParseModelPath(name string, allowInsecure bool) (ModelPath, error) {
+func ParseModelPath(name string) ModelPath {
 	mp := ModelPath{
 		ProtocolScheme: DefaultProtocolScheme,
 		Registry:       DefaultRegistry,
@@ -39,39 +39,31 @@ func ParseModelPath(name string, allowInsecure bool) (ModelPath, error) {
 		Tag:            DefaultTag,
 	}
 
-	protocol, rest, didSplit := strings.Cut(name, "://")
-	if didSplit {
-		if protocol == "https" || protocol == "http" && allowInsecure {
-			mp.ProtocolScheme = protocol
-			name = rest
-		} else if protocol == "http" && !allowInsecure {
-			return ModelPath{}, ErrInsecureProtocol
-		} else {
-			return ModelPath{}, ErrInvalidProtocol
-		}
+	parts := strings.Split(name, "://")
+	if len(parts) > 1 {
+		mp.ProtocolScheme = parts[0]
+		name = parts[1]
 	}
 
-	slashParts := strings.Split(name, "/")
-	switch len(slashParts) {
+	parts = strings.Split(name, "/")
+	switch len(parts) {
 	case 3:
-		mp.Registry = slashParts[0]
-		mp.Namespace = slashParts[1]
-		mp.Repository = slashParts[2]
+		mp.Registry = parts[0]
+		mp.Namespace = parts[1]
+		mp.Repository = parts[2]
 	case 2:
-		mp.Namespace = slashParts[0]
-		mp.Repository = slashParts[1]
+		mp.Namespace = parts[0]
+		mp.Repository = parts[1]
 	case 1:
-		mp.Repository = slashParts[0]
-	default:
-		return ModelPath{}, ErrInvalidImageFormat
+		mp.Repository = parts[0]
 	}
 
-	if repo, tag, didSplit := strings.Cut(mp.Repository, ":"); didSplit {
+	if repo, tag, found := strings.Cut(mp.Repository, ":"); found {
 		mp.Repository = repo
 		mp.Tag = tag
 	}
 
-	return mp, nil
+	return mp
 }
 
 func (mp ModelPath) GetNamespaceRepository() string {

+ 9 - 43
server/modelpath_test.go

@@ -3,20 +3,14 @@ package server
 import "testing"
 
 func TestParseModelPath(t *testing.T) {
-	type input struct {
-		name          string
-		allowInsecure bool
-	}
-
 	tests := []struct {
 		name    string
-		args    input
+		arg    string
 		want    ModelPath
-		wantErr error
 	}{
 		{
 			"full path https",
-			input{"https://example.com/ns/repo:tag", false},
+			"https://example.com/ns/repo:tag",
 			ModelPath{
 				ProtocolScheme: "https",
 				Registry:       "example.com",
@@ -24,17 +18,10 @@ func TestParseModelPath(t *testing.T) {
 				Repository:     "repo",
 				Tag:            "tag",
 			},
-			nil,
 		},
 		{
-			"full path http without insecure",
-			input{"http://example.com/ns/repo:tag", false},
-			ModelPath{},
-			ErrInsecureProtocol,
-		},
-		{
-			"full path http with insecure",
-			input{"http://example.com/ns/repo:tag", true},
+			"full path http",
+			"http://example.com/ns/repo:tag",
 			ModelPath{
 				ProtocolScheme: "http",
 				Registry:       "example.com",
@@ -42,17 +29,10 @@ func TestParseModelPath(t *testing.T) {
 				Repository:     "repo",
 				Tag:            "tag",
 			},
-			nil,
-		},
-		{
-			"full path invalid protocol",
-			input{"file://example.com/ns/repo:tag", false},
-			ModelPath{},
-			ErrInvalidProtocol,
 		},
 		{
 			"no protocol",
-			input{"example.com/ns/repo:tag", false},
+			"example.com/ns/repo:tag",
 			ModelPath{
 				ProtocolScheme: "https",
 				Registry:       "example.com",
@@ -60,11 +40,10 @@ func TestParseModelPath(t *testing.T) {
 				Repository:     "repo",
 				Tag:            "tag",
 			},
-			nil,
 		},
 		{
 			"no registry",
-			input{"ns/repo:tag", false},
+			"ns/repo:tag",
 			ModelPath{
 				ProtocolScheme: "https",
 				Registry:       DefaultRegistry,
@@ -72,11 +51,10 @@ func TestParseModelPath(t *testing.T) {
 				Repository:     "repo",
 				Tag:            "tag",
 			},
-			nil,
 		},
 		{
 			"no namespace",
-			input{"repo:tag", false},
+			"repo:tag",
 			ModelPath{
 				ProtocolScheme: "https",
 				Registry:       DefaultRegistry,
@@ -84,11 +62,10 @@ func TestParseModelPath(t *testing.T) {
 				Repository:     "repo",
 				Tag:            "tag",
 			},
-			nil,
 		},
 		{
 			"no tag",
-			input{"repo", false},
+			"repo",
 			ModelPath{
 				ProtocolScheme: "https",
 				Registry:       DefaultRegistry,
@@ -96,23 +73,12 @@ func TestParseModelPath(t *testing.T) {
 				Repository:     "repo",
 				Tag:            DefaultTag,
 			},
-			nil,
-		},
-		{
-			"invalid image format",
-			input{"example.com/a/b/c", false},
-			ModelPath{},
-			ErrInvalidImageFormat,
 		},
 	}
 
 	for _, tc := range tests {
 		t.Run(tc.name, func(t *testing.T) {
-			got, err := ParseModelPath(tc.args.name, tc.args.allowInsecure)
-
-			if err != tc.wantErr {
-				t.Errorf("got: %q want: %q", err, tc.wantErr)
-			}
+			got := ParseModelPath(tc.arg)
 
 			if got != tc.want {
 				t.Errorf("got: %q want: %q", got, tc.want)

+ 1 - 5
server/routes.go

@@ -358,11 +358,7 @@ func ListModelsHandler(c *gin.Context) {
 			}
 			tag := path[:slashIndex] + ":" + path[slashIndex+1:]
 
-			mp, err := ParseModelPath(tag, false)
-			if err != nil {
-				return err
-			}
-
+			mp := ParseModelPath(tag)
 			manifest, err := GetManifest(mp)
 			if err != nil {
 				log.Printf("skipping file: %s", fp)