Преглед на файлове

server: validate template (#5734)

add template validation to modelfile
Josh преди 9 месеца
родител
ревизия
e8b954c646
променени са 3 файла, в които са добавени 53 реда и са изтрити 3 реда
  1. 6 0
      server/images.go
  2. 11 3
      server/routes.go
  3. 36 0
      server/routes_create_test.go

+ 6 - 0
server/images.go

@@ -492,6 +492,12 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
 				layers = append(layers, baseLayer.Layer)
 			}
 		case "license", "template", "system":
+			if c.Name == "template" {
+				if _, err := template.Parse(c.Args); err != nil {
+					return fmt.Errorf("%w: %s", errBadTemplate, err)
+				}
+			}
+
 			if c.Name != "license" {
 				// replace
 				layers = slices.DeleteFunc(layers, func(layer *Layer) bool {

+ 11 - 3
server/routes.go

@@ -56,6 +56,7 @@ func init() {
 }
 
 var errRequired = errors.New("is required")
+var errBadTemplate = errors.New("template error")
 
 func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
 	opts := api.DefaultOptions()
@@ -609,8 +610,11 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
 
 		quantization := cmp.Or(r.Quantize, r.Quantization)
 		if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil {
+			if errors.Is(err, errBadTemplate) {
+			  ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
+			}
 			ch <- gin.H{"error": err.Error()}
-		}
+		  }
 	}()
 
 	if r.Stream != nil && !*r.Stream {
@@ -1196,11 +1200,15 @@ func waitForStream(c *gin.Context, ch chan interface{}) {
 				return
 			}
 		case gin.H:
+			status, ok := r["status"].(int)
+			if !ok {
+				status = http.StatusInternalServerError
+			}
 			if errorMsg, ok := r["error"].(string); ok {
-				c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
+				c.JSON(status, gin.H{"error": errorMsg})
 				return
 			} else {
-				c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in progress response"})
+				c.JSON(status, gin.H{"error": "unexpected error format in progress response"})
 				return
 			}
 		default:

+ 36 - 0
server/routes_create_test.go

@@ -491,6 +491,42 @@ func TestCreateTemplateSystem(t *testing.T) {
 	if string(system) != "Say bye!" {
 		t.Errorf("expected \"Say bye!\", actual %s", system)
 	}
+
+	t.Run("incomplete template", func(t *testing.T) {
+		w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
+			Name:      "test",
+			Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt", createBinFile(t, nil, nil)),
+			Stream:    &stream,
+		})
+	
+		if w.Code != http.StatusBadRequest {
+			t.Fatalf("expected status code 400, actual %d", w.Code)
+		}
+	})
+
+	t.Run("template with unclosed if", func(t *testing.T) {
+		w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
+			Name:      "test",
+			Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ if .Prompt }}", createBinFile(t, nil, nil)),
+			Stream:    &stream,
+		})
+	
+		if w.Code != http.StatusBadRequest {
+			t.Fatalf("expected status code 400, actual %d", w.Code)
+		}
+	})
+
+	t.Run("template with undefined function", func(t *testing.T) {
+		w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
+			Name:      "test",
+			Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{  Prompt }}", createBinFile(t, nil, nil)),
+			Stream:    &stream,
+		})
+	
+		if w.Code != http.StatusBadRequest {
+			t.Fatalf("expected status code 400, actual %d", w.Code)
+		}
+	})
 }
 
 func TestCreateLicenses(t *testing.T) {