瀏覽代碼

Strip protocol from model path (#377)

Ryan Baker 1 年之前
父節點
當前提交
0a892419ad
共有 5 個文件被更改,包括 231 次插入43 次删除
  1. 20 6
      cmd/cmd.go
  2. 44 10
      server/images.go
  3. 39 26
      server/modelpath.go
  4. 122 0
      server/modelpath_test.go
  5. 6 1
      server/routes.go

+ 20 - 6
cmd/cmd.go

@@ -97,7 +97,16 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
 }
 
 func RunHandler(cmd *cobra.Command, args []string) error {
-	mp := server.ParseModelPath(args[0])
+	insecure, err := cmd.Flags().GetBool("insecure")
+	if err != nil {
+		return err
+	}
+
+	mp, err := server.ParseModelPath(args[0], insecure)
+	if err != nil {
+		return err
+	}
+
 	fp, err := mp.GetManifestPath(false)
 	if err != nil {
 		return err
@@ -106,7 +115,7 @@ func RunHandler(cmd *cobra.Command, args []string) error {
 	_, err = os.Stat(fp)
 	switch {
 	case errors.Is(err, os.ErrNotExist):
-		if err := pull(args[0], false); err != nil {
+		if err := pull(args[0], insecure); err != nil {
 			var apiStatusError api.StatusError
 			if !errors.As(err, &apiStatusError) {
 				return err
@@ -506,7 +515,11 @@ func generateInteractive(cmd *cobra.Command, model string) error {
 		case strings.HasPrefix(line, "/show"):
 			args := strings.Fields(line)
 			if len(args) > 1 {
-				mp := server.ParseModelPath(model)
+				mp, err := server.ParseModelPath(model, false)
+				if err != nil {
+					return err
+				}
+
 				manifest, err := server.GetManifest(mp)
 				if err != nil {
 					fmt.Println("error: couldn't get a manifest for this model")
@@ -569,7 +582,7 @@ func generateBatch(cmd *cobra.Command, model string) error {
 }
 
 func RunServer(cmd *cobra.Command, _ []string) error {
-	var host, port = "127.0.0.1", "11434"
+	host, port := "127.0.0.1", "11434"
 
 	parts := strings.Split(os.Getenv("OLLAMA_HOST"), ":")
 	if ip := net.ParseIP(parts[0]); ip != nil {
@@ -630,7 +643,7 @@ func initializeKeypair() error {
 			return fmt.Errorf("could not create directory %w", err)
 		}
 
-		err = os.WriteFile(privKeyPath, pem.EncodeToMemory(privKeyBytes), 0600)
+		err = os.WriteFile(privKeyPath, pem.EncodeToMemory(privKeyBytes), 0o600)
 		if err != nil {
 			return err
 		}
@@ -642,7 +655,7 @@ func initializeKeypair() error {
 
 		pubKeyData := ssh.MarshalAuthorizedKey(sshPrivateKey.PublicKey())
 
-		err = os.WriteFile(pubKeyPath, pubKeyData, 0644)
+		err = os.WriteFile(pubKeyPath, pubKeyData, 0o644)
 		if err != nil {
 			return err
 		}
@@ -737,6 +750,7 @@ func NewCLI() *cobra.Command {
 	}
 
 	runCmd.Flags().Bool("verbose", false, "Show timings for response")
+	runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
 
 	serveCmd := &cobra.Command{
 		Use:     "serve",

+ 44 - 10
server/images.go

@@ -153,7 +153,10 @@ func GetManifest(mp ModelPath) (*ManifestV2, error) {
 }
 
 func GetModel(name string) (*Model, error) {
-	mp := ParseModelPath(name)
+	mp, err := ParseModelPath(name, false)
+	if err != nil {
+		return nil, err
+	}
 
 	manifest, err := GetManifest(mp)
 	if err != nil {
@@ -272,7 +275,12 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
 		case "model":
 			fn(api.ProgressResponse{Status: "looking for model"})
 			embed.model = c.Args
-			mp := ParseModelPath(c.Args)
+
+			mp, err := ParseModelPath(c.Args, false)
+			if err != nil {
+				return err
+			}
+
 			mf, err := GetManifest(mp)
 			if err != nil {
 				modelFile, err := filenameWithPath(path, c.Args)
@@ -286,7 +294,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
 						if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
 							return err
 						}
-						mf, err = GetManifest(ParseModelPath(c.Args))
+						mf, err = GetManifest(mp)
 						if err != nil {
 							return fmt.Errorf("failed to open file after pull: %v", err)
 						}
@@ -674,7 +682,10 @@ func SaveLayers(layers []*LayerReader, fn func(resp api.ProgressResponse), force
 }
 
 func CreateManifest(name string, cfg *LayerReader, layers []*Layer) error {
-	mp := ParseModelPath(name)
+	mp, err := ParseModelPath(name, false)
+	if err != nil {
+		return err
+	}
 
 	manifest := ManifestV2{
 		SchemaVersion: 2,
@@ -806,11 +817,22 @@ func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
 }
 
 func CopyModel(src, dest string) error {
-	srcPath, err := ParseModelPath(src).GetManifestPath(false)
+	srcModelPath, err := ParseModelPath(src, false)
+	if err != nil {
+		return err
+	}
+
+	srcPath, err := srcModelPath.GetManifestPath(false)
+	if err != nil {
+		return err
+	}
+
+	destModelPath, err := ParseModelPath(dest, false)
 	if err != nil {
 		return err
 	}
-	destPath, err := ParseModelPath(dest).GetManifestPath(true)
+
+	destPath, err := destModelPath.GetManifestPath(true)
 	if err != nil {
 		return err
 	}
@@ -832,7 +854,10 @@ func CopyModel(src, dest string) error {
 }
 
 func DeleteModel(name string) error {
-	mp := ParseModelPath(name)
+	mp, err := ParseModelPath(name, false)
+	if err != nil {
+		return err
+	}
 
 	manifest, err := GetManifest(mp)
 	if err != nil {
@@ -859,7 +884,10 @@ func DeleteModel(name string) error {
 				return nil
 			}
 			tag := path[:slashIndex] + ":" + path[slashIndex+1:]
-			fmp := ParseModelPath(tag)
+			fmp, err := ParseModelPath(tag, false)
+			if err != nil {
+				return err
+			}
 
 			// skip the manifest we're trying to delete
 			if mp.GetFullTagname() == fmp.GetFullTagname() {
@@ -912,7 +940,10 @@ func DeleteModel(name string) error {
 }
 
 func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
-	mp := ParseModelPath(name)
+	mp, err := ParseModelPath(name, regOpts.Insecure)
+	if err != nil {
+		return err
+	}
 
 	fn(api.ProgressResponse{Status: "retrieving manifest"})
 
@@ -995,7 +1026,10 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
 }
 
 func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn func(api.ProgressResponse)) error {
-	mp := ParseModelPath(name)
+	mp, err := ParseModelPath(name, regOpts.Insecure)
+	if err != nil {
+		return err
+	}
 
 	fn(api.ProgressResponse{Status: "pulling manifest"})
 

+ 39 - 26
server/modelpath.go

@@ -1,6 +1,7 @@
 package server
 
 import (
+	"errors"
 	"fmt"
 	"os"
 	"path/filepath"
@@ -23,42 +24,54 @@ const (
 	DefaultProtocolScheme = "https"
 )
 
-func ParseModelPath(name string) ModelPath {
-	slashParts := strings.Split(name, "/")
-	var registry, namespace, repository, tag string
+var (
+	ErrInvalidImageFormat = errors.New("invalid image format")
+	ErrInvalidProtocol    = errors.New("invalid protocol scheme")
+	ErrInsecureProtocol   = errors.New("insecure protocol http")
+)
+
+func ParseModelPath(name string, allowInsecure bool) (ModelPath, error) {
+	mp := ModelPath{
+		ProtocolScheme: DefaultProtocolScheme,
+		Registry:       DefaultRegistry,
+		Namespace:      DefaultNamespace,
+		Repository:     "",
+		Tag:            DefaultTag,
+	}
 
+	protocol, rest, didSplit := strings.Cut(name, "://")
+	if didSplit {
+		if protocol == "https" || protocol == "http" && allowInsecure {
+			mp.ProtocolScheme = protocol
+			name = rest
+		} else if protocol == "http" && !allowInsecure {
+			return ModelPath{}, ErrInsecureProtocol
+		} else {
+			return ModelPath{}, ErrInvalidProtocol
+		}
+	}
+
+	slashParts := strings.Split(name, "/")
 	switch len(slashParts) {
 	case 3:
-		registry = slashParts[0]
-		namespace = slashParts[1]
-		repository = strings.Split(slashParts[2], ":")[0]
+		mp.Registry = slashParts[0]
+		mp.Namespace = slashParts[1]
+		mp.Repository = slashParts[2]
 	case 2:
-		registry = DefaultRegistry
-		namespace = slashParts[0]
-		repository = strings.Split(slashParts[1], ":")[0]
+		mp.Namespace = slashParts[0]
+		mp.Repository = slashParts[1]
 	case 1:
-		registry = DefaultRegistry
-		namespace = DefaultNamespace
-		repository = strings.Split(slashParts[0], ":")[0]
+		mp.Repository = slashParts[0]
 	default:
-		fmt.Println("Invalid image format.")
-		return ModelPath{}
+		return ModelPath{}, ErrInvalidImageFormat
 	}
 
-	colonParts := strings.Split(slashParts[len(slashParts)-1], ":")
-	if len(colonParts) == 2 {
-		tag = colonParts[1]
-	} else {
-		tag = DefaultTag
+	if repo, tag, didSplit := strings.Cut(mp.Repository, ":"); didSplit {
+		mp.Repository = repo
+		mp.Tag = tag
 	}
 
-	return ModelPath{
-		ProtocolScheme: DefaultProtocolScheme,
-		Registry:       registry,
-		Namespace:      namespace,
-		Repository:     repository,
-		Tag:            tag,
-	}
+	return mp, nil
 }
 
 func (mp ModelPath) GetNamespaceRepository() string {

+ 122 - 0
server/modelpath_test.go

@@ -0,0 +1,122 @@
+package server
+
+import "testing"
+
+func TestParseModelPath(t *testing.T) {
+	type input struct {
+		name          string
+		allowInsecure bool
+	}
+
+	tests := []struct {
+		name    string
+		args    input
+		want    ModelPath
+		wantErr error
+	}{
+		{
+			"full path https",
+			input{"https://example.com/ns/repo:tag", false},
+			ModelPath{
+				ProtocolScheme: "https",
+				Registry:       "example.com",
+				Namespace:      "ns",
+				Repository:     "repo",
+				Tag:            "tag",
+			},
+			nil,
+		},
+		{
+			"full path http without insecure",
+			input{"http://example.com/ns/repo:tag", false},
+			ModelPath{},
+			ErrInsecureProtocol,
+		},
+		{
+			"full path http with insecure",
+			input{"http://example.com/ns/repo:tag", true},
+			ModelPath{
+				ProtocolScheme: "http",
+				Registry:       "example.com",
+				Namespace:      "ns",
+				Repository:     "repo",
+				Tag:            "tag",
+			},
+			nil,
+		},
+		{
+			"full path invalid protocol",
+			input{"file://example.com/ns/repo:tag", false},
+			ModelPath{},
+			ErrInvalidProtocol,
+		},
+		{
+			"no protocol",
+			input{"example.com/ns/repo:tag", false},
+			ModelPath{
+				ProtocolScheme: "https",
+				Registry:       "example.com",
+				Namespace:      "ns",
+				Repository:     "repo",
+				Tag:            "tag",
+			},
+			nil,
+		},
+		{
+			"no registry",
+			input{"ns/repo:tag", false},
+			ModelPath{
+				ProtocolScheme: "https",
+				Registry:       DefaultRegistry,
+				Namespace:      "ns",
+				Repository:     "repo",
+				Tag:            "tag",
+			},
+			nil,
+		},
+		{
+			"no namespace",
+			input{"repo:tag", false},
+			ModelPath{
+				ProtocolScheme: "https",
+				Registry:       DefaultRegistry,
+				Namespace:      DefaultNamespace,
+				Repository:     "repo",
+				Tag:            "tag",
+			},
+			nil,
+		},
+		{
+			"no tag",
+			input{"repo", false},
+			ModelPath{
+				ProtocolScheme: "https",
+				Registry:       DefaultRegistry,
+				Namespace:      DefaultNamespace,
+				Repository:     "repo",
+				Tag:            DefaultTag,
+			},
+			nil,
+		},
+		{
+			"invalid image format",
+			input{"example.com/a/b/c", false},
+			ModelPath{},
+			ErrInvalidImageFormat,
+		},
+	}
+
+	for _, tc := range tests {
+		t.Run(tc.name, func(t *testing.T) {
+			got, err := ParseModelPath(tc.args.name, tc.args.allowInsecure)
+
+			if err != tc.wantErr {
+				t.Errorf("got: %q want: %q", err, tc.wantErr)
+			}
+
+			if got != tc.want {
+				t.Errorf("got: %q want: %q", got, tc.want)
+			}
+		})
+	}
+}

+ 6 - 1
server/routes.go

@@ -357,7 +357,12 @@ func ListModelsHandler(c *gin.Context) {
 				return nil
 			}
 			tag := path[:slashIndex] + ":" + path[slashIndex+1:]
-			mp := ParseModelPath(tag)
+
+			mp, err := ParseModelPath(tag, false)
+			if err != nil {
+				return err
+			}
+
 			manifest, err := GetManifest(mp)
 			if err != nil {
 				log.Printf("skipping file: %s", fp)