Browse Source

validate model path

Michael Yang 8 months ago
parent
commit
d9d50c43cc
2 changed files with 13 additions and 13 deletions
  1. 5 13
      server/modelpath.go
  2. 8 0
      server/modelpath_test.go

+ 5 - 13
server/modelpath.go

@@ -73,18 +73,6 @@ func ParseModelPath(name string) ModelPath {
 
 
 var errModelPathInvalid = errors.New("invalid model path")
 var errModelPathInvalid = errors.New("invalid model path")
 
 
-func (mp ModelPath) Validate() error {
-	if mp.Repository == "" {
-		return fmt.Errorf("%w: model repository name is required", errModelPathInvalid)
-	}
-
-	if strings.Contains(mp.Tag, ":") {
-		return fmt.Errorf("%w: ':' (colon) is not allowed in tag names", errModelPathInvalid)
-	}
-
-	return nil
-}
-
 func (mp ModelPath) GetNamespaceRepository() string {
 func (mp ModelPath) GetNamespaceRepository() string {
 	return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
 	return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
 }
 }
@@ -105,7 +93,11 @@ func (mp ModelPath) GetShortTagname() string {
 
 
 // GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
 // GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
 func (mp ModelPath) GetManifestPath() (string, error) {
 func (mp ModelPath) GetManifestPath() (string, error) {
-	return filepath.Join(envconfig.Models(), "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag), nil
+	if p := filepath.Join(mp.Registry, mp.Namespace, mp.Repository, mp.Tag); filepath.IsLocal(p) {
+		return filepath.Join(envconfig.Models(), "manifests", p), nil
+	}
+
+	return "", errModelPathInvalid
 }
 }
 
 
 func (mp ModelPath) BaseURL() *url.URL {
 func (mp ModelPath) BaseURL() *url.URL {

+ 8 - 0
server/modelpath_test.go

@@ -1,6 +1,7 @@
 package server
 package server
 
 
 import (
 import (
+	"errors"
 	"os"
 	"os"
 	"path/filepath"
 	"path/filepath"
 	"testing"
 	"testing"
@@ -154,3 +155,10 @@ func TestParseModelPath(t *testing.T) {
 		})
 		})
 	}
 	}
 }
 }
+
+func TestInsecureModelpath(t *testing.T) {
+	mp := ParseModelPath("../../..:something")
+	if _, err := mp.GetManifestPath(); !errors.Is(err, errModelPathInvalid) {
+		t.Errorf("expected error: %v", err)
+	}
+}