123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175 |
- package benchmark
- import (
- "bytes"
- "context"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "testing"
- "time"
- )
- const (
- runnerURL = "http://localhost:8080"
- warmupPrompts = 2 // Number of warm-up requests per test case
- warmupTokens = 50 // Smaller token count for warm-up requests
- )
- var runnerMetrics []BenchmarkMetrics
- // CompletionRequest represents the request body for the completion endpoint
- type CompletionRequest struct {
- Prompt string `json:"prompt"`
- NumPredict int `json:"n_predict"`
- Temperature float32 `json:"temperature"`
- }
- // CompletionResponse represents a single response chunk from the streaming API
- type CompletionResponse struct {
- Content string `json:"content"`
- Stop bool `json:"stop"`
- Timings struct {
- PredictedN int `json:"predicted_n"`
- PredictedMs int `json:"predicted_ms"`
- PromptN int `json:"prompt_n"`
- PromptMs int `json:"prompt_ms"`
- } `json:"timings"`
- }
- // warmUp performs warm-up requests before the actual benchmark
- func warmUp(b *testing.B, tt TestCase) {
- b.Logf("Warming up for test case %s", tt.name)
- warmupTest := TestCase{
- name: tt.name + "_warmup",
- prompt: tt.prompt,
- maxTokens: warmupTokens,
- }
- for i := 0; i < warmupPrompts; i++ {
- runCompletion(context.Background(), warmupTest, b)
- time.Sleep(100 * time.Millisecond) // Brief pause between warm-up requests
- }
- b.Logf("Warm-up complete")
- }
- func BenchmarkRunnerInference(b *testing.B) {
- b.Logf("Starting benchmark suite")
- // Verify server availability
- if _, err := http.Get(runnerURL + "/health"); err != nil {
- b.Fatalf("Runner unavailable: %v", err)
- }
- b.Log("Runner available")
- tests := []TestCase{
- {
- name: "short_prompt",
- prompt: formatPrompt("Write a long story"),
- maxTokens: 100,
- },
- {
- name: "medium_prompt",
- prompt: formatPrompt("Write a detailed economic analysis"),
- maxTokens: 500,
- },
- {
- name: "long_prompt",
- prompt: formatPrompt("Write a comprehensive AI research paper"),
- maxTokens: 1000,
- },
- }
- // Register cleanup handler for results reporting
- b.Cleanup(func() { reportMetrics(metrics) })
- // Main benchmark loop
- for _, tt := range tests {
- b.Run(tt.name, func(b *testing.B) {
- // Perform warm-up requests
- warmUp(b, tt)
- // Wait a bit after warm-up before starting the actual benchmark
- time.Sleep(500 * time.Millisecond)
- m := make([]BenchmarkMetrics, b.N)
- for i := 0; i < b.N; i++ {
- b.ResetTimer()
- m[i] = runCompletion(context.Background(), tt, b)
- }
- metrics = append(metrics, m...)
- })
- }
- }
- func formatPrompt(text string) string {
- return fmt.Sprintf("<|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\n\n<|eot_id|><|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n", text)
- }
- func runCompletion(ctx context.Context, tt TestCase, b *testing.B) BenchmarkMetrics {
- start := time.Now()
- var ttft time.Duration
- var tokens int
- lastToken := start
- // Create request body
- reqBody := CompletionRequest{
- Prompt: tt.prompt,
- NumPredict: tt.maxTokens,
- Temperature: 0.1,
- }
- jsonData, err := json.Marshal(reqBody)
- if err != nil {
- b.Fatalf("Failed to marshal request: %v", err)
- }
- // Create HTTP request
- req, err := http.NewRequestWithContext(ctx, "POST", runnerURL+"/completion", bytes.NewBuffer(jsonData))
- if err != nil {
- b.Fatalf("Failed to create request: %v", err)
- }
- req.Header.Set("Content-Type", "application/json")
- // Execute request
- resp, err := http.DefaultClient.Do(req)
- if err != nil {
- b.Fatalf("Request failed: %v", err)
- }
- defer resp.Body.Close()
- // Process streaming response
- decoder := json.NewDecoder(resp.Body)
- for {
- var chunk CompletionResponse
- if err := decoder.Decode(&chunk); err != nil {
- if err == io.EOF {
- break
- }
- b.Fatalf("Failed to decode response: %v", err)
- }
- if ttft == 0 && chunk.Content != "" {
- ttft = time.Since(start)
- }
- if chunk.Content != "" {
- tokens++
- lastToken = time.Now()
- }
- if chunk.Stop {
- break
- }
- }
- totalTime := lastToken.Sub(start)
- return BenchmarkMetrics{
- testName: tt.name,
- ttft: ttft,
- totalTime: totalTime,
- totalTokens: tokens,
- tokensPerSecond: float64(tokens) / totalTime.Seconds(),
- }
- }
|