123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354 |
- package llm
- import (
- "testing"
- "github.com/jmorganca/ollama/gpu"
- "github.com/stretchr/testify/assert"
- )
- func TestGetShims(t *testing.T) {
- availableShims = map[string]string{
- "cpu": "X_cpu",
- }
- assert.Equal(t, false, rocmShimPresent())
- res := getShims(gpu.GpuInfo{Library: "cpu"})
- assert.Len(t, res, 1)
- assert.Equal(t, availableShims["cpu"], res[0])
- availableShims = map[string]string{
- "rocm_v5": "X_rocm_v5",
- "rocm_v6": "X_rocm_v6",
- "cpu": "X_cpu",
- }
- assert.Equal(t, true, rocmShimPresent())
- res = getShims(gpu.GpuInfo{Library: "rocm"})
- assert.Len(t, res, 3)
- assert.Equal(t, availableShims["rocm_v5"], res[0])
- assert.Equal(t, availableShims["rocm_v6"], res[1])
- assert.Equal(t, availableShims["cpu"], res[2])
- res = getShims(gpu.GpuInfo{Library: "rocm", Variant: "v6"})
- assert.Len(t, res, 3)
- assert.Equal(t, availableShims["rocm_v6"], res[0])
- assert.Equal(t, availableShims["rocm_v5"], res[1])
- assert.Equal(t, availableShims["cpu"], res[2])
- res = getShims(gpu.GpuInfo{Library: "cuda"})
- assert.Len(t, res, 1)
- assert.Equal(t, availableShims["cpu"], res[0])
- res = getShims(gpu.GpuInfo{Library: "default"})
- assert.Len(t, res, 1)
- assert.Equal(t, "default", res[0])
- availableShims = map[string]string{
- "rocm": "X_rocm_v5",
- "cpu": "X_cpu",
- }
- assert.Equal(t, true, rocmShimPresent())
- res = getShims(gpu.GpuInfo{Library: "rocm", Variant: "v6"})
- assert.Len(t, res, 2)
- assert.Equal(t, availableShims["rocm"], res[0])
- assert.Equal(t, availableShims["cpu"], res[1])
- }
|