main.go 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. package main
  2. import (
  3. "errors"
  4. "flag"
  5. "fmt"
  6. "image"
  7. "io"
  8. "log/slog"
  9. "os"
  10. "path/filepath"
  11. "strings"
  12. "github.com/ollama/ollama/cache"
  13. "github.com/ollama/ollama/ml"
  14. "github.com/ollama/ollama/model"
  15. _ "github.com/ollama/ollama/model/llama"
  16. _ "github.com/ollama/ollama/model/mllama"
  17. "github.com/ollama/ollama/sample"
  18. )
  19. var args struct {
  20. n int
  21. debug bool
  22. image string
  23. cache bool
  24. }
  25. func temp() error {
  26. flag.IntVar(&args.n, "n", 10, "number of samples")
  27. flag.BoolVar(&args.debug, "debug", false, "enable debug logging")
  28. flag.StringVar(&args.image, "image", "", "path to image file")
  29. flag.BoolVar(&args.cache, "cache", false, "enable KV cache")
  30. flag.Parse()
  31. var prompt string
  32. if n := len(flag.Args()); n == 1 {
  33. bts, err := io.ReadAll(os.Stdin)
  34. if err != nil {
  35. return err
  36. }
  37. prompt = string(bts)
  38. } else if n > 1 {
  39. prompt = strings.Join(flag.Args()[1:], " ")
  40. } else {
  41. return fmt.Errorf("usage: %s path/to/file <prompt\n", filepath.Base(os.Args[0]))
  42. }
  43. level := slog.LevelInfo
  44. if args.debug {
  45. level = slog.LevelDebug
  46. }
  47. slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
  48. Level: level,
  49. AddSource: true,
  50. ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
  51. if attr.Key == slog.SourceKey {
  52. source := attr.Value.Any().(*slog.Source)
  53. source.File = filepath.Base(source.File)
  54. }
  55. return attr
  56. },
  57. })))
  58. m, err := model.New(flag.Arg(0))
  59. if err != nil {
  60. return err
  61. }
  62. inputIDs, err := m.(model.TextProcessor).Encode(prompt)
  63. if err != nil {
  64. return err
  65. }
  66. var opts []model.OptionsFunc
  67. if args.cache {
  68. opts = append(opts, model.WithCache(&cache.Simple{
  69. Capacity: 2048,
  70. DType: ml.DTypeF32,
  71. }))
  72. }
  73. if args.image != "" {
  74. if err := func() error {
  75. f, err := os.Open(args.image)
  76. if err != nil {
  77. return err
  78. }
  79. defer f.Close()
  80. img, _, err := image.Decode(f)
  81. if err != nil {
  82. return err
  83. }
  84. opts = append(opts, model.WithImage(img))
  85. return nil
  86. }(); err != nil {
  87. return err
  88. }
  89. }
  90. var offset int
  91. for range args.n {
  92. logit, err := model.Forward(m, append(opts, model.WithInputIDs(inputIDs), model.WithOffset(offset))...)
  93. if err != nil {
  94. return err
  95. }
  96. f32s := logit.Floats()
  97. f64s := make([]float64, len(f32s))
  98. for i, f32 := range f32s {
  99. f64s[i] = float64(f32)
  100. }
  101. // do sampling
  102. f64s, err = sample.Sample(f64s, sample.Greedy())
  103. if err != nil {
  104. return err
  105. }
  106. var outputIDs []int32
  107. for _, f64 := range f64s {
  108. if !m.(model.TextProcessor).Is(uint32(f64), model.SpecialEOS) {
  109. outputIDs = append(outputIDs, int32(f64))
  110. }
  111. }
  112. if len(outputIDs) == 0 {
  113. break
  114. }
  115. s, err := m.(model.TextProcessor).Decode(outputIDs)
  116. if errors.Is(err, io.EOF) {
  117. break
  118. } else if err != nil {
  119. return err
  120. }
  121. fmt.Print(s)
  122. inputIDs = append(inputIDs, outputIDs...)
  123. if args.cache {
  124. offset = len(inputIDs) - 1
  125. }
  126. }
  127. return nil
  128. }
  129. func main() {
  130. if err := temp(); err != nil {
  131. fmt.Println("err", err)
  132. os.Exit(1)
  133. }
  134. }