main.go 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. package main
  2. import (
  3. "flag"
  4. "fmt"
  5. "io"
  6. "log"
  7. "os"
  8. "runtime"
  9. "strings"
  10. "github.com/ollama/ollama/llama"
  11. )
  12. func main() {
  13. mpath := flag.String("model", "", "Path to model binary file")
  14. ppath := flag.String("projector", "", "Path to projector binary file")
  15. image := flag.String("image", "", "Path to image file")
  16. prompt := flag.String("prompt", "", "Prompt including <image> tag")
  17. flag.Parse()
  18. if *mpath == "" {
  19. panic("model path is required")
  20. }
  21. if *prompt == "" {
  22. panic("prompt is required")
  23. }
  24. // load the model
  25. llama.BackendInit()
  26. params := llama.ModelParams{
  27. NumGpuLayers: 999,
  28. MainGpu: 0,
  29. UseMmap: true,
  30. Progress: func(p float32) {
  31. fmt.Printf("loading... %f\n", p)
  32. },
  33. }
  34. model := llama.LoadModelFromFile(*mpath, params)
  35. ctxParams := llama.NewContextParams(2048, runtime.NumCPU(), false)
  36. // language model context
  37. lc := llama.NewContextWithModel(model, ctxParams)
  38. // eval before
  39. batch := llama.NewBatch(512, 0, 1)
  40. var nPast int
  41. // clip context
  42. var clipCtx *llama.ClipContext
  43. // multi-modal
  44. if *ppath != "" {
  45. clipCtx = llama.NewClipContext(*ppath)
  46. // open image file
  47. file, err := os.Open(*image)
  48. if err != nil {
  49. panic(err)
  50. }
  51. defer file.Close()
  52. data, err := io.ReadAll(file)
  53. if err != nil {
  54. log.Fatal(err)
  55. }
  56. embedding := llama.NewLlavaImageEmbed(clipCtx, data)
  57. parts := strings.Split(*prompt, "<image>")
  58. if len(parts) != 2 {
  59. panic("prompt must contain exactly one <image>")
  60. }
  61. beforeTokens, err := lc.Model().Tokenize(parts[0], true, true)
  62. if err != nil {
  63. panic(err)
  64. }
  65. for _, t := range beforeTokens {
  66. batch.Add(t, nPast, []int{0}, true)
  67. nPast++
  68. }
  69. err = lc.Decode(batch)
  70. if err != nil {
  71. panic(err)
  72. }
  73. llama.LlavaEvalImageEmbed(lc, embedding, 512, &nPast)
  74. afterTokens, err := lc.Model().Tokenize(parts[1], true, true)
  75. if err != nil {
  76. panic(err)
  77. }
  78. for _, t := range afterTokens {
  79. batch.Add(t, nPast, []int{0}, true)
  80. nPast++
  81. }
  82. } else {
  83. tokens, err := lc.Model().Tokenize(*prompt, true, true)
  84. if err != nil {
  85. panic(err)
  86. }
  87. for _, t := range tokens {
  88. batch.Add(t, nPast, []int{0}, true)
  89. nPast++
  90. }
  91. }
  92. // main loop
  93. for n := nPast; n < 4096; n++ {
  94. err := lc.Decode(batch)
  95. if err != nil {
  96. panic(err)
  97. }
  98. // sample a token
  99. logits := lc.GetLogitsIth(batch.NumTokens() - 1)
  100. token := lc.SampleTokenGreedy(logits)
  101. // if it's an end of sequence token, break
  102. if lc.Model().TokenIsEog(token) {
  103. break
  104. }
  105. // print the token
  106. str := lc.Model().TokenToPiece(token)
  107. fmt.Print(str)
  108. batch.Clear()
  109. batch.Add(token, n, []int{0}, true)
  110. }
  111. }