new_runner_benchmark_test.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. package benchmark
  2. import (
  3. "bytes"
  4. "context"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "testing"
  10. "time"
  11. )
  12. const (
  13. runnerURL = "http://localhost:8080"
  14. warmupPrompts = 2 // Number of warm-up requests per test case
  15. warmupTokens = 50 // Smaller token count for warm-up requests
  16. )
  17. var runnerMetrics []BenchmarkMetrics
  18. // CompletionRequest represents the request body for the completion endpoint
  19. type CompletionRequest struct {
  20. Prompt string `json:"prompt"`
  21. NumPredict int `json:"n_predict"`
  22. Temperature float32 `json:"temperature"`
  23. }
  24. // CompletionResponse represents a single response chunk from the streaming API
  25. type CompletionResponse struct {
  26. Content string `json:"content"`
  27. Stop bool `json:"stop"`
  28. Timings struct {
  29. PredictedN int `json:"predicted_n"`
  30. PredictedMs int `json:"predicted_ms"`
  31. PromptN int `json:"prompt_n"`
  32. PromptMs int `json:"prompt_ms"`
  33. } `json:"timings"`
  34. }
  35. // warmUp performs warm-up requests before the actual benchmark
  36. func warmUp(b *testing.B, tt TestCase) {
  37. b.Logf("Warming up for test case %s", tt.name)
  38. warmupTest := TestCase{
  39. name: tt.name + "_warmup",
  40. prompt: tt.prompt,
  41. maxTokens: warmupTokens,
  42. }
  43. for i := 0; i < warmupPrompts; i++ {
  44. runCompletion(context.Background(), warmupTest, b)
  45. time.Sleep(100 * time.Millisecond) // Brief pause between warm-up requests
  46. }
  47. b.Logf("Warm-up complete")
  48. }
  49. func BenchmarkRunnerInference(b *testing.B) {
  50. b.Logf("Starting benchmark suite")
  51. // Verify server availability
  52. if _, err := http.Get(runnerURL + "/health"); err != nil {
  53. b.Fatalf("Runner unavailable: %v", err)
  54. }
  55. b.Log("Runner available")
  56. tests := []TestCase{
  57. {
  58. name: "short_prompt",
  59. prompt: formatPrompt("Write a long story"),
  60. maxTokens: 100,
  61. },
  62. {
  63. name: "medium_prompt",
  64. prompt: formatPrompt("Write a detailed economic analysis"),
  65. maxTokens: 500,
  66. },
  67. {
  68. name: "long_prompt",
  69. prompt: formatPrompt("Write a comprehensive AI research paper"),
  70. maxTokens: 1000,
  71. },
  72. }
  73. // Register cleanup handler for results reporting
  74. b.Cleanup(func() { reportMetrics(metrics) })
  75. // Main benchmark loop
  76. for _, tt := range tests {
  77. b.Run(tt.name, func(b *testing.B) {
  78. // Perform warm-up requests
  79. warmUp(b, tt)
  80. // Wait a bit after warm-up before starting the actual benchmark
  81. time.Sleep(500 * time.Millisecond)
  82. m := make([]BenchmarkMetrics, b.N)
  83. for i := 0; i < b.N; i++ {
  84. b.ResetTimer()
  85. m[i] = runCompletion(context.Background(), tt, b)
  86. }
  87. metrics = append(metrics, m...)
  88. })
  89. }
  90. }
  91. func formatPrompt(text string) string {
  92. return fmt.Sprintf("<|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", text)
  93. }
  94. func runCompletion(ctx context.Context, tt TestCase, b *testing.B) BenchmarkMetrics {
  95. start := time.Now()
  96. var ttft time.Duration
  97. var tokens int
  98. lastToken := start
  99. // Create request body
  100. reqBody := CompletionRequest{
  101. Prompt: tt.prompt,
  102. NumPredict: tt.maxTokens,
  103. Temperature: 0.1,
  104. }
  105. jsonData, err := json.Marshal(reqBody)
  106. if err != nil {
  107. b.Fatalf("Failed to marshal request: %v", err)
  108. }
  109. // Create HTTP request
  110. req, err := http.NewRequestWithContext(ctx, "POST", runnerURL+"/completion", bytes.NewBuffer(jsonData))
  111. if err != nil {
  112. b.Fatalf("Failed to create request: %v", err)
  113. }
  114. req.Header.Set("Content-Type", "application/json")
  115. // Execute request
  116. resp, err := http.DefaultClient.Do(req)
  117. if err != nil {
  118. b.Fatalf("Request failed: %v", err)
  119. }
  120. defer resp.Body.Close()
  121. // Process streaming response
  122. decoder := json.NewDecoder(resp.Body)
  123. for {
  124. var chunk CompletionResponse
  125. if err := decoder.Decode(&chunk); err != nil {
  126. if err == io.EOF {
  127. break
  128. }
  129. b.Fatalf("Failed to decode response: %v", err)
  130. }
  131. if ttft == 0 && chunk.Content != "" {
  132. ttft = time.Since(start)
  133. }
  134. if chunk.Content != "" {
  135. tokens++
  136. lastToken = time.Now()
  137. }
  138. if chunk.Stop {
  139. break
  140. }
  141. }
  142. totalTime := lastToken.Sub(start)
  143. return BenchmarkMetrics{
  144. testName: tt.name,
  145. ttft: ttft,
  146. totalTime: totalTime,
  147. totalTokens: tokens,
  148. tokensPerSecond: float64(tokens) / totalTime.Seconds(),
  149. }
  150. }