Browse Source

validate model tags on copy (#1323)

Bruce MacDonald 1 year ago
parent
commit
96122b7271
2 changed files with 21 additions and 2 deletions
  1. 14 0
      server/modelpath.go
  2. 7 2
      server/routes.go

+ 14 - 0
server/modelpath.go

@@ -67,6 +67,20 @@ func ParseModelPath(name string) ModelPath {
 	return mp
 }
 
+var errModelPathInvalid = errors.New("invalid model path")
+
+func (mp ModelPath) Validate() error {
+	if mp.Repository == "" {
+		return fmt.Errorf("%w: model repository name is required", errModelPathInvalid)
+	}
+
+	if strings.Contains(mp.Tag, ":") {
+		return fmt.Errorf("%w: ':' (colon) is not allowed in tag names", errModelPathInvalid)
+	}
+
+	return nil
+}
+
 func (mp ModelPath) GetNamespaceRepository() string {
 	return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
 }

+ 7 - 2
server/routes.go

@@ -416,8 +416,8 @@ func CreateModelHandler(c *gin.Context) {
 		return
 	}
 
-	if strings.Count(req.Name, ":") > 1 {
-		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "':' (colon) is not allowed in tag names"})
+	if err := ParseModelPath(req.Name).Validate(); err != nil {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 	}
 
@@ -645,6 +645,11 @@ func CopyModelHandler(c *gin.Context) {
 		return
 	}
 
+	if err := ParseModelPath(req.Destination).Validate(); err != nil {
+		c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
+	}
+
 	if err := CopyModel(req.Source, req.Destination); err != nil {
 		if os.IsNotExist(err) {
 			c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Source)})