shim_test.go 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. package llm
  2. import (
  3. "testing"
  4. "github.com/jmorganca/ollama/gpu"
  5. "github.com/stretchr/testify/assert"
  6. )
  7. func TestGetShims(t *testing.T) {
  8. availableShims = map[string]string{
  9. "cpu": "X_cpu",
  10. }
  11. assert.Equal(t, false, rocmShimPresent())
  12. res := getShims(gpu.GpuInfo{Library: "cpu"})
  13. assert.Len(t, res, 2)
  14. assert.Equal(t, availableShims["cpu"], res[0])
  15. assert.Equal(t, "default", res[1])
  16. availableShims = map[string]string{
  17. "rocm_v5": "X_rocm_v5",
  18. "rocm_v6": "X_rocm_v6",
  19. "cpu": "X_cpu",
  20. }
  21. assert.Equal(t, true, rocmShimPresent())
  22. res = getShims(gpu.GpuInfo{Library: "rocm"})
  23. assert.Len(t, res, 4)
  24. assert.Equal(t, availableShims["rocm_v5"], res[0])
  25. assert.Equal(t, availableShims["rocm_v6"], res[1])
  26. assert.Equal(t, availableShims["cpu"], res[2])
  27. assert.Equal(t, "default", res[3])
  28. res = getShims(gpu.GpuInfo{Library: "rocm", Variant: "v6"})
  29. assert.Len(t, res, 4)
  30. assert.Equal(t, availableShims["rocm_v6"], res[0])
  31. assert.Equal(t, availableShims["rocm_v5"], res[1])
  32. assert.Equal(t, availableShims["cpu"], res[2])
  33. assert.Equal(t, "default", res[3])
  34. res = getShims(gpu.GpuInfo{Library: "cuda"})
  35. assert.Len(t, res, 2)
  36. assert.Equal(t, availableShims["cpu"], res[0])
  37. assert.Equal(t, "default", res[1])
  38. res = getShims(gpu.GpuInfo{Library: "default"})
  39. assert.Len(t, res, 2)
  40. assert.Equal(t, availableShims["cpu"], res[0])
  41. assert.Equal(t, "default", res[1])
  42. availableShims = map[string]string{
  43. "rocm": "X_rocm_v5",
  44. "cpu": "X_cpu",
  45. }
  46. assert.Equal(t, true, rocmShimPresent())
  47. res = getShims(gpu.GpuInfo{Library: "rocm", Variant: "v6"})
  48. assert.Len(t, res, 3)
  49. assert.Equal(t, availableShims["rocm"], res[0])
  50. assert.Equal(t, availableShims["cpu"], res[1])
  51. assert.Equal(t, "default", res[2])
  52. }