concurrency_test.go 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. //go:build integration
  2. package integration
  3. import (
  4. "context"
  5. "log/slog"
  6. "os"
  7. "strconv"
  8. "sync"
  9. "testing"
  10. "time"
  11. "github.com/stretchr/testify/require"
  12. "github.com/ollama/ollama/api"
  13. "github.com/ollama/ollama/format"
  14. )
  15. func TestMultiModelConcurrency(t *testing.T) {
  16. var (
  17. req = [2]api.GenerateRequest{
  18. {
  19. Model: "orca-mini",
  20. Prompt: "why is the ocean blue?",
  21. Stream: &stream,
  22. KeepAlive: &api.Duration{Duration: 10 * time.Second},
  23. Options: map[string]interface{}{
  24. "seed": 42,
  25. "temperature": 0.0,
  26. },
  27. }, {
  28. Model: "tinydolphin",
  29. Prompt: "what is the origin of the us thanksgiving holiday?",
  30. Stream: &stream,
  31. KeepAlive: &api.Duration{Duration: 10 * time.Second},
  32. Options: map[string]interface{}{
  33. "seed": 42,
  34. "temperature": 0.0,
  35. },
  36. },
  37. }
  38. resp = [2][]string{
  39. {"sunlight"},
  40. {"england", "english", "massachusetts", "pilgrims", "british", "festival"},
  41. }
  42. )
  43. var wg sync.WaitGroup
  44. wg.Add(len(req))
  45. ctx, cancel := context.WithTimeout(context.Background(), time.Second*240)
  46. defer cancel()
  47. client, _, cleanup := InitServerConnection(ctx, t)
  48. defer cleanup()
  49. for i := 0; i < len(req); i++ {
  50. require.NoError(t, PullIfMissing(ctx, client, req[i].Model))
  51. }
  52. for i := 0; i < len(req); i++ {
  53. go func(i int) {
  54. defer wg.Done()
  55. // Note: CPU based inference can crawl so don't give up too quickly
  56. DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 30*time.Second)
  57. }(i)
  58. }
  59. wg.Wait()
  60. }
  61. func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
  62. req, resp := GenerateRequests()
  63. reqLimit := len(req)
  64. iterLimit := 5
  65. if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {
  66. maxVram, err := strconv.ParseUint(s, 10, 64)
  67. require.NoError(t, err)
  68. // Don't hammer on small VRAM cards...
  69. if maxVram < 4*format.GibiByte {
  70. reqLimit = min(reqLimit, 2)
  71. iterLimit = 2
  72. }
  73. }
  74. ctx, cancel := context.WithTimeout(context.Background(), 9*time.Minute)
  75. defer cancel()
  76. client, _, cleanup := InitServerConnection(ctx, t)
  77. defer cleanup()
  78. // Get the server running (if applicable) warm the model up with a single initial request
  79. DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 10*time.Second)
  80. var wg sync.WaitGroup
  81. wg.Add(reqLimit)
  82. for i := 0; i < reqLimit; i++ {
  83. go func(i int) {
  84. defer wg.Done()
  85. for j := 0; j < iterLimit; j++ {
  86. slog.Info("Starting", "req", i, "iter", j)
  87. // On slower GPUs it can take a while to process the concurrent requests
  88. // so we allow a much longer initial timeout
  89. DoGenerate(ctx, t, client, req[i], resp[i], 120*time.Second, 20*time.Second)
  90. }
  91. }(i)
  92. }
  93. wg.Wait()
  94. }
  95. // Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
  96. func TestMultiModelStress(t *testing.T) {
  97. s := os.Getenv("OLLAMA_MAX_VRAM") // TODO - discover actual VRAM
  98. if s == "" {
  99. t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
  100. }
  101. maxVram, err := strconv.ParseUint(s, 10, 64)
  102. if err != nil {
  103. t.Fatal(err)
  104. }
  105. type model struct {
  106. name string
  107. size uint64 // Approximate amount of VRAM they typically use when fully loaded in VRAM
  108. }
  109. smallModels := []model{
  110. {
  111. name: "orca-mini",
  112. size: 2992 * format.MebiByte,
  113. },
  114. {
  115. name: "phi",
  116. size: 2616 * format.MebiByte,
  117. },
  118. {
  119. name: "gemma:2b",
  120. size: 2364 * format.MebiByte,
  121. },
  122. {
  123. name: "stable-code:3b",
  124. size: 2608 * format.MebiByte,
  125. },
  126. {
  127. name: "starcoder2:3b",
  128. size: 2166 * format.MebiByte,
  129. },
  130. }
  131. mediumModels := []model{
  132. {
  133. name: "llama2",
  134. size: 5118 * format.MebiByte,
  135. },
  136. {
  137. name: "mistral",
  138. size: 4620 * format.MebiByte,
  139. },
  140. {
  141. name: "orca-mini:7b",
  142. size: 5118 * format.MebiByte,
  143. },
  144. {
  145. name: "dolphin-mistral",
  146. size: 4620 * format.MebiByte,
  147. },
  148. {
  149. name: "gemma:7b",
  150. size: 5000 * format.MebiByte,
  151. },
  152. {
  153. name: "codellama:7b",
  154. size: 5118 * format.MebiByte,
  155. },
  156. }
  157. // These seem to be too slow to be useful...
  158. // largeModels := []model{
  159. // {
  160. // name: "llama2:13b",
  161. // size: 7400 * format.MebiByte,
  162. // },
  163. // {
  164. // name: "codellama:13b",
  165. // size: 7400 * format.MebiByte,
  166. // },
  167. // {
  168. // name: "orca-mini:13b",
  169. // size: 7400 * format.MebiByte,
  170. // },
  171. // {
  172. // name: "gemma:7b",
  173. // size: 5000 * format.MebiByte,
  174. // },
  175. // {
  176. // name: "starcoder2:15b",
  177. // size: 9100 * format.MebiByte,
  178. // },
  179. // }
  180. var chosenModels []model
  181. switch {
  182. case maxVram < 10000*format.MebiByte:
  183. slog.Info("selecting small models")
  184. chosenModels = smallModels
  185. // case maxVram < 30000*format.MebiByte:
  186. default:
  187. slog.Info("selecting medium models")
  188. chosenModels = mediumModels
  189. // default:
  190. // slog.Info("selecting large models")
  191. // chosenModels = largeModels
  192. }
  193. req, resp := GenerateRequests()
  194. for i := range req {
  195. if i > len(chosenModels) {
  196. break
  197. }
  198. req[i].Model = chosenModels[i].name
  199. }
  200. ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) // TODO baseline -- 10m too short
  201. defer cancel()
  202. client, _, cleanup := InitServerConnection(ctx, t)
  203. defer cleanup()
  204. // Make sure all the models are pulled before we get started
  205. for _, r := range req {
  206. require.NoError(t, PullIfMissing(ctx, client, r.Model))
  207. }
  208. var wg sync.WaitGroup
  209. consumed := uint64(256 * format.MebiByte) // Assume some baseline usage
  210. for i := 0; i < len(req); i++ {
  211. // Always get at least 2 models, but don't overshoot VRAM too much or we'll take too long
  212. if i > 1 && consumed > maxVram {
  213. slog.Info("achieved target vram exhaustion", "count", i, "vram", format.HumanBytes2(maxVram), "models", format.HumanBytes2(consumed))
  214. break
  215. }
  216. consumed += chosenModels[i].size
  217. slog.Info("target vram", "count", i, "vram", format.HumanBytes2(maxVram), "models", format.HumanBytes2(consumed))
  218. wg.Add(1)
  219. go func(i int) {
  220. defer wg.Done()
  221. for j := 0; j < 3; j++ {
  222. slog.Info("Starting", "req", i, "iter", j, "model", req[i].Model)
  223. DoGenerate(ctx, t, client, req[i], resp[i], 120*time.Second, 5*time.Second)
  224. }
  225. }(i)
  226. }
  227. go func() {
  228. for {
  229. time.Sleep(2 * time.Second)
  230. select {
  231. case <-ctx.Done():
  232. return
  233. default:
  234. models, err := client.ListRunning(ctx)
  235. if err != nil {
  236. slog.Warn("failed to list running models", "error", err)
  237. continue
  238. }
  239. for _, m := range models.Models {
  240. slog.Info("loaded model snapshot", "model", m)
  241. }
  242. }
  243. }
  244. }()
  245. wg.Wait()
  246. }