main.go 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. package main
  2. import (
  3. "errors"
  4. "flag"
  5. "fmt"
  6. "image"
  7. "io"
  8. "log/slog"
  9. "os"
  10. "path/filepath"
  11. "strings"
  12. "time"
  13. "github.com/ollama/ollama/cache"
  14. "github.com/ollama/ollama/ml"
  15. "github.com/ollama/ollama/model"
  16. _ "github.com/ollama/ollama/model/llama"
  17. _ "github.com/ollama/ollama/model/mllama"
  18. "github.com/ollama/ollama/sample"
  19. )
  20. var args struct {
  21. n int
  22. debug bool
  23. image string
  24. cache bool
  25. }
  26. func temp() error {
  27. // start := time.Now()
  28. flag.IntVar(&args.n, "n", 10, "number of samples")
  29. flag.BoolVar(&args.debug, "debug", false, "enable debug logging")
  30. flag.StringVar(&args.image, "image", "", "path to image file")
  31. flag.BoolVar(&args.cache, "cache", false, "enable KV cache")
  32. flag.Parse()
  33. var prompt string
  34. if n := len(flag.Args()); n == 1 {
  35. bts, err := io.ReadAll(os.Stdin)
  36. if err != nil {
  37. return err
  38. }
  39. prompt = string(bts)
  40. } else if n > 1 {
  41. prompt = strings.Join(flag.Args()[1:], " ")
  42. } else {
  43. return fmt.Errorf("usage: %s path/to/file <prompt\n", filepath.Base(os.Args[0]))
  44. }
  45. level := slog.LevelInfo
  46. if args.debug {
  47. level = slog.LevelDebug
  48. }
  49. slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
  50. Level: level,
  51. AddSource: true,
  52. ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
  53. if attr.Key == slog.SourceKey {
  54. source := attr.Value.Any().(*slog.Source)
  55. source.File = filepath.Base(source.File)
  56. }
  57. return attr
  58. },
  59. })))
  60. m, err := model.New(flag.Arg(0))
  61. if err != nil {
  62. return err
  63. }
  64. inputIDs, err := m.(model.TextProcessor).Encode(prompt)
  65. if err != nil {
  66. return err
  67. }
  68. var opts []model.OptionsFunc
  69. if args.cache {
  70. opts = append(opts, model.WithCache(&cache.Simple{
  71. Capacity: 2048,
  72. DType: ml.DTypeF32,
  73. }))
  74. }
  75. if args.image != "" {
  76. if err := func() error {
  77. f, err := os.Open(args.image)
  78. if err != nil {
  79. return err
  80. }
  81. defer f.Close()
  82. img, _, err := image.Decode(f)
  83. if err != nil {
  84. return err
  85. }
  86. opts = append(opts, model.WithImage(img))
  87. return nil
  88. }(); err != nil {
  89. return err
  90. }
  91. }
  92. // Schema for a list of friends with their info
  93. // Maps to JSON like:
  94. // {
  95. // "name": "string",
  96. // "age": integer,
  97. // "is_available": boolean
  98. // }
  99. schema := &sample.Schema{
  100. Name: "root",
  101. Type: "object",
  102. Properties: []*sample.Schema{
  103. {Name: "name", Type: "string"},
  104. {Name: "age", Type: "integer"},
  105. {Name: "is_available", Type: "boolean"},
  106. },
  107. }
  108. // fmt.Println("schema", schema)
  109. // schema = nil
  110. jsonTransform, err := sample.NewJSONSampler(m.(model.TextProcessor), schema)
  111. if err != nil {
  112. return err
  113. }
  114. transforms := []sample.Transform{
  115. jsonTransform,
  116. }
  117. var offset int
  118. var stringBuffer string
  119. // var ttft time.Duration
  120. var totalSamplingTime time.Duration
  121. count := 0
  122. for range args.n {
  123. logits, err := model.Forward(m, append(opts, model.WithInputIDs(inputIDs), model.WithOffset(offset))...)
  124. if err != nil {
  125. return err
  126. }
  127. samplingStart := time.Now()
  128. sampler := sample.Greedy()
  129. sampledIdx, err := sampler.Sample(logits.Floats(), transforms...)
  130. if err != nil {
  131. return err
  132. }
  133. samplingTime := time.Since(samplingStart)
  134. totalSamplingTime += samplingTime
  135. // fmt.Println("sampling time", samplingTime)
  136. // fmt.Printf("Sample time: %vms\n", finishTime.Sub(sampleTime).Milliseconds())
  137. var outputIDs []int32
  138. if !m.(model.TextProcessor).Is(uint32(sampledIdx), model.SpecialEOS) {
  139. outputIDs = append(outputIDs, int32(sampledIdx))
  140. }
  141. if len(outputIDs) == 0 {
  142. break
  143. }
  144. s, err := m.(model.TextProcessor).Decode(outputIDs)
  145. if errors.Is(err, io.EOF) {
  146. break
  147. } else if err != nil {
  148. return err
  149. }
  150. // if ttft == 0 {
  151. // ttft = time.Since(start)
  152. // fmt.Printf("Time to first token: %vms\n", ttft.Milliseconds())
  153. // }
  154. // fmt.Printf("--- token: %q\n", s)
  155. // fmt.Printf("--- outputIDs: %v\n", outputIDs)
  156. stringBuffer += s
  157. count++
  158. fmt.Println("--- stringBuffer", stringBuffer)
  159. outputIDs, err = jsonTransform.UpdateState(outputIDs)
  160. if err != nil {
  161. return err
  162. }
  163. // can do fun shifting stuff here if needed
  164. inputIDs = append(inputIDs, outputIDs...)
  165. if args.cache {
  166. offset = len(inputIDs) - 1
  167. }
  168. }
  169. fmt.Println("\n------ Output: ------")
  170. fmt.Println(stringBuffer)
  171. fmt.Println("--------------------")
  172. fmt.Println("sample average time", totalSamplingTime/time.Duration(count))
  173. return nil
  174. }
  175. func main() {
  176. if err := temp(); err != nil {
  177. fmt.Println("err", err)
  178. os.Exit(1)
  179. }
  180. }