llm_test.go 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. package server
  2. import (
  3. "context"
  4. "strings"
  5. "sync"
  6. "testing"
  7. "time"
  8. "github.com/stretchr/testify/assert"
  9. "github.com/jmorganca/ollama/api"
  10. )
  11. // TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server
  12. // package to avoid circular dependencies
  13. // WARNING - these tests will fail on mac if you don't manually copy ggml-metal.metal to this dir (./server)
  14. //
  15. // TODO - Fix this ^^
  16. var (
  17. req = [2]api.GenerateRequest{
  18. {
  19. Model: "orca-mini",
  20. Prompt: "tell me a short story about agi?",
  21. Options: map[string]interface{}{},
  22. }, {
  23. Model: "orca-mini",
  24. Prompt: "what is the origin of the us thanksgiving holiday?",
  25. Options: map[string]interface{}{},
  26. },
  27. }
  28. resp = [2]string{
  29. "once upon a time",
  30. "fourth thursday",
  31. }
  32. )
  33. func TestIntegrationSimpleOrcaMini(t *testing.T) {
  34. SkipIFNoTestData(t)
  35. ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
  36. defer cancel()
  37. opts := api.DefaultOptions()
  38. opts.Seed = 42
  39. opts.Temperature = 0.0
  40. model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts)
  41. defer llmRunner.Close()
  42. response := OneShotPromptResponse(t, ctx, req[0], model, llmRunner)
  43. assert.Contains(t, strings.ToLower(response), resp[0])
  44. }
  45. // TODO
  46. // The server always loads a new runner and closes the old one, which forces serial execution
  47. // At present this test case fails with concurrency problems. Eventually we should try to
  48. // get true concurrency working with n_parallel support in the backend
  49. func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
  50. SkipIFNoTestData(t)
  51. t.Skip("concurrent prediction on single runner not currently supported")
  52. ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
  53. defer cancel()
  54. opts := api.DefaultOptions()
  55. opts.Seed = 42
  56. opts.Temperature = 0.0
  57. var wg sync.WaitGroup
  58. wg.Add(len(req))
  59. model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts)
  60. defer llmRunner.Close()
  61. for i := 0; i < len(req); i++ {
  62. go func(i int) {
  63. defer wg.Done()
  64. response := OneShotPromptResponse(t, ctx, req[i], model, llmRunner)
  65. t.Logf("Prompt: %s\nResponse: %s", req[0].Prompt, response)
  66. assert.Contains(t, strings.ToLower(response), resp[i], "error in thread %d (%s)", i, req[i].Prompt)
  67. }(i)
  68. }
  69. wg.Wait()
  70. }
  71. func TestIntegrationConcurrentRunnersOrcaMini(t *testing.T) {
  72. SkipIFNoTestData(t)
  73. ctx, cancel := context.WithTimeout(context.Background(), time.Second*60)
  74. defer cancel()
  75. opts := api.DefaultOptions()
  76. opts.Seed = 42
  77. opts.Temperature = 0.0
  78. var wg sync.WaitGroup
  79. wg.Add(len(req))
  80. for i := 0; i < len(req); i++ {
  81. go func(i int) {
  82. defer wg.Done()
  83. model, llmRunner := PrepareModelForPrompts(t, req[0].Model, opts)
  84. defer llmRunner.Close()
  85. response := OneShotPromptResponse(t, ctx, req[i], model, llmRunner)
  86. t.Logf("Prompt: %s\nResponse: %s", req[0].Prompt, response)
  87. assert.Contains(t, strings.ToLower(response), resp[i], "error in thread %d (%s)", i, req[i].Prompt)
  88. }(i)
  89. }
  90. wg.Wait()
  91. }
  92. // TODO - create a parallel test with 2 different models once we support concurrency