123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183 |
- package llamarunner
- import (
- "errors"
- "fmt"
- "hash/maphash"
- "log/slog"
- "slices"
- "sync"
- "time"
- "github.com/ollama/ollama/llama"
- )
- const imageCacheSize = 4
- type ImageContext struct {
- // mu is required to be held when generating embeddings or accessing the cache
- mu sync.Mutex
- clip *llama.ClipContext
- mllama *llama.MllamaContext
- // cache of images to embeddings
- images []imageCache
- imageHash maphash.Hash
- }
- func NewImageContext(llamaContext *llama.Context, modelPath string) (*ImageContext, error) {
- arch, err := llama.GetModelArch(modelPath)
- if err != nil {
- return nil, fmt.Errorf("unable to determine vision architecture: %w (%s)", err, modelPath)
- }
- var c ImageContext
- if arch == "clip" {
- c.clip, err = llama.NewClipContext(llamaContext, modelPath)
- } else if arch == "mllama" {
- c.mllama, err = llama.NewMllamaContext(llamaContext, modelPath)
- } else {
- return nil, fmt.Errorf("unknown vision model architecture: %s", arch)
- }
- if err != nil {
- return nil, err
- }
- c.images = make([]imageCache, imageCacheSize)
- return &c, nil
- }
- func (c *ImageContext) Free(modelPath string) {
- if c == nil {
- return
- }
- if c.clip != nil {
- c.clip.Free()
- }
- if c.mllama != nil {
- c.mllama.Free()
- }
- }
- func (c *ImageContext) NewEmbed(llamaContext *llama.Context, data []byte, aspectRatioId int) ([][]float32, error) {
- if c == nil {
- return nil, nil
- }
- if len(data) <= 0 {
- return nil, errors.New("received zero length image")
- }
- hash := c.hashImage(data)
- c.mu.Lock()
- defer c.mu.Unlock()
- embed, err := c.findImage(hash)
- if err != nil {
- if c.mllama != nil {
- embed, err = c.mllama.NewEmbed(llamaContext, data, aspectRatioId)
- if err != nil {
- return nil, err
- }
- } else if c.clip != nil {
- embed, err = c.clip.NewEmbed(llamaContext, data)
- if err != nil {
- return nil, err
- }
- } else {
- return nil, errors.New("received image but vision model not loaded")
- }
- c.addImage(hash, embed)
- }
- return embed, nil
- }
- func (c *ImageContext) BatchSize(configuredBatchSize int) int {
- // If images are not supported, we don't need to allocate embedding batches
- if c == nil {
- return 0
- }
- // Mllama maps an image to 1 embedding token (llava creates many tokens)
- // and doesn't support more than a single image per request.
- // The embeddings are large (100 MB), so allocating a big batch can fail
- // on some systems
- if c.mllama != nil {
- return 1
- }
- return configuredBatchSize
- }
- func (c *ImageContext) EmbedSize(llamaContext *llama.Context) int {
- if c != nil && c.mllama != nil {
- return c.mllama.EmbedSize(llamaContext)
- } else {
- return llamaContext.Model().NEmbd()
- }
- }
- func (c *ImageContext) NeedCrossAttention(inputs ...input) bool {
- if c == nil || c.mllama == nil {
- return false
- }
- return slices.ContainsFunc(inputs, func(input input) bool {
- return input.embed != nil
- })
- }
- type imageCache struct {
- key uint64
- val [][]float32
- lastUsed time.Time
- }
- func (c *ImageContext) hashImage(image []byte) uint64 {
- c.imageHash.Reset()
- _, _ = c.imageHash.Write(image)
- return c.imageHash.Sum64()
- }
- var errImageNotFound = errors.New("image not found in cache")
- func (c *ImageContext) findImage(hash uint64) ([][]float32, error) {
- for i := range c.images {
- if c.images[i].key == hash {
- slog.Debug("loading image embeddings from cache", "entry", i)
- c.images[i].lastUsed = time.Now()
- return c.images[i].val, nil
- }
- }
- return nil, errImageNotFound
- }
- func (c *ImageContext) addImage(hash uint64, embed [][]float32) {
- best := time.Now()
- var bestImage int
- for i := range c.images {
- if c.images[i].key == hash {
- bestImage = i
- break
- }
- if c.images[i].lastUsed.Compare(best) < 0 {
- best = c.images[i].lastUsed
- bestImage = i
- }
- }
- slog.Debug("storing image embeddings in cache", "entry", bestImage, "used", c.images[bestImage].lastUsed)
- c.images[bestImage].key = hash
- c.images[bestImage].val = embed
- c.images[bestImage].lastUsed = time.Now()
- }
|