llm_test.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. package server
  2. import (
  3. "context"
  4. "os"
  5. "strings"
  6. "sync"
  7. "testing"
  8. "time"
  9. "github.com/stretchr/testify/assert"
  10. "github.com/stretchr/testify/require"
  11. "github.com/jmorganca/ollama/api"
  12. "github.com/jmorganca/ollama/llm"
  13. )
  14. // TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server
  15. // package to avoid circular dependencies
  16. // WARNING - these tests will fail on mac if you don't manually copy ggml-metal.metal to this dir (./server)
  17. //
  18. // TODO - Fix this ^^
  19. var (
  20. req = [2]api.GenerateRequest{
  21. {
  22. Model: "orca-mini",
  23. Prompt: "tell me a short story about agi?",
  24. Options: map[string]interface{}{},
  25. }, {
  26. Model: "orca-mini",
  27. Prompt: "what is the origin of the us thanksgiving holiday?",
  28. Options: map[string]interface{}{},
  29. },
  30. }
  31. resp = [2]string{
  32. "once upon a time",
  33. "united states thanksgiving",
  34. }
  35. )
  36. func TestIntegrationSimpleOrcaMini(t *testing.T) {
  37. SkipIFNoTestData(t)
  38. workDir, err := os.MkdirTemp("", "ollama")
  39. require.NoError(t, err)
  40. defer os.RemoveAll(workDir)
  41. require.NoError(t, llm.Init(workDir))
  42. ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
  43. defer cancel()
  44. opts := api.DefaultOptions()
  45. opts.Seed = 42
  46. opts.Temperature = 0.0
  47. model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts)
  48. defer llmRunner.Close()
  49. response := OneShotPromptResponse(t, ctx, req[0], model, llmRunner)
  50. assert.Contains(t, strings.ToLower(response), resp[0])
  51. }
  52. // TODO
  53. // The server always loads a new runner and closes the old one, which forces serial execution
  54. // At present this test case fails with concurrency problems. Eventually we should try to
  55. // get true concurrency working with n_parallel support in the backend
  56. func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
  57. SkipIFNoTestData(t)
  58. t.Skip("concurrent prediction on single runner not currently supported")
  59. workDir, err := os.MkdirTemp("", "ollama")
  60. require.NoError(t, err)
  61. defer os.RemoveAll(workDir)
  62. require.NoError(t, llm.Init(workDir))
  63. ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
  64. defer cancel()
  65. opts := api.DefaultOptions()
  66. opts.Seed = 42
  67. opts.Temperature = 0.0
  68. var wg sync.WaitGroup
  69. wg.Add(len(req))
  70. model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts)
  71. defer llmRunner.Close()
  72. for i := 0; i < len(req); i++ {
  73. go func(i int) {
  74. defer wg.Done()
  75. response := OneShotPromptResponse(t, ctx, req[i], model, llmRunner)
  76. t.Logf("Prompt: %s\nResponse: %s", req[0].Prompt, response)
  77. assert.Contains(t, strings.ToLower(response), resp[i], "error in thread %d (%s)", i, req[i].Prompt)
  78. }(i)
  79. }
  80. wg.Wait()
  81. }
  82. func TestIntegrationConcurrentRunnersOrcaMini(t *testing.T) {
  83. SkipIFNoTestData(t)
  84. workDir, err := os.MkdirTemp("", "ollama")
  85. require.NoError(t, err)
  86. defer os.RemoveAll(workDir)
  87. require.NoError(t, llm.Init(workDir))
  88. ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
  89. defer cancel()
  90. opts := api.DefaultOptions()
  91. opts.Seed = 42
  92. opts.Temperature = 0.0
  93. var wg sync.WaitGroup
  94. wg.Add(len(req))
  95. t.Logf("Running %d concurrently", len(req))
  96. for i := 0; i < len(req); i++ {
  97. go func(i int) {
  98. defer wg.Done()
  99. model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts)
  100. defer llmRunner.Close()
  101. response := OneShotPromptResponse(t, ctx, req[i], model, llmRunner)
  102. t.Logf("Prompt: %s\nResponse: %s", req[0].Prompt, response)
  103. assert.Contains(t, strings.ToLower(response), resp[i], "error in thread %d (%s)", i, req[i].Prompt)
  104. }(i)
  105. }
  106. wg.Wait()
  107. }
  108. // TODO - create a parallel test with 2 different models once we support concurrency