main.go 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128
  1. package main
  2. import (
  3. "flag"
  4. "fmt"
  5. "io"
  6. "log"
  7. "os"
  8. "strings"
  9. "github.com/ollama/ollama/llama"
  10. )
  11. func main() {
  12. mpath := flag.String("model", "", "Path to model binary file")
  13. ppath := flag.String("projector", "", "Path to projector binary file")
  14. image := flag.String("image", "", "Path to image file")
  15. prompt := flag.String("prompt", "", "Prompt including <image> tag")
  16. flag.Parse()
  17. if *mpath == "" {
  18. panic("model path is required")
  19. }
  20. if *prompt == "" {
  21. panic("prompt is required")
  22. }
  23. // load the model
  24. llama.BackendInit()
  25. params := llama.NewModelParams()
  26. model := llama.LoadModelFromFile(*mpath, params)
  27. ctxParams := llama.NewContextParams()
  28. // language model context
  29. lc := llama.NewContextWithModel(model, ctxParams)
  30. // eval before
  31. batch := llama.NewBatch(512, 0, 1)
  32. var nPast int
  33. // clip context
  34. var clipCtx *llama.ClipContext
  35. // multi-modal
  36. if *ppath == "" {
  37. clipCtx = llama.NewClipContext(*ppath)
  38. // open image file
  39. file, err := os.Open(*image)
  40. if err != nil {
  41. panic(err)
  42. }
  43. defer file.Close()
  44. data, err := io.ReadAll(file)
  45. if err != nil {
  46. log.Fatal(err)
  47. }
  48. embedding := llama.NewLlavaImageEmbed(clipCtx, data)
  49. parts := strings.Split(*prompt, "<image>")
  50. if len(parts) != 2 {
  51. panic("prompt must contain exactly one <image>")
  52. }
  53. beforeTokens, err := lc.Model().Tokenize(parts[0], 2048, true, true)
  54. if err != nil {
  55. panic(err)
  56. }
  57. for _, t := range beforeTokens {
  58. batch.Add(t, nPast, []int{0}, true)
  59. nPast++
  60. }
  61. err = lc.Decode(batch)
  62. if err != nil {
  63. panic(err)
  64. }
  65. llama.LlavaEvalImageEmbed(lc, embedding, 512, &nPast)
  66. afterTokens, err := lc.Model().Tokenize(parts[1], 2048, true, true)
  67. if err != nil {
  68. panic(err)
  69. }
  70. for _, t := range afterTokens {
  71. batch.Add(t, nPast, []int{0}, true)
  72. nPast++
  73. }
  74. } else {
  75. tokens, err := lc.Model().Tokenize(*prompt, 2048, true, true)
  76. if err != nil {
  77. panic(err)
  78. }
  79. for _, t := range tokens {
  80. batch.Add(t, nPast, []int{0}, true)
  81. nPast++
  82. }
  83. }
  84. // main loop
  85. for n := nPast; n < 4096; n++ {
  86. err := lc.Decode(batch)
  87. if err != nil {
  88. panic(err)
  89. }
  90. // sample a token
  91. logits := lc.GetLogitsIth(batch.NumTokens() - 1)
  92. token := lc.SampleTokenGreedy(logits)
  93. // if it's an end of sequence token, break
  94. if lc.Model().TokenIsEog(token) {
  95. break
  96. }
  97. // print the token
  98. str := lc.Model().TokenToPiece(token)
  99. fmt.Print(str)
  100. batch.Clear()
  101. batch.Add(token, n, []int{0}, true)
  102. }
  103. }