浏览代码

add `llava` to `runner`

jmorganca 11 月之前
父节点
当前提交
fbc8572859
共有 3 个文件被更改,包括 135 次插入42 次删除
  1. 20 16
      llama/llama.go
  2. 7 7
      llama/llava/main.go
  3. 108 19
      llama/runner/runner.go

+ 20 - 16
llama/llama.go

@@ -38,10 +38,6 @@ import (
 	"github.com/ollama/ollama/llm"
 	"github.com/ollama/ollama/llm"
 )
 )
 
 
-type Token int32
-type Pos int32
-type SeqId int32
-
 // SystemInfo is an unused example of calling llama.cpp functions using CGo
 // SystemInfo is an unused example of calling llama.cpp functions using CGo
 func PrintSystemInfo() string {
 func PrintSystemInfo() string {
 	return C.GoString(C.llama_print_system_info())
 	return C.GoString(C.llama_print_system_info())
@@ -78,6 +74,10 @@ type Context struct {
 	c *C.struct_llama_context
 	c *C.struct_llama_context
 }
 }
 
 
+func (c *Context) KvCacheClear() {
+	C.llama_kv_cache_clear(c.c)
+}
+
 func (c *Context) Decode(batch Batch) error {
 func (c *Context) Decode(batch Batch) error {
 	// Positive return values does not mean a fatal error, but rather a warning.
 	// Positive return values does not mean a fatal error, but rather a warning.
 	//   0 - success
 	//   0 - success
@@ -90,18 +90,18 @@ func (c *Context) Decode(batch Batch) error {
 	}
 	}
 
 
 	if code > 0 {
 	if code > 0 {
-		return fmt.Errorf("could not find a KV slot for the batch - try reducing the size of the batch or increase the context. code: %d\n", code)
+		return fmt.Errorf("could not find a KV slot for the batch - try reducing the size of the batch or increase the context. code: %d", code)
 	}
 	}
 
 
 	return nil
 	return nil
 }
 }
 
 
-func (c *Context) GetModel() *Model {
+func (c *Context) Model() *Model {
 	return &Model{c: C.llama_get_model(c.c)}
 	return &Model{c: C.llama_get_model(c.c)}
 }
 }
 
 
-func (c *Context) SampleTokenGreedy(batch Batch) Token {
-	nv := c.GetModel().NumVocab()
+func (c *Context) SampleTokenGreedy(batch Batch) int {
+	nv := c.Model().NumVocab()
 
 
 	// TODO(jmorganca): split this up into different functions
 	// TODO(jmorganca): split this up into different functions
 	candidates := (*C.struct_llama_token_data)(C.malloc(C.size_t(nv) * C.size_t(unsafe.Sizeof(C.struct_llama_token_data{}))))
 	candidates := (*C.struct_llama_token_data)(C.malloc(C.size_t(nv) * C.size_t(unsafe.Sizeof(C.struct_llama_token_data{}))))
@@ -116,7 +116,7 @@ func (c *Context) SampleTokenGreedy(batch Batch) Token {
 		ptr.p = 0.0
 		ptr.p = 0.0
 	}
 	}
 
 
-	return Token(C.llama_sample_token_greedy(c.c, &C.llama_token_data_array{
+	return int(C.llama_sample_token_greedy(c.c, &C.llama_token_data_array{
 		data:   candidates,
 		data:   candidates,
 		size:   C.size_t(nv),
 		size:   C.size_t(nv),
 		sorted: C.bool(false),
 		sorted: C.bool(false),
@@ -135,7 +135,7 @@ func (m *Model) NumVocab() int {
 	return int(C.llama_n_vocab(m.c))
 	return int(C.llama_n_vocab(m.c))
 }
 }
 
 
-func (m *Model) TokenIsEog(token Token) bool {
+func (m *Model) TokenIsEog(token int) bool {
 	return bool(C.llama_token_is_eog(m.c, C.llama_token(token)))
 	return bool(C.llama_token_is_eog(m.c, C.llama_token(token)))
 }
 }
 
 
@@ -151,7 +151,7 @@ func (b *Batch) NumTokens() int {
 	return int(b.c.n_tokens)
 	return int(b.c.n_tokens)
 }
 }
 
 
-func (b *Batch) Add(token Token, pos Pos, seqIds []SeqId, logits bool) {
+func (b *Batch) Add(token int, pos int, seqIds []int, logits bool) {
 	unsafe.Slice(b.c.token, 512)[b.c.n_tokens] = C.llama_token(token)
 	unsafe.Slice(b.c.token, 512)[b.c.n_tokens] = C.llama_token(token)
 	unsafe.Slice(b.c.pos, 512)[b.c.n_tokens] = C.llama_pos(pos)
 	unsafe.Slice(b.c.pos, 512)[b.c.n_tokens] = C.llama_pos(pos)
 	unsafe.Slice(b.c.n_seq_id, 512)[b.c.n_tokens] = C.int(len(seqIds))
 	unsafe.Slice(b.c.n_seq_id, 512)[b.c.n_tokens] = C.int(len(seqIds))
@@ -171,13 +171,17 @@ func (b *Batch) Clear() {
 	b.c.n_tokens = 0
 	b.c.n_tokens = 0
 }
 }
 
 
+func (b *Batch) Free() {
+	C.llama_batch_free(b.c)
+}
+
 // LLAMA_API struct llama_batch llama_batch_get_one(
 // LLAMA_API struct llama_batch llama_batch_get_one(
 //
 //
 //		llama_token * tokens,
 //		llama_token * tokens,
 //			int32_t   n_tokens,
 //			int32_t   n_tokens,
 //		  llama_pos   pos_0,
 //		  llama_pos   pos_0,
 //	   llama_seq_id   seq_id);
 //	   llama_seq_id   seq_id);
-func BatchGetOne(tokens []Token, pos0 Pos, seqId SeqId) Batch {
+func BatchGetOne(tokens []int, pos0 int, seqId int) Batch {
 	return Batch{c: C.llama_batch_get_one((*C.int)(unsafe.Pointer(&tokens[0])), C.int32_t(len(tokens)), C.int(pos0), C.int(seqId))}
 	return Batch{c: C.llama_batch_get_one((*C.int)(unsafe.Pointer(&tokens[0])), C.int32_t(len(tokens)), C.int(pos0), C.int(seqId))}
 }
 }
 
 
@@ -185,7 +189,7 @@ type Model struct {
 	c *C.struct_llama_model
 	c *C.struct_llama_model
 }
 }
 
 
-func (m *Model) TokenToPiece(token Token) string {
+func (m *Model) TokenToPiece(token int) string {
 	buf := make([]byte, 12)
 	buf := make([]byte, 12)
 	C.llama_token_to_piece(
 	C.llama_token_to_piece(
 		m.c,
 		m.c,
@@ -197,7 +201,7 @@ func (m *Model) TokenToPiece(token Token) string {
 	return strings.TrimRight(string(buf), "\x00")
 	return strings.TrimRight(string(buf), "\x00")
 }
 }
 
 
-func (m *Model) Tokenize(text string, maxTokens int, addSpecial bool, parseSpecial bool) ([]Token, error) {
+func (m *Model) Tokenize(text string, maxTokens int, addSpecial bool, parseSpecial bool) ([]int, error) {
 	cTokens := make([]C.llama_token, maxTokens)
 	cTokens := make([]C.llama_token, maxTokens)
 	cText := C.CString(text)
 	cText := C.CString(text)
 	defer C.free(unsafe.Pointer(cText))
 	defer C.free(unsafe.Pointer(cText))
@@ -216,9 +220,9 @@ func (m *Model) Tokenize(text string, maxTokens int, addSpecial bool, parseSpeci
 		return nil, fmt.Errorf("tokenization failed, required %d tokens", -result)
 		return nil, fmt.Errorf("tokenization failed, required %d tokens", -result)
 	}
 	}
 
 
-	tokens := make([]Token, result)
+	tokens := make([]int, result)
 	for i := 0; i < int(result); i++ {
 	for i := 0; i < int(result); i++ {
-		tokens[i] = Token(cTokens[i])
+		tokens[i] = int(cTokens[i])
 	}
 	}
 
 
 	return tokens, nil
 	return tokens, nil

+ 7 - 7
llama/llava/main.go

@@ -56,12 +56,12 @@ func main() {
 }
 }
 
 
 func eval(lc *llama.Context, before string, embedding *llama.LlavaImageEmbed, after string) error {
 func eval(lc *llama.Context, before string, embedding *llama.LlavaImageEmbed, after string) error {
-	beforeTokens, err := lc.GetModel().Tokenize(before, 2048, true, true)
+	beforeTokens, err := lc.Model().Tokenize(before, 2048, true, true)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
 
 
-	afterTokens, err := lc.GetModel().Tokenize(after, 2048, true, true)
+	afterTokens, err := lc.Model().Tokenize(after, 2048, true, true)
 	if err != nil {
 	if err != nil {
 		return err
 		return err
 	}
 	}
@@ -73,7 +73,7 @@ func eval(lc *llama.Context, before string, embedding *llama.LlavaImageEmbed, af
 
 
 	// prompt eval
 	// prompt eval
 	for _, t := range beforeTokens {
 	for _, t := range beforeTokens {
-		batch.Add(t, llama.Pos(nPast), []llama.SeqId{0}, true)
+		batch.Add(t, nPast, []int{0}, true)
 		nPast++
 		nPast++
 	}
 	}
 
 
@@ -88,7 +88,7 @@ func eval(lc *llama.Context, before string, embedding *llama.LlavaImageEmbed, af
 
 
 	batch = llama.NewBatch(512, 0, 1)
 	batch = llama.NewBatch(512, 0, 1)
 	for _, t := range afterTokens {
 	for _, t := range afterTokens {
-		batch.Add(t, llama.Pos(nPast), []llama.SeqId{0}, true)
+		batch.Add(t, nPast, []int{0}, true)
 	}
 	}
 
 
 	// main loop
 	// main loop
@@ -102,15 +102,15 @@ func eval(lc *llama.Context, before string, embedding *llama.LlavaImageEmbed, af
 		token := lc.SampleTokenGreedy(batch)
 		token := lc.SampleTokenGreedy(batch)
 
 
 		// if it's an end of sequence token, break
 		// if it's an end of sequence token, break
-		if lc.GetModel().TokenIsEog(token) {
+		if lc.Model().TokenIsEog(token) {
 			break
 			break
 		}
 		}
 
 
 		// print the token
 		// print the token
-		str := lc.GetModel().TokenToPiece(token)
+		str := lc.Model().TokenToPiece(token)
 		fmt.Print(str)
 		fmt.Print(str)
 		batch.Clear()
 		batch.Clear()
-		batch.Add(token, llama.Pos(n), []llama.SeqId{0}, true)
+		batch.Add(token, n, []int{0}, true)
 	}
 	}
 
 
 	return nil
 	return nil

+ 108 - 19
llama/runner/runner.go

@@ -1,19 +1,24 @@
 package main
 package main
 
 
 import (
 import (
+	"encoding/base64"
 	"encoding/json"
 	"encoding/json"
 	"flag"
 	"flag"
 	"fmt"
 	"fmt"
 	"log"
 	"log"
+	"log/slog"
 	"net"
 	"net"
 	"net/http"
 	"net/http"
+	"regexp"
+	"strconv"
 	"sync"
 	"sync"
 
 
 	"github.com/ollama/ollama/llama"
 	"github.com/ollama/ollama/llama"
 )
 )
 
 
 type Request struct {
 type Request struct {
-	Prompt string `json:"prompt"`
+	Prompt string   `json:"prompt"`
+	Images []string `json:"images"`
 }
 }
 
 
 type Response struct {
 type Response struct {
@@ -23,6 +28,7 @@ type Response struct {
 type Server struct {
 type Server struct {
 	model *llama.Model
 	model *llama.Model
 	lc    *llama.Context
 	lc    *llama.Context
+	cc    *llama.ClipContext
 }
 }
 
 
 var mu sync.Mutex
 var mu sync.Mutex
@@ -34,6 +40,9 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
 		return
 		return
 	}
 	}
 
 
+	mu.Lock()
+	defer mu.Unlock()
+
 	// Set the headers to indicate streaming
 	// Set the headers to indicate streaming
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Set("Content-Type", "application/json")
 	w.Header().Set("Transfer-Encoding", "chunked")
 	w.Header().Set("Transfer-Encoding", "chunked")
@@ -41,30 +50,69 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
 
 
 	enc := json.NewEncoder(w)
 	enc := json.NewEncoder(w)
 
 
-	// main loop
-	tokens, err := s.model.Tokenize(request.Prompt, 2048, true, true)
-	if err != nil {
-		panic(err)
+	// create embeddings for each image
+	var embeddings []*llama.LlavaImageEmbed
+	if s.cc != nil {
+		for _, img := range request.Images {
+			data, err := base64.StdEncoding.DecodeString(img)
+			if err != nil {
+				http.Error(w, "Failed to decode image", http.StatusBadRequest)
+				return
+			}
+
+			embd := llama.NewLlavaImageEmbed(s.cc, data)
+			embeddings = append(embeddings, embd)
+		}
 	}
 	}
 
 
-	batch := llama.NewBatch(512, 0, 1)
+	var nPast int
 
 
-	// prompt eval
-	for i, t := range tokens {
-		batch.Add(t, llama.Pos(i), []llama.SeqId{0}, true)
-	}
+	// eval the prompt
+	re := regexp.MustCompile(`\[\s*img-(\d+)\s*\]`)
+	matches := re.FindAllStringSubmatchIndex(request.Prompt, -1)
 
 
-	// main loop
-	for n := batch.NumTokens(); n < 2048; n++ {
-		mu.Lock()
-		err = s.lc.Decode(batch)
+	// eval each chunk including images
+	pos := 0
+	for _, match := range matches {
+		part := request.Prompt[pos:match[0]]
+		fmt.Println("Text part:", part)
+
+		// eval text before image
+		err := s.evalText(part, &nPast)
 		if err != nil {
 		if err != nil {
-			panic("Failed to decode")
+			log.Println("Failed to eval text:", err)
+			return
 		}
 		}
 
 
+		// eval image
+		imgIndexStr := request.Prompt[match[2]:match[3]]
+		imgIndex, err := strconv.Atoi(imgIndexStr)
+		if err != nil {
+			slog.Warn("Failed to parse image index", "index", imgIndexStr)
+			continue
+		}
+
+		fmt.Println("Tag index:", imgIndex)
+		if imgIndex <= len(embeddings) {
+			slog.Info("evaluating image", "index", imgIndex)
+			llama.LlavaEvalImageEmbed(s.lc, embeddings[imgIndex], 512, &nPast)
+		}
+
+		pos = match[1]
+	}
+
+	// eval remaining text
+	if pos < len(request.Prompt) {
+		s.evalText(request.Prompt[pos:], &nPast)
+	}
+
+	batch := llama.NewBatch(512, 0, 1)
+	defer batch.Free()
+
+	// main loop
+	for n := nPast; n < 2048; n++ {
 		// sample a token
 		// sample a token
 		token := s.lc.SampleTokenGreedy(batch)
 		token := s.lc.SampleTokenGreedy(batch)
-		mu.Unlock()
 
 
 		// if it's an end of sequence token, break
 		// if it's an end of sequence token, break
 		if s.model.TokenIsEog(token) {
 		if s.model.TokenIsEog(token) {
@@ -81,27 +129,44 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
 		w.(http.Flusher).Flush()
 		w.(http.Flusher).Flush()
 
 
 		batch.Clear()
 		batch.Clear()
-		batch.Add(token, llama.Pos(n), []llama.SeqId{0}, true)
+		batch.Add(token, n, []int{0}, true)
+
+		err := s.lc.Decode(batch)
+		if err != nil {
+			panic("Failed to decode")
+		}
 	}
 	}
+
+	s.lc.KvCacheClear()
 }
 }
 
 
 func main() {
 func main() {
-	mp := flag.String("model", "", "Path to model binary file")
+	mpath := flag.String("model", "", "Path to model binary file")
+	ppath := flag.String("projector", "", "Path to projector binary file")
 	flag.Parse()
 	flag.Parse()
 
 
 	// load the model
 	// load the model
 	llama.BackendInit()
 	llama.BackendInit()
 	params := llama.NewModelParams()
 	params := llama.NewModelParams()
-	model := llama.LoadModelFromFile(*mp, params)
+	model := llama.LoadModelFromFile(*mpath, params)
 	ctxParams := llama.NewContextParams()
 	ctxParams := llama.NewContextParams()
 	lc := llama.NewContextWithModel(model, ctxParams)
 	lc := llama.NewContextWithModel(model, ctxParams)
 	if lc == nil {
 	if lc == nil {
 		panic("Failed to create context")
 		panic("Failed to create context")
 	}
 	}
 
 
+	var cc *llama.ClipContext
+	if ppath != nil {
+		cc = llama.NewClipContext(*ppath)
+		if cc == nil {
+			panic("Failed to create clip context")
+		}
+	}
+
 	server := &Server{
 	server := &Server{
 		model: model,
 		model: model,
 		lc:    lc,
 		lc:    lc,
+		cc:    cc,
 	}
 	}
 
 
 	addr := "127.0.0.1:8080"
 	addr := "127.0.0.1:8080"
@@ -121,3 +186,27 @@ func main() {
 		log.Fatal("server error:", err)
 		log.Fatal("server error:", err)
 	}
 	}
 }
 }
+
+func (s *Server) evalText(text string, nPast *int) error {
+	// eval before
+	batch := llama.NewBatch(512, 0, 1)
+	defer batch.Free()
+
+	tokens, err := s.lc.Model().Tokenize(text, 2048, true, true)
+	if err != nil {
+		return fmt.Errorf("tokenize failed: %w", err)
+	}
+
+	// prompt eval
+	for _, t := range tokens {
+		batch.Add(t, *nPast, []int{0}, true)
+		*nPast++
+	}
+
+	err = s.lc.Decode(batch)
+	if err != nil {
+		return fmt.Errorf("decode failed: %w", err)
+	}
+
+	return nil
+}