123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178 |
- package benchmark
- import (
- "context"
- "flag"
- "fmt"
- "testing"
- "time"
- "github.com/ollama/ollama/api"
- )
- // Command line flags
- var modelFlag string
- func init() {
- flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
- flag.Lookup("m").DefValue = "model"
- }
- // modelName returns the model name from flags, failing the test if not set
- func modelName(b *testing.B) string {
- if modelFlag == "" {
- b.Fatal("Error: -m flag is required for benchmark tests")
- }
- return modelFlag
- }
- type TestCase struct {
- name string
- prompt string
- maxTokens int
- }
- // runGenerateBenchmark contains the common generate and metrics logic
- func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
- start := time.Now()
- var ttft time.Duration
- var metrics api.Metrics
- err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
- if ttft == 0 && resp.Response != "" {
- ttft = time.Since(start)
- }
- if resp.Done {
- metrics = resp.Metrics
- }
- return nil
- })
- // Report custom metrics as part of the benchmark results
- b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
- b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
- // Token throughput metrics
- promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
- genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
- b.ReportMetric(promptThroughput, "prompt_tok/s")
- b.ReportMetric(genThroughput, "gen_tok/s")
- // Token counts
- b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
- b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
- if err != nil {
- b.Fatal(err)
- }
- }
- // BenchmarkColdStart runs benchmarks with model loading from cold state
- func BenchmarkColdStart(b *testing.B) {
- client := setup(b)
- tests := []TestCase{
- {"short_prompt", "Write a long story", 100},
- {"medium_prompt", "Write a detailed economic analysis", 500},
- {"long_prompt", "Write a comprehensive AI research paper", 1000},
- }
- m := modelName(b)
- for _, tt := range tests {
- b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
- ctx := context.Background()
- // Set number of tokens as our throughput metric
- b.SetBytes(int64(tt.maxTokens))
- for b.Loop() {
- b.StopTimer()
- // Ensure model is unloaded before each iteration
- unload(client, m, b)
- b.StartTimer()
- req := &api.GenerateRequest{
- Model: m,
- Prompt: tt.prompt,
- Options: map[string]interface{}{"num_predict": tt.maxTokens, "temperature": 0.1},
- }
- runGenerateBenchmark(b, ctx, client, req)
- }
- })
- }
- }
- // BenchmarkWarmStart runs benchmarks with pre-loaded model
- func BenchmarkWarmStart(b *testing.B) {
- client := setup(b)
- tests := []TestCase{
- {"short_prompt", "Write a long story", 100},
- {"medium_prompt", "Write a detailed economic analysis", 500},
- {"long_prompt", "Write a comprehensive AI research paper", 1000},
- }
- m := modelName(b)
- for _, tt := range tests {
- b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
- ctx := context.Background()
- // Pre-warm the model
- warmup(client, m, tt.prompt, b)
- // Set number of tokens as our throughput metric
- b.SetBytes(int64(tt.maxTokens))
- for b.Loop() {
- req := &api.GenerateRequest{
- Model: m,
- Prompt: tt.prompt,
- Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
- }
- runGenerateBenchmark(b, ctx, client, req)
- }
- })
- }
- }
- // setup verifies server and model availability
- func setup(b *testing.B) *api.Client {
- client, err := api.ClientFromEnvironment()
- if err != nil {
- b.Fatal(err)
- }
- if _, err := client.Show(context.Background(), &api.ShowRequest{Model: modelName(b)}); err != nil {
- b.Fatalf("Model unavailable: %v", err)
- }
- return client
- }
- // warmup ensures the model is loaded and warmed up
- func warmup(client *api.Client, model string, prompt string, b *testing.B) {
- for range 3 {
- err := client.Generate(
- context.Background(),
- &api.GenerateRequest{
- Model: model,
- Prompt: prompt,
- Options: map[string]interface{}{"num_predict": 50, "temperature": 0.1},
- },
- func(api.GenerateResponse) error { return nil },
- )
- if err != nil {
- b.Logf("Error during model warm-up: %v", err)
- }
- }
- }
- // unload forces model unloading using KeepAlive: 0 parameter
- func unload(client *api.Client, model string, b *testing.B) {
- req := &api.GenerateRequest{
- Model: model,
- KeepAlive: &api.Duration{Duration: 0},
- }
- if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
- b.Logf("Unload error: %v", err)
- }
- time.Sleep(1 * time.Second)
- }
|