Просмотр исходного кода

gpu: Group GPU Library sets by variant (#6483)

The recent cuda variant changes uncovered a bug in ByLibrary
which failed to group by common variant for GPU types.
Daniel Hiltgen 8 месяцев назад
Родитель
Сommit
69be940bf6
2 измененных файлов с 26 добавлено и 1 удалено
  1. 25 0
      gpu/gpu_test.go
  2. 1 1
      gpu/types.go

+ 25 - 0
gpu/gpu_test.go

@@ -32,4 +32,29 @@ func TestCPUMemInfo(t *testing.T) {
 	}
 }
 
+func TestByLibrary(t *testing.T) {
+	type testCase struct {
+		input  []GpuInfo
+		expect int
+	}
+
+	testCases := map[string]*testCase{
+		"empty":                    {input: []GpuInfo{}, expect: 0},
+		"cpu":                      {input: []GpuInfo{{Library: "cpu"}}, expect: 1},
+		"cpu + GPU":                {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda"}}, expect: 2},
+		"cpu + 2 GPU no variant":   {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda"}, {Library: "cuda"}}, expect: 2},
+		"cpu + 2 GPU same variant": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda", Variant: "v11"}, {Library: "cuda", Variant: "v11"}}, expect: 2},
+		"cpu + 2 GPU diff variant": {input: []GpuInfo{{Library: "cpu"}, {Library: "cuda", Variant: "v11"}, {Library: "cuda", Variant: "v12"}}, expect: 3},
+	}
+
+	for k, v := range testCases {
+		t.Run(k, func(t *testing.T) {
+			resp := (GpuInfoList)(v.input).ByLibrary()
+			if len(resp) != v.expect {
+				t.Fatalf("expected length %d, got %d => %+v", v.expect, len(resp), resp)
+			}
+		})
+	}
+}
+
 // TODO - add some logic to figure out card type through other means and actually verify we got back what we expected

+ 1 - 1
gpu/types.go

@@ -94,7 +94,7 @@ func (l GpuInfoList) ByLibrary() []GpuInfoList {
 			}
 		}
 		if !found {
-			libs = append(libs, info.Library)
+			libs = append(libs, requested)
 			resp = append(resp, []GpuInfo{info})
 		}
 	}