Sfoglia il codice sorgente

validate api options fields from map (#711)

Bruce MacDonald 1 anno fa
parent
commit
7804b8fab9
2 ha cambiato i file con 13 aggiunte e 1 eliminazioni
  1. 9 0
      api/types.go
  2. 4 1
      server/routes.go

+ 9 - 0
api/types.go

@@ -205,6 +205,8 @@ type Options struct {
 	NumThread int `json:"num_thread,omitempty"`
 }
 
+var ErrInvalidOpts = fmt.Errorf("invalid options")
+
 func (opts *Options) FromMap(m map[string]interface{}) error {
 	valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
 	typeOpts := reflect.TypeOf(opts).Elem()   // types of the fields in the options struct
@@ -218,6 +220,7 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
 		}
 	}
 
+	invalidOpts := []string{}
 	for key, val := range m {
 		if opt, ok := jsonOpts[key]; ok {
 			field := valueOpts.FieldByName(opt.Name)
@@ -281,8 +284,14 @@ func (opts *Options) FromMap(m map[string]interface{}) error {
 					return fmt.Errorf("unknown type loading config params: %v", field.Kind())
 				}
 			}
+		} else {
+			invalidOpts = append(invalidOpts, key)
 		}
 	}
+
+	if len(invalidOpts) > 0 {
+		return fmt.Errorf("%w: %v", ErrInvalidOpts, strings.Join(invalidOpts, ", "))
+	}
 	return nil
 }
 

+ 4 - 1
server/routes.go

@@ -68,7 +68,6 @@ func load(ctx context.Context, workDir string, model *Model, reqOpts map[string]
 	}
 
 	if err := opts.FromMap(reqOpts); err != nil {
-		log.Printf("could not merge model options: %v", err)
 		return err
 	}
 
@@ -186,6 +185,10 @@ func GenerateHandler(c *gin.Context) {
 	// TODO: set this duration from the request if specified
 	sessionDuration := defaultSessionDuration
 	if err := load(c.Request.Context(), workDir, model, req.Options, sessionDuration); err != nil {
+		if errors.Is(err, api.ErrInvalidOpts) {
+			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+			return
+		}
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
 	}