main.go 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. package main
  2. import (
  3. "flag"
  4. "fmt"
  5. "io"
  6. "log"
  7. "os"
  8. "runtime"
  9. "strings"
  10. "github.com/ollama/ollama/llama"
  11. )
  12. func main() {
  13. mpath := flag.String("model", "", "Path to model binary file")
  14. ppath := flag.String("projector", "", "Path to projector binary file")
  15. image := flag.String("image", "", "Path to image file")
  16. prompt := flag.String("prompt", "", "Prompt including <image> tag")
  17. flag.Parse()
  18. if *mpath == "" {
  19. panic("model path is required")
  20. }
  21. if *prompt == "" {
  22. panic("prompt is required")
  23. }
  24. // load the model
  25. llama.BackendInit()
  26. params := llama.NewModelParams(999, 0, func(p float32) {
  27. fmt.Printf("loading... %f\n", p)
  28. })
  29. model := llama.LoadModelFromFile(*mpath, params)
  30. ctxParams := llama.NewContextParams(2048, runtime.NumCPU(), false)
  31. // language model context
  32. lc := llama.NewContextWithModel(model, ctxParams)
  33. // eval before
  34. batch := llama.NewBatch(512, 0, 1)
  35. var nPast int
  36. // clip context
  37. var clipCtx *llama.ClipContext
  38. // multi-modal
  39. if *ppath == "" {
  40. clipCtx = llama.NewClipContext(*ppath)
  41. // open image file
  42. file, err := os.Open(*image)
  43. if err != nil {
  44. panic(err)
  45. }
  46. defer file.Close()
  47. data, err := io.ReadAll(file)
  48. if err != nil {
  49. log.Fatal(err)
  50. }
  51. embedding := llama.NewLlavaImageEmbed(clipCtx, data)
  52. parts := strings.Split(*prompt, "<image>")
  53. if len(parts) != 2 {
  54. panic("prompt must contain exactly one <image>")
  55. }
  56. beforeTokens, err := lc.Model().Tokenize(parts[0], true, true)
  57. if err != nil {
  58. panic(err)
  59. }
  60. for _, t := range beforeTokens {
  61. batch.Add(t, nPast, []int{0}, true)
  62. nPast++
  63. }
  64. err = lc.Decode(batch)
  65. if err != nil {
  66. panic(err)
  67. }
  68. llama.LlavaEvalImageEmbed(lc, embedding, 512, &nPast)
  69. afterTokens, err := lc.Model().Tokenize(parts[1], true, true)
  70. if err != nil {
  71. panic(err)
  72. }
  73. for _, t := range afterTokens {
  74. batch.Add(t, nPast, []int{0}, true)
  75. nPast++
  76. }
  77. } else {
  78. tokens, err := lc.Model().Tokenize(*prompt, true, true)
  79. if err != nil {
  80. panic(err)
  81. }
  82. for _, t := range tokens {
  83. batch.Add(t, nPast, []int{0}, true)
  84. nPast++
  85. }
  86. }
  87. // main loop
  88. for n := nPast; n < 4096; n++ {
  89. err := lc.Decode(batch)
  90. if err != nil {
  91. panic(err)
  92. }
  93. // sample a token
  94. logits := lc.GetLogitsIth(batch.NumTokens() - 1)
  95. token := lc.SampleTokenGreedy(logits)
  96. // if it's an end of sequence token, break
  97. if lc.Model().TokenIsEog(token) {
  98. break
  99. }
  100. // print the token
  101. str := lc.Model().TokenToPiece(token)
  102. fmt.Print(str)
  103. batch.Clear()
  104. batch.Add(token, n, []int{0}, true)
  105. }
  106. }