Browse Source

benchmark: performance of running ollama server (#8643)

Bruce MacDonald 1 month ago
parent
commit
fb6252d786
2 changed files with 237 additions and 0 deletions
  1. 178 0
      benchmark/server_benchmark_test.go
  2. 59 0
      docs/benchmark.md

+ 178 - 0
benchmark/server_benchmark_test.go

@@ -0,0 +1,178 @@
+package benchmark
+
+import (
+	"context"
+	"flag"
+	"fmt"
+	"testing"
+	"time"
+
+	"github.com/ollama/ollama/api"
+)
+
+// Command line flags
+var modelFlag string
+
+func init() {
+	flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
+	flag.Lookup("m").DefValue = "model"
+}
+
+// modelName returns the model name from flags, failing the test if not set
+func modelName(b *testing.B) string {
+	if modelFlag == "" {
+		b.Fatal("Error: -m flag is required for benchmark tests")
+	}
+	return modelFlag
+}
+
+type TestCase struct {
+	name      string
+	prompt    string
+	maxTokens int
+}
+
+// runGenerateBenchmark contains the common generate and metrics logic
+func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
+	start := time.Now()
+	var ttft time.Duration
+	var metrics api.Metrics
+
+	err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
+		if ttft == 0 && resp.Response != "" {
+			ttft = time.Since(start)
+		}
+		if resp.Done {
+			metrics = resp.Metrics
+		}
+		return nil
+	})
+
+	// Report custom metrics as part of the benchmark results
+	b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
+	b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
+
+	// Token throughput metrics
+	promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
+	genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
+	b.ReportMetric(promptThroughput, "prompt_tok/s")
+	b.ReportMetric(genThroughput, "gen_tok/s")
+
+	// Token counts
+	b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
+	b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
+	if err != nil {
+		b.Fatal(err)
+	}
+}
+
+// BenchmarkColdStart runs benchmarks with model loading from cold state
+func BenchmarkColdStart(b *testing.B) {
+	client := setup(b)
+	tests := []TestCase{
+		{"short_prompt", "Write a long story", 100},
+		{"medium_prompt", "Write a detailed economic analysis", 500},
+		{"long_prompt", "Write a comprehensive AI research paper", 1000},
+	}
+	m := modelName(b)
+
+	for _, tt := range tests {
+		b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
+			ctx := context.Background()
+
+			// Set number of tokens as our throughput metric
+			b.SetBytes(int64(tt.maxTokens))
+
+			for b.Loop() {
+				b.StopTimer()
+				// Ensure model is unloaded before each iteration
+				unload(client, m, b)
+				b.StartTimer()
+
+				req := &api.GenerateRequest{
+					Model:   m,
+					Prompt:  tt.prompt,
+					Options: map[string]interface{}{"num_predict": tt.maxTokens, "temperature": 0.1},
+				}
+
+				runGenerateBenchmark(b, ctx, client, req)
+			}
+		})
+	}
+}
+
+// BenchmarkWarmStart runs benchmarks with pre-loaded model
+func BenchmarkWarmStart(b *testing.B) {
+	client := setup(b)
+	tests := []TestCase{
+		{"short_prompt", "Write a long story", 100},
+		{"medium_prompt", "Write a detailed economic analysis", 500},
+		{"long_prompt", "Write a comprehensive AI research paper", 1000},
+	}
+	m := modelName(b)
+
+	for _, tt := range tests {
+		b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
+			ctx := context.Background()
+
+			// Pre-warm the model
+			warmup(client, m, tt.prompt, b)
+
+			// Set number of tokens as our throughput metric
+			b.SetBytes(int64(tt.maxTokens))
+
+			for b.Loop() {
+				req := &api.GenerateRequest{
+					Model:   m,
+					Prompt:  tt.prompt,
+					Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
+				}
+
+				runGenerateBenchmark(b, ctx, client, req)
+			}
+		})
+	}
+}
+
+// setup verifies server and model availability
+func setup(b *testing.B) *api.Client {
+	client, err := api.ClientFromEnvironment()
+	if err != nil {
+		b.Fatal(err)
+	}
+	if _, err := client.Show(context.Background(), &api.ShowRequest{Model: modelName(b)}); err != nil {
+		b.Fatalf("Model unavailable: %v", err)
+	}
+
+	return client
+}
+
+// warmup ensures the model is loaded and warmed up
+func warmup(client *api.Client, model string, prompt string, b *testing.B) {
+	for range 3 {
+		err := client.Generate(
+			context.Background(),
+			&api.GenerateRequest{
+				Model:   model,
+				Prompt:  prompt,
+				Options: map[string]interface{}{"num_predict": 50, "temperature": 0.1},
+			},
+			func(api.GenerateResponse) error { return nil },
+		)
+		if err != nil {
+			b.Logf("Error during model warm-up: %v", err)
+		}
+	}
+}
+
+// unload forces model unloading using KeepAlive: 0 parameter
+func unload(client *api.Client, model string, b *testing.B) {
+	req := &api.GenerateRequest{
+		Model:     model,
+		KeepAlive: &api.Duration{Duration: 0},
+	}
+	if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
+		b.Logf("Unload error: %v", err)
+	}
+	time.Sleep(1 * time.Second)
+}

+ 59 - 0
docs/benchmark.md

@@ -0,0 +1,59 @@
+# Benchmark
+
+Go benchmark tests that measure end-to-end performance of a running Ollama server. Run these tests to evaluate model inference performance on your hardware and measure the impact of code changes.
+
+## When to use
+
+Run these benchmarks when:
+- Making changes to the model inference engine
+- Modifying model loading/unloading logic
+- Changing prompt processing or token generation code
+- Implementing a new model architecture
+- Testing performance across different hardware setups
+
+## Prerequisites
+- Ollama server running locally with `ollama serve` on `127.0.0.1:11434`
+## Usage and Examples
+
+>[!NOTE]
+>All commands must be run from the root directory of the Ollama project.
+
+Basic syntax:
+```bash
+go test -bench=. ./benchmark/... -m $MODEL_NAME
+```
+
+Required flags:
+- `-bench=.`: Run all benchmarks
+- `-m`: Model name to benchmark
+
+Optional flags:
+- `-count N`: Number of times to run the benchmark (useful for statistical analysis)
+- `-timeout T`: Maximum time for the benchmark to run (e.g. "10m" for 10 minutes)
+
+Common usage patterns:
+
+Single benchmark run with a model specified:
+```bash
+go test -bench=. ./benchmark/... -m llama3.3
+```
+
+## Output metrics
+
+The benchmark reports several key metrics:
+
+- `gen_tok/s`: Generated tokens per second
+- `prompt_tok/s`: Prompt processing tokens per second
+- `ttft_ms`: Time to first token in milliseconds
+- `load_ms`: Model load time in milliseconds
+- `gen_tokens`: Total tokens generated
+- `prompt_tokens`: Total prompt tokens processed
+
+Each benchmark runs two scenarios:
+- Cold start: Model is loaded from disk for each test
+- Warm start: Model is pre-loaded in memory
+
+Three prompt lengths are tested for each scenario:
+- Short prompt (100 tokens)
+- Medium prompt (500 tokens)
+- Long prompt (1000 tokens)