server_benchmark_test.go 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. package benchmark
  2. import (
  3. "context"
  4. "flag"
  5. "fmt"
  6. "testing"
  7. "time"
  8. "github.com/ollama/ollama/api"
  9. )
  10. // Command line flags
  11. var modelFlag string
  12. func init() {
  13. flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
  14. flag.Lookup("m").DefValue = "model"
  15. }
  16. // modelName returns the model name from flags, failing the test if not set
  17. func modelName(b *testing.B) string {
  18. if modelFlag == "" {
  19. b.Fatal("Error: -m flag is required for benchmark tests")
  20. }
  21. return modelFlag
  22. }
  23. type TestCase struct {
  24. name string
  25. prompt string
  26. maxTokens int
  27. }
  28. // runGenerateBenchmark contains the common generate and metrics logic
  29. func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
  30. start := time.Now()
  31. var ttft time.Duration
  32. var metrics api.Metrics
  33. err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
  34. if ttft == 0 && resp.Response != "" {
  35. ttft = time.Since(start)
  36. }
  37. if resp.Done {
  38. metrics = resp.Metrics
  39. }
  40. return nil
  41. })
  42. // Report custom metrics as part of the benchmark results
  43. b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
  44. b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
  45. // Token throughput metrics
  46. promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
  47. genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
  48. b.ReportMetric(promptThroughput, "prompt_tok/s")
  49. b.ReportMetric(genThroughput, "gen_tok/s")
  50. // Token counts
  51. b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
  52. b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
  53. if err != nil {
  54. b.Fatal(err)
  55. }
  56. }
  57. // BenchmarkColdStart runs benchmarks with model loading from cold state
  58. func BenchmarkColdStart(b *testing.B) {
  59. client := setup(b)
  60. tests := []TestCase{
  61. {"short_prompt", "Write a long story", 100},
  62. {"medium_prompt", "Write a detailed economic analysis", 500},
  63. {"long_prompt", "Write a comprehensive AI research paper", 1000},
  64. }
  65. m := modelName(b)
  66. for _, tt := range tests {
  67. b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
  68. ctx := context.Background()
  69. // Set number of tokens as our throughput metric
  70. b.SetBytes(int64(tt.maxTokens))
  71. for b.Loop() {
  72. b.StopTimer()
  73. // Ensure model is unloaded before each iteration
  74. unload(client, m, b)
  75. b.StartTimer()
  76. req := &api.GenerateRequest{
  77. Model: m,
  78. Prompt: tt.prompt,
  79. Options: map[string]interface{}{"num_predict": tt.maxTokens, "temperature": 0.1},
  80. }
  81. runGenerateBenchmark(b, ctx, client, req)
  82. }
  83. })
  84. }
  85. }
  86. // BenchmarkWarmStart runs benchmarks with pre-loaded model
  87. func BenchmarkWarmStart(b *testing.B) {
  88. client := setup(b)
  89. tests := []TestCase{
  90. {"short_prompt", "Write a long story", 100},
  91. {"medium_prompt", "Write a detailed economic analysis", 500},
  92. {"long_prompt", "Write a comprehensive AI research paper", 1000},
  93. }
  94. m := modelName(b)
  95. for _, tt := range tests {
  96. b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
  97. ctx := context.Background()
  98. // Pre-warm the model
  99. warmup(client, m, tt.prompt, b)
  100. // Set number of tokens as our throughput metric
  101. b.SetBytes(int64(tt.maxTokens))
  102. for b.Loop() {
  103. req := &api.GenerateRequest{
  104. Model: m,
  105. Prompt: tt.prompt,
  106. Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
  107. }
  108. runGenerateBenchmark(b, ctx, client, req)
  109. }
  110. })
  111. }
  112. }
  113. // setup verifies server and model availability
  114. func setup(b *testing.B) *api.Client {
  115. client, err := api.ClientFromEnvironment()
  116. if err != nil {
  117. b.Fatal(err)
  118. }
  119. if _, err := client.Show(context.Background(), &api.ShowRequest{Model: modelName(b)}); err != nil {
  120. b.Fatalf("Model unavailable: %v", err)
  121. }
  122. return client
  123. }
  124. // warmup ensures the model is loaded and warmed up
  125. func warmup(client *api.Client, model string, prompt string, b *testing.B) {
  126. for range 3 {
  127. err := client.Generate(
  128. context.Background(),
  129. &api.GenerateRequest{
  130. Model: model,
  131. Prompt: prompt,
  132. Options: map[string]interface{}{"num_predict": 50, "temperature": 0.1},
  133. },
  134. func(api.GenerateResponse) error { return nil },
  135. )
  136. if err != nil {
  137. b.Logf("Error during model warm-up: %v", err)
  138. }
  139. }
  140. }
  141. // unload forces model unloading using KeepAlive: 0 parameter
  142. func unload(client *api.Client, model string, b *testing.B) {
  143. req := &api.GenerateRequest{
  144. Model: model,
  145. KeepAlive: &api.Duration{Duration: 0},
  146. }
  147. if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
  148. b.Logf("Unload error: %v", err)
  149. }
  150. time.Sleep(1 * time.Second)
  151. }