浏览代码

Fix use_mmap parsing for modelfiles

Add the new tristate parsing logic for the code path for modelfiles,
as well as a unit test.
Daniel Hiltgen 10 月之前
父节点
当前提交
7e7749224c
共有 2 个文件被更改,包括 76 次插入0 次删除
  1. 13 0
      api/types.go
  2. 63 0
      api/types_test.go

+ 13 - 0
api/types.go

@@ -608,6 +608,19 @@ func FormatParams(params map[string][]string) (map[string]interface{}, error) {
 		} else {
 			field := valueOpts.FieldByName(opt.Name)
 			if field.IsValid() && field.CanSet() {
+				if reflect.PointerTo(field.Type()) == reflect.TypeOf((*TriState)(nil)) {
+					boolVal, err := strconv.ParseBool(vals[0])
+					if err != nil {
+						return nil, fmt.Errorf("invalid bool value %s", vals)
+					}
+					if boolVal {
+						out[key] = TriStateTrue
+					} else {
+						out[key] = TriStateFalse
+					}
+					continue
+				}
+
 				switch field.Kind() {
 				case reflect.Float32:
 					floatVal, err := strconv.ParseFloat(vals[0], 32)

+ 63 - 0
api/types_test.go

@@ -2,6 +2,7 @@ package api
 
 import (
 	"encoding/json"
+	"fmt"
 	"math"
 	"testing"
 	"time"
@@ -141,3 +142,65 @@ func TestUseMmapParsingFromJSON(t *testing.T) {
 		})
 	}
 }
+
+func TestUseMmapFormatParams(t *testing.T) {
+	tests := []struct {
+		name string
+		req  map[string][]string
+		exp  TriState
+		err  error
+	}{
+		{
+			name: "True",
+			req: map[string][]string{
+				"use_mmap": []string{"true"},
+			},
+			exp: TriStateTrue,
+			err: nil,
+		},
+		{
+			name: "False",
+			req: map[string][]string{
+				"use_mmap": []string{"false"},
+			},
+			exp: TriStateFalse,
+			err: nil,
+		},
+		{
+			name: "Numeric True",
+			req: map[string][]string{
+				"use_mmap": []string{"1"},
+			},
+			exp: TriStateTrue,
+			err: nil,
+		},
+		{
+			name: "Numeric False",
+			req: map[string][]string{
+				"use_mmap": []string{"0"},
+			},
+			exp: TriStateFalse,
+			err: nil,
+		},
+		{
+			name: "invalid string",
+			req: map[string][]string{
+				"use_mmap": []string{"foo"},
+			},
+			exp: TriStateUndefined,
+			err: fmt.Errorf("invalid bool value [foo]"),
+		},
+	}
+
+	for _, test := range tests {
+		t.Run(test.name, func(t *testing.T) {
+			resp, err := FormatParams(test.req)
+			require.Equal(t, err, test.err)
+			respVal, ok := resp["use_mmap"]
+			if test.exp != TriStateUndefined {
+				assert.True(t, ok, "resp: %v", resp)
+				assert.Equal(t, test.exp, respVal)
+			}
+		})
+	}
+}