Преглед изворни кода

Add MarshalJSON to Duration (#3284)

---------

Co-authored-by: Patrick Devine <patrick@infrahq.com>
Jackie Li пре 1 година
родитељ
комит
af47413dba
2 измењених фајлова са 67 додато и 1 уклоњено
  1. 10 1
      api/types.go
  2. 57 0
      api/types_test.go

+ 10 - 1
api/types.go

@@ -436,6 +436,13 @@ type Duration struct {
 	time.Duration
 }
 
+func (d Duration) MarshalJSON() ([]byte, error) {
+	if d.Duration < 0 {
+		return []byte("-1"), nil
+	}
+	return []byte("\"" + d.Duration.String() + "\""), nil
+}
+
 func (d *Duration) UnmarshalJSON(b []byte) (err error) {
 	var v any
 	if err := json.Unmarshal(b, &v); err != nil {
@@ -449,7 +456,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
 		if t < 0 {
 			d.Duration = time.Duration(math.MaxInt64)
 		} else {
-			d.Duration = time.Duration(t * float64(time.Second))
+			d.Duration = time.Duration(int(t) * int(time.Second))
 		}
 	case string:
 		d.Duration, err = time.ParseDuration(t)
@@ -459,6 +466,8 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
 		if d.Duration < 0 {
 			d.Duration = time.Duration(math.MaxInt64)
 		}
+	default:
+		return fmt.Errorf("Unsupported type: '%s'", reflect.TypeOf(v))
 	}
 
 	return nil

+ 57 - 0
api/types_test.go

@@ -21,6 +21,11 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
 			req:  `{ "keep_alive": 42 }`,
 			exp:  &Duration{42 * time.Second},
 		},
+		{
+			name: "Positive Float",
+			req:  `{ "keep_alive": 42.5 }`,
+			exp:  &Duration{42 * time.Second},
+		},
 		{
 			name: "Positive Integer String",
 			req:  `{ "keep_alive": "42m" }`,
@@ -31,6 +36,11 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
 			req:  `{ "keep_alive": -1 }`,
 			exp:  &Duration{math.MaxInt64},
 		},
+		{
+			name: "Negative Float",
+			req:  `{ "keep_alive": -3.14 }`,
+			exp:  &Duration{math.MaxInt64},
+		},
 		{
 			name: "Negative Integer String",
 			req:  `{ "keep_alive": "-1m" }`,
@@ -48,3 +58,50 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
 		})
 	}
 }
+
+func TestDurationMarshalUnmarshal(t *testing.T) {
+	tests := []struct {
+		name     string
+		input    time.Duration
+		expected time.Duration
+	}{
+		{
+			"negative duration",
+			time.Duration(-1),
+			time.Duration(math.MaxInt64),
+		},
+		{
+			"positive duration",
+			time.Duration(42 * time.Second),
+			time.Duration(42 * time.Second),
+		},
+		{
+			"another positive duration",
+			time.Duration(42 * time.Minute),
+			time.Duration(42 * time.Minute),
+		},
+		{
+			"zero duration",
+			time.Duration(0),
+			time.Duration(0),
+		},
+		{
+			"max duration",
+			time.Duration(math.MaxInt64),
+			time.Duration(math.MaxInt64),
+		},
+	}
+
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			b, err := json.Marshal(Duration{test.input})
+			require.NoError(t, err)
+
+			var d Duration
+			err = json.Unmarshal(b, &d)
+			require.NoError(t, err)
+
+			assert.Equal(t, test.expected, d.Duration, "input %v, marshalled %v, got %v", test.input, string(b), d.Duration)
+		})
+	}
+}