Prechádzať zdrojové kódy

replace `reflect` usage in option parsing

Jeffrey Morgan 1 rok pred
rodič
commit
a859f037da
2 zmenil súbory, kde vykonal 12 pridanie a 89 odobranie
  1. 9 74
      api/types.go
  2. 3 15
      server/routes.go

+ 9 - 74
api/types.go

@@ -279,85 +279,20 @@ func (m *Metrics) Summary() {
 var ErrInvalidOpts = fmt.Errorf("invalid options")
 
 func (opts *Options) FromMap(m map[string]interface{}) error {
-	valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
-	typeOpts := reflect.TypeOf(opts).Elem()   // types of the fields in the options struct
-
-	// build map of json struct tags to their types
-	jsonOpts := make(map[string]reflect.StructField)
-	for _, field := range reflect.VisibleFields(typeOpts) {
-		jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
-		if jsonTag != "" {
-			jsonOpts[jsonTag] = field
-		}
+	data, err := json.Marshal(m)
+	if err != nil {
+		return err
 	}
 
-	invalidOpts := []string{}
-	for key, val := range m {
-		if opt, ok := jsonOpts[key]; ok {
-			field := valueOpts.FieldByName(opt.Name)
-			if field.IsValid() && field.CanSet() {
-				if val == nil {
-					continue
-				}
-
-				switch field.Kind() {
-				case reflect.Int:
-					switch t := val.(type) {
-					case int64:
-						field.SetInt(t)
-					case float64:
-						// when JSON unmarshals numbers, it uses float64, not int
-						field.SetInt(int64(t))
-					default:
-						return fmt.Errorf("option %q must be of type integer", key)
-					}
-				case reflect.Bool:
-					val, ok := val.(bool)
-					if !ok {
-						return fmt.Errorf("option %q must be of type boolean", key)
-					}
-					field.SetBool(val)
-				case reflect.Float32:
-					// JSON unmarshals to float64
-					val, ok := val.(float64)
-					if !ok {
-						return fmt.Errorf("option %q must be of type float32", key)
-					}
-					field.SetFloat(val)
-				case reflect.String:
-					val, ok := val.(string)
-					if !ok {
-						return fmt.Errorf("option %q must be of type string", key)
-					}
-					field.SetString(val)
-				case reflect.Slice:
-					// JSON unmarshals to []interface{}, not []string
-					val, ok := val.([]interface{})
-					if !ok {
-						return fmt.Errorf("option %q must be of type array", key)
-					}
-					// convert []interface{} to []string
-					slice := make([]string, len(val))
-					for i, item := range val {
-						str, ok := item.(string)
-						if !ok {
-							return fmt.Errorf("option %q must be of an array of strings", key)
-						}
-						slice[i] = str
-					}
-					field.Set(reflect.ValueOf(slice))
-				default:
-					return fmt.Errorf("unknown type loading config params: %v", field.Kind())
-				}
-			}
-		} else {
-			invalidOpts = append(invalidOpts, key)
+	err = json.Unmarshal(data, opts)
+	if err != nil {
+		// Custom error handling
+		if jsonErr, ok := err.(*json.UnmarshalTypeError); ok {
+			return fmt.Errorf("invalid type for option '%v': expected %v, got %v", jsonErr.Field, jsonErr.Type, jsonErr.Value)
 		}
+		return err
 	}
 
-	if len(invalidOpts) > 0 {
-		return fmt.Errorf("%w: %v", ErrInvalidOpts, strings.Join(invalidOpts, ", "))
-	}
 	return nil
 }
 

+ 3 - 15
server/routes.go

@@ -178,11 +178,7 @@ func GenerateHandler(c *gin.Context) {
 
 	opts, err := modelOptions(model, req.Options)
 	if err != nil {
-		if errors.Is(err, api.ErrInvalidOpts) {
-			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
-			return
-		}
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 	}
 
@@ -396,11 +392,7 @@ func EmbeddingHandler(c *gin.Context) {
 
 	opts, err := modelOptions(model, req.Options)
 	if err != nil {
-		if errors.Is(err, api.ErrInvalidOpts) {
-			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
-			return
-		}
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 	}
 
@@ -1112,11 +1104,7 @@ func ChatHandler(c *gin.Context) {
 
 	opts, err := modelOptions(model, req.Options)
 	if err != nil {
-		if errors.Is(err, api.ErrInvalidOpts) {
-			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
-			return
-		}
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 	}