main.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167
  1. package main
  2. import (
  3. "errors"
  4. "flag"
  5. "fmt"
  6. "image"
  7. "io"
  8. "log/slog"
  9. "os"
  10. "path/filepath"
  11. "strings"
  12. "github.com/ollama/ollama/cache"
  13. "github.com/ollama/ollama/ml"
  14. "github.com/ollama/ollama/model"
  15. _ "github.com/ollama/ollama/model/llama"
  16. _ "github.com/ollama/ollama/model/mllama"
  17. "github.com/ollama/ollama/sample"
  18. )
  19. var args struct {
  20. n int
  21. debug bool
  22. image string
  23. cache bool
  24. }
  25. func temp() error {
  26. flag.IntVar(&args.n, "n", 10, "number of samples")
  27. flag.BoolVar(&args.debug, "debug", false, "enable debug logging")
  28. flag.StringVar(&args.image, "image", "", "path to image file")
  29. flag.BoolVar(&args.cache, "cache", false, "enable KV cache")
  30. flag.Parse()
  31. var prompt string
  32. if n := len(flag.Args()); n == 1 {
  33. bts, err := io.ReadAll(os.Stdin)
  34. if err != nil {
  35. return err
  36. }
  37. prompt = string(bts)
  38. } else if n > 1 {
  39. prompt = strings.Join(flag.Args()[1:], " ")
  40. } else {
  41. return fmt.Errorf("usage: %s path/to/file <prompt\n", filepath.Base(os.Args[0]))
  42. }
  43. level := slog.LevelInfo
  44. if args.debug {
  45. level = slog.LevelDebug
  46. }
  47. slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
  48. Level: level,
  49. AddSource: true,
  50. ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
  51. if attr.Key == slog.SourceKey {
  52. source := attr.Value.Any().(*slog.Source)
  53. source.File = filepath.Base(source.File)
  54. }
  55. return attr
  56. },
  57. })))
  58. m, err := model.New(flag.Arg(0))
  59. if err != nil {
  60. return err
  61. }
  62. inputIDs, err := m.(model.TextProcessor).Encode(prompt)
  63. if err != nil {
  64. return err
  65. }
  66. var opts []model.OptionsFunc
  67. if args.cache {
  68. opts = append(opts, model.WithCache(&cache.Simple{
  69. Capacity: 2048,
  70. DType: ml.DTypeF32,
  71. }))
  72. }
  73. if args.image != "" {
  74. if err := func() error {
  75. f, err := os.Open(args.image)
  76. if err != nil {
  77. return err
  78. }
  79. defer f.Close()
  80. img, _, err := image.Decode(f)
  81. if err != nil {
  82. return err
  83. }
  84. opts = append(opts, model.WithImage(img))
  85. return nil
  86. }(); err != nil {
  87. return err
  88. }
  89. }
  90. pdaSampler := sample.NewPushdownSampler(m.(model.TextProcessor))
  91. var stringBuffer string
  92. var offset int
  93. for range args.n {
  94. logit, err := model.Forward(m, append(opts, model.WithInputIDs(inputIDs), model.WithOffset(offset))...)
  95. if err != nil {
  96. return err
  97. }
  98. f32s := logit.Floats()
  99. f64s := make([]float64, len(f32s))
  100. for i, f32 := range f32s {
  101. f64s[i] = float64(f32)
  102. }
  103. // do sampling
  104. // []ints back
  105. // ints map to sampled logits
  106. f64s, err = sample.Sample(f64s, pdaSampler, sample.Greedy())
  107. if err != nil {
  108. return err
  109. }
  110. var outputIDs []int32
  111. for _, f64 := range f64s {
  112. if !m.(model.TextProcessor).Is(uint32(f64), model.SpecialEOS) {
  113. outputIDs = append(outputIDs, int32(f64))
  114. }
  115. }
  116. pdaSampler.UpdateState(outputIDs)
  117. if len(outputIDs) == 0 {
  118. break
  119. }
  120. s, err := m.(model.TextProcessor).Decode(outputIDs)
  121. if errors.Is(err, io.EOF) {
  122. break
  123. } else if err != nil {
  124. return err
  125. }
  126. // fmt.Print(s)
  127. stringBuffer += s
  128. fmt.Println("--- stringBuffer", stringBuffer)
  129. inputIDs = append(inputIDs, outputIDs...)
  130. if args.cache {
  131. offset = len(inputIDs) - 1
  132. }
  133. }
  134. return nil
  135. }
  136. func main() {
  137. if err := temp(); err != nil {
  138. fmt.Println("err", err)
  139. os.Exit(1)
  140. }
  141. }