main.go 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. package main
  2. import (
  3. "errors"
  4. "flag"
  5. "fmt"
  6. "image"
  7. "io"
  8. "log/slog"
  9. "os"
  10. "path/filepath"
  11. "strings"
  12. "time"
  13. "github.com/ollama/ollama/cache"
  14. "github.com/ollama/ollama/ml"
  15. "github.com/ollama/ollama/model"
  16. _ "github.com/ollama/ollama/model/llama"
  17. _ "github.com/ollama/ollama/model/mllama"
  18. "github.com/ollama/ollama/sample"
  19. )
  20. var args struct {
  21. n int
  22. debug bool
  23. image string
  24. cache bool
  25. }
  26. func temp() error {
  27. start := time.Now()
  28. flag.IntVar(&args.n, "n", 10, "number of samples")
  29. flag.BoolVar(&args.debug, "debug", false, "enable debug logging")
  30. flag.StringVar(&args.image, "image", "", "path to image file")
  31. flag.BoolVar(&args.cache, "cache", false, "enable KV cache")
  32. flag.Parse()
  33. var prompt string
  34. if n := len(flag.Args()); n == 1 {
  35. bts, err := io.ReadAll(os.Stdin)
  36. if err != nil {
  37. return err
  38. }
  39. prompt = string(bts)
  40. } else if n > 1 {
  41. prompt = strings.Join(flag.Args()[1:], " ")
  42. } else {
  43. return fmt.Errorf("usage: %s path/to/file <prompt\n", filepath.Base(os.Args[0]))
  44. }
  45. level := slog.LevelInfo
  46. if args.debug {
  47. level = slog.LevelDebug
  48. }
  49. slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
  50. Level: level,
  51. AddSource: true,
  52. ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
  53. if attr.Key == slog.SourceKey {
  54. source := attr.Value.Any().(*slog.Source)
  55. source.File = filepath.Base(source.File)
  56. }
  57. return attr
  58. },
  59. })))
  60. m, err := model.New(flag.Arg(0))
  61. if err != nil {
  62. return err
  63. }
  64. inputIDs, err := m.(model.TextProcessor).Encode(prompt)
  65. if err != nil {
  66. return err
  67. }
  68. var opts []model.OptionsFunc
  69. if args.cache {
  70. opts = append(opts, model.WithCache(&cache.Simple{
  71. Capacity: 2048,
  72. DType: ml.DTypeF32,
  73. }))
  74. }
  75. if args.image != "" {
  76. if err := func() error {
  77. f, err := os.Open(args.image)
  78. if err != nil {
  79. return err
  80. }
  81. defer f.Close()
  82. img, _, err := image.Decode(f)
  83. if err != nil {
  84. return err
  85. }
  86. opts = append(opts, model.WithImage(img))
  87. return nil
  88. }(); err != nil {
  89. return err
  90. }
  91. }
  92. pushdownSampler := sample.NewPushdownSampler(m.(model.TextProcessor))
  93. var offset int
  94. var stringBuffer string
  95. var firstTokenTime time.Duration
  96. for range args.n {
  97. logit, err := model.Forward(m, append(opts, model.WithInputIDs(inputIDs), model.WithOffset(offset))...)
  98. if err != nil {
  99. return err
  100. }
  101. f32s := logit.Floats()
  102. f64s := make([]float64, len(f32s))
  103. for i, f32 := range f32s {
  104. f64s[i] = float64(f32)
  105. }
  106. sampleTime := time.Now()
  107. samplers := []sample.Sampler{
  108. pushdownSampler,
  109. // sample.Weighed(),
  110. // sample.TopP(0.9),
  111. // sample.Weighed(),
  112. sample.Greedy(),
  113. }
  114. f64s, err = sample.Sample(f64s, samplers...)
  115. if err != nil {
  116. return err
  117. }
  118. finishTime := time.Now()
  119. fmt.Printf("Sample time: %vms\n", finishTime.Sub(sampleTime).Milliseconds())
  120. var outputIDs []int32
  121. for _, f64 := range f64s {
  122. if !m.(model.TextProcessor).Is(uint32(f64), model.SpecialEOS) {
  123. outputIDs = append(outputIDs, int32(f64))
  124. }
  125. }
  126. if len(outputIDs) == 0 {
  127. break
  128. }
  129. s, err := m.(model.TextProcessor).Decode(outputIDs)
  130. if errors.Is(err, io.EOF) {
  131. break
  132. } else if err != nil {
  133. return err
  134. }
  135. if firstTokenTime == 0 {
  136. firstTokenTime = time.Since(start)
  137. fmt.Printf("Time to first token: %vms\n", firstTokenTime.Milliseconds())
  138. }
  139. // fmt.Printf("--- token: %q\n", s)
  140. // fmt.Printf("--- outputIDs: %v\n", outputIDs)
  141. stringBuffer += s
  142. fmt.Println("--- stringBuffer", stringBuffer)
  143. err = pushdownSampler.UpdateState(outputIDs)
  144. if err != nil {
  145. return err
  146. }
  147. inputIDs = append(inputIDs, outputIDs...)
  148. if args.cache {
  149. offset = len(inputIDs) - 1
  150. }
  151. }
  152. fmt.Println("\n------ Output: ------")
  153. fmt.Println(stringBuffer)
  154. fmt.Println("--------------------")
  155. return nil
  156. }
  157. func main() {
  158. if err := temp(); err != nil {
  159. fmt.Println("err", err)
  160. os.Exit(1)
  161. }
  162. }