浏览代码

validate model path

Michael Yang 8 月之前
父节点
当前提交
d9d50c43cc
共有 2 个文件被更改,包括 13 次插入13 次删除
  1. 5 13
      server/modelpath.go
  2. 8 0
      server/modelpath_test.go

+ 5 - 13
server/modelpath.go

@@ -73,18 +73,6 @@ func ParseModelPath(name string) ModelPath {
 
 var errModelPathInvalid = errors.New("invalid model path")
 
-func (mp ModelPath) Validate() error {
-	if mp.Repository == "" {
-		return fmt.Errorf("%w: model repository name is required", errModelPathInvalid)
-	}
-
-	if strings.Contains(mp.Tag, ":") {
-		return fmt.Errorf("%w: ':' (colon) is not allowed in tag names", errModelPathInvalid)
-	}
-
-	return nil
-}
-
 func (mp ModelPath) GetNamespaceRepository() string {
 	return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
 }
@@ -105,7 +93,11 @@ func (mp ModelPath) GetShortTagname() string {
 
 // GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
 func (mp ModelPath) GetManifestPath() (string, error) {
-	return filepath.Join(envconfig.Models(), "manifests", mp.Registry, mp.Namespace, mp.Repository, mp.Tag), nil
+	if p := filepath.Join(mp.Registry, mp.Namespace, mp.Repository, mp.Tag); filepath.IsLocal(p) {
+		return filepath.Join(envconfig.Models(), "manifests", p), nil
+	}
+
+	return "", errModelPathInvalid
 }
 
 func (mp ModelPath) BaseURL() *url.URL {

+ 8 - 0
server/modelpath_test.go

@@ -1,6 +1,7 @@
 package server
 
 import (
+	"errors"
 	"os"
 	"path/filepath"
 	"testing"
@@ -154,3 +155,10 @@ func TestParseModelPath(t *testing.T) {
 		})
 	}
 }
+
+func TestInsecureModelpath(t *testing.T) {
+	mp := ParseModelPath("../../..:something")
+	if _, err := mp.GetManifestPath(); !errors.Is(err, errModelPathInvalid) {
+		t.Errorf("expected error: %v", err)
+	}
+}