|
@@ -0,0 +1,178 @@
|
|
|
+package benchmark
|
|
|
+
|
|
|
+import (
|
|
|
+ "context"
|
|
|
+ "flag"
|
|
|
+ "fmt"
|
|
|
+ "testing"
|
|
|
+ "time"
|
|
|
+
|
|
|
+ "github.com/ollama/ollama/api"
|
|
|
+)
|
|
|
+
|
|
|
+// Command line flags
|
|
|
+var modelFlag string
|
|
|
+
|
|
|
+func init() {
|
|
|
+ flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
|
|
|
+ flag.Lookup("m").DefValue = "model"
|
|
|
+}
|
|
|
+
|
|
|
+// modelName returns the model name from flags, failing the test if not set
|
|
|
+func modelName(b *testing.B) string {
|
|
|
+ if modelFlag == "" {
|
|
|
+ b.Fatal("Error: -m flag is required for benchmark tests")
|
|
|
+ }
|
|
|
+ return modelFlag
|
|
|
+}
|
|
|
+
|
|
|
+type TestCase struct {
|
|
|
+ name string
|
|
|
+ prompt string
|
|
|
+ maxTokens int
|
|
|
+}
|
|
|
+
|
|
|
+// runGenerateBenchmark contains the common generate and metrics logic
|
|
|
+func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
|
|
|
+ start := time.Now()
|
|
|
+ var ttft time.Duration
|
|
|
+ var metrics api.Metrics
|
|
|
+
|
|
|
+ err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
|
|
|
+ if ttft == 0 && resp.Response != "" {
|
|
|
+ ttft = time.Since(start)
|
|
|
+ }
|
|
|
+ if resp.Done {
|
|
|
+ metrics = resp.Metrics
|
|
|
+ }
|
|
|
+ return nil
|
|
|
+ })
|
|
|
+
|
|
|
+ // Report custom metrics as part of the benchmark results
|
|
|
+ b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
|
|
|
+ b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
|
|
|
+
|
|
|
+ // Token throughput metrics
|
|
|
+ promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
|
|
|
+ genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
|
|
|
+ b.ReportMetric(promptThroughput, "prompt_tok/s")
|
|
|
+ b.ReportMetric(genThroughput, "gen_tok/s")
|
|
|
+
|
|
|
+ // Token counts
|
|
|
+ b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
|
|
|
+ b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
|
|
|
+ if err != nil {
|
|
|
+ b.Fatal(err)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// BenchmarkColdStart runs benchmarks with model loading from cold state
|
|
|
+func BenchmarkColdStart(b *testing.B) {
|
|
|
+ client := setup(b)
|
|
|
+ tests := []TestCase{
|
|
|
+ {"short_prompt", "Write a long story", 100},
|
|
|
+ {"medium_prompt", "Write a detailed economic analysis", 500},
|
|
|
+ {"long_prompt", "Write a comprehensive AI research paper", 1000},
|
|
|
+ }
|
|
|
+ m := modelName(b)
|
|
|
+
|
|
|
+ for _, tt := range tests {
|
|
|
+ b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
|
|
|
+ ctx := context.Background()
|
|
|
+
|
|
|
+ // Set number of tokens as our throughput metric
|
|
|
+ b.SetBytes(int64(tt.maxTokens))
|
|
|
+
|
|
|
+ for b.Loop() {
|
|
|
+ b.StopTimer()
|
|
|
+ // Ensure model is unloaded before each iteration
|
|
|
+ unload(client, m, b)
|
|
|
+ b.StartTimer()
|
|
|
+
|
|
|
+ req := &api.GenerateRequest{
|
|
|
+ Model: m,
|
|
|
+ Prompt: tt.prompt,
|
|
|
+ Options: map[string]interface{}{"num_predict": tt.maxTokens, "temperature": 0.1},
|
|
|
+ }
|
|
|
+
|
|
|
+ runGenerateBenchmark(b, ctx, client, req)
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// BenchmarkWarmStart runs benchmarks with pre-loaded model
|
|
|
+func BenchmarkWarmStart(b *testing.B) {
|
|
|
+ client := setup(b)
|
|
|
+ tests := []TestCase{
|
|
|
+ {"short_prompt", "Write a long story", 100},
|
|
|
+ {"medium_prompt", "Write a detailed economic analysis", 500},
|
|
|
+ {"long_prompt", "Write a comprehensive AI research paper", 1000},
|
|
|
+ }
|
|
|
+ m := modelName(b)
|
|
|
+
|
|
|
+ for _, tt := range tests {
|
|
|
+ b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
|
|
|
+ ctx := context.Background()
|
|
|
+
|
|
|
+ // Pre-warm the model
|
|
|
+ warmup(client, m, tt.prompt, b)
|
|
|
+
|
|
|
+ // Set number of tokens as our throughput metric
|
|
|
+ b.SetBytes(int64(tt.maxTokens))
|
|
|
+
|
|
|
+ for b.Loop() {
|
|
|
+ req := &api.GenerateRequest{
|
|
|
+ Model: m,
|
|
|
+ Prompt: tt.prompt,
|
|
|
+ Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
|
|
|
+ }
|
|
|
+
|
|
|
+ runGenerateBenchmark(b, ctx, client, req)
|
|
|
+ }
|
|
|
+ })
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// setup verifies server and model availability
|
|
|
+func setup(b *testing.B) *api.Client {
|
|
|
+ client, err := api.ClientFromEnvironment()
|
|
|
+ if err != nil {
|
|
|
+ b.Fatal(err)
|
|
|
+ }
|
|
|
+ if _, err := client.Show(context.Background(), &api.ShowRequest{Model: modelName(b)}); err != nil {
|
|
|
+ b.Fatalf("Model unavailable: %v", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ return client
|
|
|
+}
|
|
|
+
|
|
|
+// warmup ensures the model is loaded and warmed up
|
|
|
+func warmup(client *api.Client, model string, prompt string, b *testing.B) {
|
|
|
+ for range 3 {
|
|
|
+ err := client.Generate(
|
|
|
+ context.Background(),
|
|
|
+ &api.GenerateRequest{
|
|
|
+ Model: model,
|
|
|
+ Prompt: prompt,
|
|
|
+ Options: map[string]interface{}{"num_predict": 50, "temperature": 0.1},
|
|
|
+ },
|
|
|
+ func(api.GenerateResponse) error { return nil },
|
|
|
+ )
|
|
|
+ if err != nil {
|
|
|
+ b.Logf("Error during model warm-up: %v", err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// unload forces model unloading using KeepAlive: 0 parameter
|
|
|
+func unload(client *api.Client, model string, b *testing.B) {
|
|
|
+ req := &api.GenerateRequest{
|
|
|
+ Model: model,
|
|
|
+ KeepAlive: &api.Duration{Duration: 0},
|
|
|
+ }
|
|
|
+ if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
|
|
|
+ b.Logf("Unload error: %v", err)
|
|
|
+ }
|
|
|
+ time.Sleep(1 * time.Second)
|
|
|
+}
|