main.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. package main
  2. import (
  3. "flag"
  4. "fmt"
  5. "io"
  6. "log"
  7. "os"
  8. "strings"
  9. "github.com/ollama/ollama/llama"
  10. )
  11. func main() {
  12. mp := flag.String("model", "", "Path to model binary file")
  13. pp := flag.String("projector", "", "Path to projector binary file")
  14. image := flag.String("image", "", "Path to image file")
  15. prompt := flag.String("prompt", " [INST] What is in the picture? <image> [/INST]", "Prompt including <image> tag")
  16. flag.Parse()
  17. // load the model
  18. llama.BackendInit()
  19. params := llama.NewModelParams()
  20. model := llama.LoadModelFromFile(*mp, params)
  21. ctxParams := llama.NewContextParams()
  22. // language model context
  23. lc := llama.NewContextWithModel(model, ctxParams)
  24. // clip context
  25. clipCtx := llama.NewClipContext(*pp)
  26. // open image file
  27. file, err := os.Open(*image)
  28. if err != nil {
  29. panic(err)
  30. }
  31. defer file.Close()
  32. data, err := io.ReadAll(file)
  33. if err != nil {
  34. log.Fatal(err)
  35. }
  36. embedding := llama.NewLlavaImageEmbed(clipCtx, data)
  37. parts := strings.Split(*prompt, "<image>")
  38. if len(parts) != 2 {
  39. panic("prompt must contain exactly one <image>")
  40. }
  41. err = eval(lc, parts[0], embedding, parts[1])
  42. if err != nil {
  43. panic(err)
  44. }
  45. }
  46. func eval(lc *llama.Context, before string, embedding *llama.LlavaImageEmbed, after string) error {
  47. beforeTokens, err := lc.Model().Tokenize(before, 2048, true, true)
  48. if err != nil {
  49. return err
  50. }
  51. afterTokens, err := lc.Model().Tokenize(after, 2048, true, true)
  52. if err != nil {
  53. return err
  54. }
  55. // eval before
  56. batch := llama.NewBatch(512, 0, 1)
  57. var nPast int
  58. // prompt eval
  59. for _, t := range beforeTokens {
  60. batch.Add(t, nPast, []int{0}, true)
  61. nPast++
  62. }
  63. err = lc.Decode(batch)
  64. if err != nil {
  65. return err
  66. }
  67. // batch.Clear()
  68. llama.LlavaEvalImageEmbed(lc, embedding, 512, &nPast)
  69. batch = llama.NewBatch(512, 0, 1)
  70. for _, t := range afterTokens {
  71. batch.Add(t, nPast, []int{0}, true)
  72. }
  73. // main loop
  74. for n := nPast; n < 4096; n++ {
  75. err = lc.Decode(batch)
  76. if err != nil {
  77. panic("Failed to decode")
  78. }
  79. // sample a token
  80. token := lc.SampleTokenGreedy(batch)
  81. // if it's an end of sequence token, break
  82. if lc.Model().TokenIsEog(token) {
  83. break
  84. }
  85. // print the token
  86. str := lc.Model().TokenToPiece(token)
  87. fmt.Print(str)
  88. batch.Clear()
  89. batch.Add(token, n, []int{0}, true)
  90. }
  91. return nil
  92. }