server_benchmark_test.go 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. // Package benchmark provides tools for performance testing of Ollama inference server and supported models.
  2. package benchmark
  3. import (
  4. "context"
  5. "fmt"
  6. "net/http"
  7. "net/url"
  8. "os"
  9. "testing"
  10. "text/tabwriter"
  11. "time"
  12. "github.com/ollama/ollama/api"
  13. )
  14. // ServerURL is the default Ollama server URL for benchmarking
  15. const serverURL = "http://127.0.0.1:11434"
  16. // metrics collects all benchmark results for final reporting
  17. var metrics []BenchmarkMetrics
  18. // models contains the list of model names to benchmark
  19. var models = []string{
  20. "llama3.2:1b",
  21. // "qwen2.5:7b",
  22. // "llama3.3:70b",
  23. }
  24. // TestCase defines a benchmark test scenario with prompt characteristics
  25. type TestCase struct {
  26. name string // Human-readable test name
  27. prompt string // Input prompt text
  28. maxTokens int // Maximum tokens to generate
  29. }
  30. // BenchmarkMetrics contains performance measurements for a single test run
  31. type BenchmarkMetrics struct {
  32. model string // Model being tested
  33. scenario string // cold_start or warm_start
  34. testName string // Name of the test case
  35. ttft time.Duration // Time To First Token (TTFT)
  36. totalTime time.Duration // Total time for complete response
  37. totalTokens int // Total generated tokens
  38. tokensPerSecond float64 // Calculated throughput
  39. }
  40. // ScenarioType defines the initialization state for benchmarking
  41. type ScenarioType int
  42. const (
  43. ColdStart ScenarioType = iota // Model is loaded from cold state
  44. WarmStart // Model is already loaded in memory
  45. )
  46. // String implements fmt.Stringer for ScenarioType
  47. func (s ScenarioType) String() string {
  48. return [...]string{"cold_start", "warm_start"}[s]
  49. }
  50. // BenchmarkServerInference is the main entry point for benchmarking Ollama inference performance.
  51. // It tests all configured models with different prompt lengths and start scenarios.
  52. func BenchmarkServerInference(b *testing.B) {
  53. b.Logf("Starting benchmark suite with %d models", len(models))
  54. // Verify server availability
  55. if _, err := http.Get(serverURL + "/api/version"); err != nil {
  56. b.Fatalf("Server unavailable: %v", err)
  57. }
  58. b.Log("Server available")
  59. tests := []TestCase{
  60. {"short_prompt", "Write a long story", 100},
  61. {"medium_prompt", "Write a detailed economic analysis", 500},
  62. {"long_prompt", "Write a comprehensive AI research paper", 1000},
  63. }
  64. // Register cleanup handler for results reporting
  65. b.Cleanup(func() { reportMetrics(metrics) })
  66. // Main benchmark loop
  67. for _, model := range models {
  68. client := api.NewClient(mustParse(serverURL), http.DefaultClient)
  69. // Verify model availability
  70. if _, err := client.Show(context.Background(), &api.ShowRequest{Model: model}); err != nil {
  71. b.Fatalf("Model unavailable: %v", err)
  72. }
  73. for _, tt := range tests {
  74. testName := fmt.Sprintf("%s/%s/%s", model, ColdStart, tt.name)
  75. b.Run(testName, func(b *testing.B) {
  76. m := runBenchmark(b, tt, model, ColdStart, client)
  77. metrics = append(metrics, m...)
  78. })
  79. }
  80. for _, tt := range tests {
  81. testName := fmt.Sprintf("%s/%s/%s", model, WarmStart, tt.name)
  82. b.Run(testName, func(b *testing.B) {
  83. m := runBenchmark(b, tt, model, WarmStart, client)
  84. metrics = append(metrics, m...)
  85. })
  86. }
  87. }
  88. }
  89. // runBenchmark executes multiple iterations of a specific test case and scenario.
  90. // Returns collected metrics for all iterations.
  91. func runBenchmark(b *testing.B, tt TestCase, model string, scenario ScenarioType, client *api.Client) []BenchmarkMetrics {
  92. results := make([]BenchmarkMetrics, b.N)
  93. // Run benchmark iterations
  94. for i := 0; i < b.N; i++ {
  95. switch scenario {
  96. case WarmStart:
  97. // Pre-warm the model by generating some tokens
  98. for i := 0; i < 2; i++ {
  99. client.Generate(
  100. context.Background(),
  101. &api.GenerateRequest{
  102. Model: model,
  103. Prompt: tt.prompt,
  104. Options: map[string]interface{}{"num_predict": tt.maxTokens, "temperature": 0.1},
  105. },
  106. func(api.GenerateResponse) error { return nil },
  107. )
  108. }
  109. case ColdStart:
  110. unloadModel(client, model, b)
  111. }
  112. b.ResetTimer()
  113. results[i] = runSingleIteration(context.Background(), client, tt, model, b)
  114. results[i].scenario = scenario.String()
  115. }
  116. return results
  117. }
  118. // unloadModel forces model unloading using KeepAlive: -1 parameter.
  119. // Includes short delay to ensure unloading completes before next test.
  120. func unloadModel(client *api.Client, model string, b *testing.B) {
  121. req := &api.GenerateRequest{
  122. Model: model,
  123. KeepAlive: &api.Duration{Duration: 0},
  124. }
  125. if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
  126. b.Logf("Unload error: %v", err)
  127. }
  128. time.Sleep(100 * time.Millisecond)
  129. }
  130. // runSingleIteration measures performance metrics for a single inference request.
  131. // Captures TTFT, total generation time, and calculates tokens/second.
  132. func runSingleIteration(ctx context.Context, client *api.Client, tt TestCase, model string, b *testing.B) BenchmarkMetrics {
  133. start := time.Now()
  134. var ttft time.Duration
  135. var tokens int
  136. lastToken := start
  137. req := &api.GenerateRequest{
  138. Model: model,
  139. Prompt: tt.prompt,
  140. Options: map[string]interface{}{"num_predict": tt.maxTokens, "temperature": 0.1},
  141. }
  142. if b != nil {
  143. b.Logf("Prompt length: %d chars", len(tt.prompt))
  144. }
  145. // Execute generation request with metrics collection
  146. client.Generate(ctx, req, func(resp api.GenerateResponse) error {
  147. if ttft == 0 {
  148. ttft = time.Since(start)
  149. }
  150. if resp.Response != "" {
  151. tokens++
  152. lastToken = time.Now()
  153. }
  154. return nil
  155. })
  156. totalTime := lastToken.Sub(start)
  157. return BenchmarkMetrics{
  158. model: model,
  159. testName: tt.name,
  160. ttft: ttft,
  161. totalTime: totalTime,
  162. totalTokens: tokens,
  163. tokensPerSecond: float64(tokens) / totalTime.Seconds(),
  164. }
  165. }
  166. // reportMetrics processes collected metrics and prints formatted results.
  167. // Generates both human-readable tables and CSV output with averaged statistics.
  168. func reportMetrics(results []BenchmarkMetrics) {
  169. if len(results) == 0 {
  170. return
  171. }
  172. // Aggregate results by test case
  173. type statsKey struct {
  174. model string
  175. scenario string
  176. testName string
  177. }
  178. stats := make(map[statsKey]*struct {
  179. ttftSum time.Duration
  180. totalTimeSum time.Duration
  181. tokensSum int
  182. iterations int
  183. })
  184. for _, m := range results {
  185. key := statsKey{m.model, m.scenario, m.testName}
  186. if _, exists := stats[key]; !exists {
  187. stats[key] = &struct {
  188. ttftSum time.Duration
  189. totalTimeSum time.Duration
  190. tokensSum int
  191. iterations int
  192. }{}
  193. }
  194. stats[key].ttftSum += m.ttft
  195. stats[key].totalTimeSum += m.totalTime
  196. stats[key].tokensSum += m.totalTokens
  197. stats[key].iterations++
  198. }
  199. // Calculate averages
  200. var averaged []BenchmarkMetrics
  201. for key, data := range stats {
  202. count := data.iterations
  203. averaged = append(averaged, BenchmarkMetrics{
  204. model: key.model,
  205. scenario: key.scenario,
  206. testName: key.testName,
  207. ttft: data.ttftSum / time.Duration(count),
  208. totalTime: data.totalTimeSum / time.Duration(count),
  209. totalTokens: data.tokensSum / count,
  210. tokensPerSecond: float64(data.tokensSum) / data.totalTimeSum.Seconds(),
  211. })
  212. }
  213. // Print formatted results
  214. printTableResults(averaged)
  215. printCSVResults(averaged)
  216. }
  217. // printTableResults displays averaged metrics in a formatted table
  218. func printTableResults(averaged []BenchmarkMetrics) {
  219. w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
  220. fmt.Fprintln(w, "\nAVERAGED BENCHMARK RESULTS")
  221. fmt.Fprintln(w, "Model\tScenario\tTest Name\tTTFT (ms)\tTotal Time (ms)\tTokens\tTokens/sec")
  222. for _, m := range averaged {
  223. fmt.Fprintf(w, "%s\t%s\t%s\t%.2f\t%.2f\t%d\t%.2f\n",
  224. m.model,
  225. m.scenario,
  226. m.testName,
  227. float64(m.ttft.Milliseconds()),
  228. float64(m.totalTime.Milliseconds()),
  229. m.totalTokens,
  230. m.tokensPerSecond,
  231. )
  232. }
  233. w.Flush()
  234. }
  235. // printCSVResults outputs averaged metrics in CSV format
  236. func printCSVResults(averaged []BenchmarkMetrics) {
  237. fmt.Println("\nCSV OUTPUT")
  238. fmt.Println("model,scenario,test_name,ttft_ms,total_ms,tokens,tokens_per_sec")
  239. for _, m := range averaged {
  240. fmt.Printf("%s,%s,%s,%.2f,%.2f,%d,%.2f\n",
  241. m.model,
  242. m.scenario,
  243. m.testName,
  244. float64(m.ttft.Milliseconds()),
  245. float64(m.totalTime.Milliseconds()),
  246. m.totalTokens,
  247. m.tokensPerSecond,
  248. )
  249. }
  250. }
  251. // mustParse is a helper function to parse URLs with panic on error
  252. func mustParse(rawURL string) *url.URL {
  253. u, err := url.Parse(rawURL)
  254. if err != nil {
  255. panic(err)
  256. }
  257. return u
  258. }