llm_test.go 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. //go:build integration
  2. package server
  3. import (
  4. "context"
  5. "os"
  6. "strings"
  7. "sync"
  8. "testing"
  9. "time"
  10. "github.com/stretchr/testify/assert"
  11. "github.com/stretchr/testify/require"
  12. "github.com/jmorganca/ollama/api"
  13. "github.com/jmorganca/ollama/llm"
  14. )
  15. // TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server
  16. // package to avoid circular dependencies
  17. // WARNING - these tests will fail on mac if you don't manually copy ggml-metal.metal to this dir (./server)
  18. //
  19. // TODO - Fix this ^^
  20. var (
  21. req = [2]api.GenerateRequest{
  22. {
  23. Model: "orca-mini",
  24. Prompt: "tell me a short story about agi?",
  25. Options: map[string]interface{}{},
  26. }, {
  27. Model: "orca-mini",
  28. Prompt: "what is the origin of the us thanksgiving holiday?",
  29. Options: map[string]interface{}{},
  30. },
  31. }
  32. resp = [2]string{
  33. "once upon a time",
  34. "united states thanksgiving",
  35. }
  36. )
  37. func TestIntegrationSimpleOrcaMini(t *testing.T) {
  38. SkipIFNoTestData(t)
  39. workDir, err := os.MkdirTemp("", "ollama")
  40. require.NoError(t, err)
  41. defer os.RemoveAll(workDir)
  42. require.NoError(t, llm.Init(workDir))
  43. ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
  44. defer cancel()
  45. opts := api.DefaultOptions()
  46. opts.Seed = 42
  47. opts.Temperature = 0.0
  48. model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts)
  49. defer llmRunner.Close()
  50. response := OneShotPromptResponse(t, ctx, req[0], model, llmRunner)
  51. assert.Contains(t, strings.ToLower(response), resp[0])
  52. }
  53. // TODO
  54. // The server always loads a new runner and closes the old one, which forces serial execution
  55. // At present this test case fails with concurrency problems. Eventually we should try to
  56. // get true concurrency working with n_parallel support in the backend
  57. func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
  58. SkipIFNoTestData(t)
  59. t.Skip("concurrent prediction on single runner not currently supported")
  60. workDir, err := os.MkdirTemp("", "ollama")
  61. require.NoError(t, err)
  62. defer os.RemoveAll(workDir)
  63. require.NoError(t, llm.Init(workDir))
  64. ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
  65. defer cancel()
  66. opts := api.DefaultOptions()
  67. opts.Seed = 42
  68. opts.Temperature = 0.0
  69. var wg sync.WaitGroup
  70. wg.Add(len(req))
  71. model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts)
  72. defer llmRunner.Close()
  73. for i := 0; i < len(req); i++ {
  74. go func(i int) {
  75. defer wg.Done()
  76. response := OneShotPromptResponse(t, ctx, req[i], model, llmRunner)
  77. t.Logf("Prompt: %s\nResponse: %s", req[0].Prompt, response)
  78. assert.Contains(t, strings.ToLower(response), resp[i], "error in thread %d (%s)", i, req[i].Prompt)
  79. }(i)
  80. }
  81. wg.Wait()
  82. }
  83. func TestIntegrationConcurrentRunnersOrcaMini(t *testing.T) {
  84. SkipIFNoTestData(t)
  85. workDir, err := os.MkdirTemp("", "ollama")
  86. require.NoError(t, err)
  87. defer os.RemoveAll(workDir)
  88. require.NoError(t, llm.Init(workDir))
  89. ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
  90. defer cancel()
  91. opts := api.DefaultOptions()
  92. opts.Seed = 42
  93. opts.Temperature = 0.0
  94. var wg sync.WaitGroup
  95. wg.Add(len(req))
  96. t.Logf("Running %d concurrently", len(req))
  97. for i := 0; i < len(req); i++ {
  98. go func(i int) {
  99. defer wg.Done()
  100. model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts)
  101. defer llmRunner.Close()
  102. response := OneShotPromptResponse(t, ctx, req[i], model, llmRunner)
  103. t.Logf("Prompt: %s\nResponse: %s", req[0].Prompt, response)
  104. assert.Contains(t, strings.ToLower(response), resp[i], "error in thread %d (%s)", i, req[i].Prompt)
  105. }(i)
  106. }
  107. wg.Wait()
  108. }
  109. // TODO - create a parallel test with 2 different models once we support concurrency