ggml_backend_benchmark_test.go 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. package backend
  2. import (
  3. "flag"
  4. "fmt"
  5. "io"
  6. "log"
  7. "os"
  8. "testing"
  9. "github.com/ollama/ollama/ml"
  10. "github.com/ollama/ollama/model"
  11. "github.com/ollama/ollama/server"
  12. _ "github.com/ollama/ollama/model/models/llama"
  13. )
  14. var modelName = flag.String("m", "", "Name of the model to benchmark")
  15. func suppressOutput() (cleanup func()) {
  16. oldStdout, oldStderr := os.Stdout, os.Stderr
  17. os.Stdout, os.Stderr = nil, nil
  18. log.SetOutput(io.Discard)
  19. return func() {
  20. os.Stdout, os.Stderr = oldStdout, oldStderr
  21. log.SetOutput(os.Stderr)
  22. }
  23. }
  24. func setupModel(b *testing.B) model.Model {
  25. if *modelName == "" {
  26. b.Fatal("Error: -m flag is required for benchmark tests")
  27. }
  28. sm, err := server.GetModel(*modelName)
  29. if err != nil {
  30. b.Fatal(err)
  31. }
  32. m, err := model.New(sm.ModelPath)
  33. if err != nil {
  34. b.Fatal(err)
  35. }
  36. m.Config().Cache.Init(m.Backend(), ml.DTypeF32, 2048)
  37. return m
  38. }
  39. func BenchmarkGGMLOperations(b *testing.B) {
  40. // loading the GGML back-end logs to standard out and makes the bench output messy
  41. cleanup := suppressOutput()
  42. defer cleanup()
  43. b.Setenv("OLLAMA_BENCHMARK", "1")
  44. b.Setenv("OLLAMA_BACKEND", "ggml")
  45. m := setupModel(b)
  46. // Sample input data
  47. inputIDs := []int32{1, 2, 3, 4, 5}
  48. options := model.Options{
  49. Inputs: inputIDs,
  50. Positions: []int32{1, 2, 3, 4, 5},
  51. Sequences: []int{1, 1, 1, 1, 1},
  52. Outputs: []int32{int32(len(inputIDs) - 1)},
  53. }
  54. b.ResetTimer()
  55. for range b.N {
  56. ctx := m.Backend().NewContext()
  57. defer ctx.Close()
  58. modelOutput, err := model.Forward(ctx, m, options)
  59. if err != nil {
  60. b.Fatal(fmt.Errorf("forward pass failed: %v", err))
  61. }
  62. ctx.Compute(modelOutput)
  63. for _, op := range ctx.Timing() {
  64. b.ReportMetric(op.Duration, fmt.Sprintf("%s_ms", op.Type))
  65. }
  66. }
  67. }