concurrency_test.go 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. //go:build integration
  2. package integration
  3. import (
  4. "context"
  5. "log/slog"
  6. "os"
  7. "strconv"
  8. "sync"
  9. "testing"
  10. "time"
  11. "github.com/ollama/ollama/api"
  12. "github.com/stretchr/testify/require"
  13. )
  14. func TestMultiModelConcurrency(t *testing.T) {
  15. var (
  16. req = [2]api.GenerateRequest{
  17. {
  18. Model: "orca-mini",
  19. Prompt: "why is the ocean blue?",
  20. Stream: &stream,
  21. Options: map[string]interface{}{
  22. "seed": 42,
  23. "temperature": 0.0,
  24. },
  25. }, {
  26. Model: "tinydolphin",
  27. Prompt: "what is the origin of the us thanksgiving holiday?",
  28. Stream: &stream,
  29. Options: map[string]interface{}{
  30. "seed": 42,
  31. "temperature": 0.0,
  32. },
  33. },
  34. }
  35. resp = [2][]string{
  36. []string{"sunlight"},
  37. []string{"england", "english", "massachusetts", "pilgrims", "british"},
  38. }
  39. )
  40. var wg sync.WaitGroup
  41. wg.Add(len(req))
  42. ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
  43. defer cancel()
  44. client, _, cleanup := InitServerConnection(ctx, t)
  45. defer cleanup()
  46. for i := 0; i < len(req); i++ {
  47. require.NoError(t, PullIfMissing(ctx, client, req[i].Model))
  48. }
  49. for i := 0; i < len(req); i++ {
  50. go func(i int) {
  51. defer wg.Done()
  52. DoGenerate(ctx, t, client, req[i], resp[i], 30*time.Second, 10*time.Second)
  53. }(i)
  54. }
  55. wg.Wait()
  56. }
  57. func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
  58. ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) // GTX 750 2G card takes ~9 minutes
  59. defer cancel()
  60. client, _, cleanup := InitServerConnection(ctx, t)
  61. defer cleanup()
  62. req, resp := GenerateRequests()
  63. // Get the server running (if applicable) warm the model up with a single initial request
  64. DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 5*time.Second)
  65. var wg sync.WaitGroup
  66. wg.Add(len(req))
  67. for i := 0; i < len(req); i++ {
  68. go func(i int) {
  69. defer wg.Done()
  70. for j := 0; j < 5; j++ {
  71. slog.Info("Starting", "req", i, "iter", j)
  72. // On slower GPUs it can take a while to process the 4 concurrent requests
  73. // so we allow a much longer initial timeout
  74. DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
  75. }
  76. }(i)
  77. }
  78. wg.Wait()
  79. }
  80. // Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
  81. func TestMultiModelStress(t *testing.T) {
  82. vram := os.Getenv("OLLAMA_MAX_VRAM")
  83. if vram == "" {
  84. t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
  85. }
  86. max, err := strconv.ParseUint(vram, 10, 64)
  87. require.NoError(t, err)
  88. const MB = uint64(1024 * 1024)
  89. type model struct {
  90. name string
  91. size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM
  92. }
  93. smallModels := []model{
  94. {
  95. name: "orca-mini",
  96. size: 2992 * MB,
  97. },
  98. {
  99. name: "phi",
  100. size: 2616 * MB,
  101. },
  102. {
  103. name: "gemma:2b",
  104. size: 2364 * MB,
  105. },
  106. {
  107. name: "stable-code:3b",
  108. size: 2608 * MB,
  109. },
  110. {
  111. name: "starcoder2:3b",
  112. size: 2166 * MB,
  113. },
  114. }
  115. mediumModels := []model{
  116. {
  117. name: "llama2",
  118. size: 5118 * MB,
  119. },
  120. {
  121. name: "mistral",
  122. size: 4620 * MB,
  123. },
  124. {
  125. name: "orca-mini:7b",
  126. size: 5118 * MB,
  127. },
  128. {
  129. name: "dolphin-mistral",
  130. size: 4620 * MB,
  131. },
  132. {
  133. name: "gemma:7b",
  134. size: 5000 * MB,
  135. },
  136. // TODO - uncomment this once #3565 is merged and this is rebased on it
  137. // {
  138. // name: "codellama:7b",
  139. // size: 5118 * MB,
  140. // },
  141. }
  142. // These seem to be too slow to be useful...
  143. // largeModels := []model{
  144. // {
  145. // name: "llama2:13b",
  146. // size: 7400 * MB,
  147. // },
  148. // {
  149. // name: "codellama:13b",
  150. // size: 7400 * MB,
  151. // },
  152. // {
  153. // name: "orca-mini:13b",
  154. // size: 7400 * MB,
  155. // },
  156. // {
  157. // name: "gemma:7b",
  158. // size: 5000 * MB,
  159. // },
  160. // {
  161. // name: "starcoder2:15b",
  162. // size: 9100 * MB,
  163. // },
  164. // }
  165. var chosenModels []model
  166. switch {
  167. case max < 10000*MB:
  168. slog.Info("selecting small models")
  169. chosenModels = smallModels
  170. // case max < 30000*MB:
  171. default:
  172. slog.Info("selecting medium models")
  173. chosenModels = mediumModels
  174. // default:
  175. // slog.Info("selecting large models")
  176. // chosenModels = largModels
  177. }
  178. req, resp := GenerateRequests()
  179. for i := range req {
  180. if i > len(chosenModels) {
  181. break
  182. }
  183. req[i].Model = chosenModels[i].name
  184. }
  185. ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) // TODO baseline -- 10m too short
  186. defer cancel()
  187. client, _, cleanup := InitServerConnection(ctx, t)
  188. defer cleanup()
  189. // Make sure all the models are pulled before we get started
  190. for _, r := range req {
  191. require.NoError(t, PullIfMissing(ctx, client, r.Model))
  192. }
  193. var wg sync.WaitGroup
  194. consumed := uint64(256 * MB) // Assume some baseline usage
  195. for i := 0; i < len(req); i++ {
  196. // Always get at least 2 models, but dont' overshoot VRAM too much or we'll take too long
  197. if i > 1 && consumed > max {
  198. slog.Info("achieved target vram exhaustion", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
  199. break
  200. }
  201. consumed += chosenModels[i].size
  202. slog.Info("target vram", "count", i, "vramMB", max/1024/1024, "modelsMB", consumed/1024/1024)
  203. wg.Add(1)
  204. go func(i int) {
  205. defer wg.Done()
  206. for j := 0; j < 3; j++ {
  207. slog.Info("Starting", "req", i, "iter", j, "model", req[i].Model)
  208. DoGenerate(ctx, t, client, req[i], resp[i], 120*time.Second, 5*time.Second)
  209. }
  210. }(i)
  211. }
  212. go func() {
  213. for {
  214. time.Sleep(2 * time.Second)
  215. select {
  216. case <-ctx.Done():
  217. return
  218. default:
  219. models, err := client.ListRunning(ctx)
  220. if err != nil {
  221. slog.Warn("failed to list running models", "error", err)
  222. continue
  223. }
  224. for _, m := range models.Models {
  225. slog.Info("loaded model snapshot", "model", m)
  226. }
  227. }
  228. }
  229. }()
  230. wg.Wait()
  231. }