ソースを参照

better `example` module, add port

jmorganca 11 ヶ月 前
コミット
d12db0568e

+ 31 - 0
llama/example/README.md

@@ -0,0 +1,31 @@
+# `example`
+
+Demo app for the `llama` package
+
+Pull a model:
+
+```
+ollama pull mistral:7b-instruct-v0.3-q4_0
+```
+
+Then run it:
+
+```
+go run -x . \
+    -model ~/.ollama/models/blobs/sha256-ff82381e2bea77d91c1b824c7afb83f6fb73e9f7de9dda631bcdbca564aa5435 \
+    -prompt "[ISNT] Why is the sky blue? [/INST]"
+```
+
+## Vision
+
+```
+ollama pull llava:7b-v1.6-mistral-q4_0
+```
+
+```
+go run -x . \
+    -model ~/.ollama/models/blobs/sha256-170370233dd5c5415250a2ecd5c71586352850729062ccef1496385647293868 \
+    -projector ~/.ollama/models/blobs/sha256-72d6f08a42f656d36b356dbe0920675899a99ce21192fd66266fb7d82ed07539 \
+    -image ./alonso.jpg \
+    -prompt "[ISNT] What is in this image? <image> [/INST]"
+```

+ 0 - 0
llama/llava/alonso.jpg → llama/example/alonso.jpg


+ 128 - 0
llama/example/main.go

@@ -0,0 +1,128 @@
+package main
+
+import (
+	"flag"
+	"fmt"
+	"io"
+	"log"
+	"os"
+	"strings"
+
+	"github.com/ollama/ollama/llama"
+)
+
+func main() {
+	mpath := flag.String("model", "", "Path to model binary file")
+	ppath := flag.String("projector", "", "Path to projector binary file")
+	image := flag.String("image", "", "Path to image file")
+	prompt := flag.String("prompt", "", "Prompt including <image> tag")
+	flag.Parse()
+
+	if *mpath == "" {
+		panic("model path is required")
+	}
+
+	if *prompt == "" {
+		panic("prompt is required")
+	}
+
+	// load the model
+	llama.BackendInit()
+	params := llama.NewModelParams()
+	model := llama.LoadModelFromFile(*mpath, params)
+	ctxParams := llama.NewContextParams()
+
+	// language model context
+	lc := llama.NewContextWithModel(model, ctxParams)
+
+	// eval before
+	batch := llama.NewBatch(512, 0, 1)
+	var nPast int
+
+	// clip context
+	var clipCtx *llama.ClipContext
+
+	// multi-modal
+	if *ppath == "" {
+		clipCtx = llama.NewClipContext(*ppath)
+
+		// open image file
+		file, err := os.Open(*image)
+		if err != nil {
+			panic(err)
+		}
+		defer file.Close()
+
+		data, err := io.ReadAll(file)
+		if err != nil {
+			log.Fatal(err)
+		}
+
+		embedding := llama.NewLlavaImageEmbed(clipCtx, data)
+
+		parts := strings.Split(*prompt, "<image>")
+		if len(parts) != 2 {
+			panic("prompt must contain exactly one <image>")
+		}
+
+		beforeTokens, err := lc.Model().Tokenize(parts[0], 2048, true, true)
+		if err != nil {
+			panic(err)
+		}
+
+		for _, t := range beforeTokens {
+			batch.Add(t, nPast, []int{0}, true)
+			nPast++
+		}
+
+		err = lc.Decode(batch)
+		if err != nil {
+			panic(err)
+		}
+
+		llama.LlavaEvalImageEmbed(lc, embedding, 512, &nPast)
+
+		afterTokens, err := lc.Model().Tokenize(parts[1], 2048, true, true)
+		if err != nil {
+			panic(err)
+		}
+
+		for _, t := range afterTokens {
+			batch.Add(t, nPast, []int{0}, true)
+			nPast++
+		}
+	} else {
+		tokens, err := lc.Model().Tokenize(*prompt, 2048, true, true)
+		if err != nil {
+			panic(err)
+		}
+
+		for _, t := range tokens {
+			batch.Add(t, nPast, []int{0}, true)
+			nPast++
+		}
+	}
+
+	// main loop
+	for n := nPast; n < 4096; n++ {
+		err := lc.Decode(batch)
+		if err != nil {
+			panic(err)
+		}
+
+		// sample a token
+		logits := lc.GetLogitsIth(batch.NumTokens() - 1)
+		token := lc.SampleTokenGreedy(logits)
+
+		// if it's an end of sequence token, break
+		if lc.Model().TokenIsEog(token) {
+			break
+		}
+
+		// print the token
+		str := lc.Model().TokenToPiece(token)
+		fmt.Print(str)
+		batch.Clear()
+		batch.Add(token, n, []int{0}, true)
+	}
+}

+ 10 - 16
llama/llama.go

@@ -99,26 +99,24 @@ func (c *Context) Model() *Model {
 	return &Model{c: C.llama_get_model(c.c)}
 }
 
-// TODO: break this up
-func (c *Context) SampleTokenGreedy(batch Batch, i int) int {
-	nv := c.Model().NumVocab()
+func (c *Context) GetLogitsIth(i int) []float32 {
+	return unsafe.Slice((*float32)(unsafe.Pointer(C.llama_get_logits_ith(c.c, C.int(i)))), c.Model().NumVocab())
+}
 
-	// TODO(jmorganca): split this up into different functions
-	candidates := (*C.struct_llama_token_data)(C.malloc(C.size_t(nv) * C.size_t(unsafe.Sizeof(C.struct_llama_token_data{}))))
+func (c *Context) SampleTokenGreedy(logits []float32) int {
+	candidates := (*C.struct_llama_token_data)(C.malloc(C.size_t(len(logits)) * C.size_t(unsafe.Sizeof(C.struct_llama_token_data{}))))
 	defer C.free(unsafe.Pointer(candidates))
 
-	// get most recent logits
-	logits := C.llama_get_logits_ith(c.c, C.int(i))
-	for i := 0; i < int(nv); i++ {
+	for i, logit := range logits {
 		ptr := (*C.struct_llama_token_data)(unsafe.Pointer(uintptr(unsafe.Pointer(candidates)) + uintptr(i)*unsafe.Sizeof(C.struct_llama_token_data{})))
 		ptr.id = C.int(i)
-		ptr.logit = unsafe.Slice(logits, nv)[i]
+		ptr.logit = C.float(logit)
 		ptr.p = 0.0
 	}
 
 	return int(C.llama_sample_token_greedy(c.c, &C.llama_token_data_array{
 		data:   candidates,
-		size:   C.size_t(nv),
+		size:   C.size_t(len(logits)),
 		sorted: C.bool(false),
 	}))
 }
@@ -155,6 +153,8 @@ func (b *Batch) NumTokens() int {
 	return int(b.c.n_tokens)
 }
 
+// Add adds a token to the batch with the given position for the given
+// sequence ids, and optionally instructs to include logits.
 func (b *Batch) Add(token int, pos int, seqIds []int, logits bool) {
 	unsafe.Slice(b.c.token, 512)[b.c.n_tokens] = C.llama_token(token)
 	unsafe.Slice(b.c.pos, 512)[b.c.n_tokens] = C.llama_pos(pos)
@@ -179,12 +179,6 @@ func (b *Batch) Free() {
 	C.llama_batch_free(b.c)
 }
 
-// LLAMA_API struct llama_batch llama_batch_get_one(
-//
-//		llama_token * tokens,
-//			int32_t   n_tokens,
-//		  llama_pos   pos_0,
-//	   llama_seq_id   seq_id);
 func BatchGetOne(tokens []int, pos0 int, seqId int) Batch {
 	return Batch{c: C.llama_batch_get_one((*C.int)(unsafe.Pointer(&tokens[0])), C.int32_t(len(tokens)), C.int(pos0), C.int(seqId))}
 }

+ 0 - 14
llama/llava/README.md

@@ -1,14 +0,0 @@
-# `llava`
-
-Demo app for running Llava and other clip-based vision models.
-
-```
-ollama pull llava
-```
-
-```
-go run -x . \
-    -model ~/.ollama/models/blobs/sha256-170370233dd5c5415250a2ecd5c71586352850729062ccef1496385647293868 \
-    -projector ~/.ollama/models/blobs/sha256-72d6f08a42f656d36b356dbe0920675899a99ce21192fd66266fb7d82ed07539 \
-    -image ./alonso.jpg
-```

+ 0 - 117
llama/llava/main.go

@@ -1,117 +0,0 @@
-package main
-
-import (
-	"flag"
-	"fmt"
-	"io"
-	"log"
-	"os"
-	"strings"
-
-	"github.com/ollama/ollama/llama"
-)
-
-func main() {
-	mp := flag.String("model", "", "Path to model binary file")
-	pp := flag.String("projector", "", "Path to projector binary file")
-	image := flag.String("image", "", "Path to image file")
-	prompt := flag.String("prompt", " [INST] What is in the picture? <image> [/INST]", "Prompt including <image> tag")
-	flag.Parse()
-
-	// load the model
-	llama.BackendInit()
-	params := llama.NewModelParams()
-	model := llama.LoadModelFromFile(*mp, params)
-	ctxParams := llama.NewContextParams()
-
-	// language model context
-	lc := llama.NewContextWithModel(model, ctxParams)
-
-	// clip context
-	clipCtx := llama.NewClipContext(*pp)
-
-	// open image file
-	file, err := os.Open(*image)
-	if err != nil {
-		panic(err)
-	}
-	defer file.Close()
-
-	data, err := io.ReadAll(file)
-	if err != nil {
-		log.Fatal(err)
-	}
-
-	embedding := llama.NewLlavaImageEmbed(clipCtx, data)
-
-	parts := strings.Split(*prompt, "<image>")
-	if len(parts) != 2 {
-		panic("prompt must contain exactly one <image>")
-	}
-
-	err = eval(lc, parts[0], embedding, parts[1])
-	if err != nil {
-		panic(err)
-	}
-}
-
-func eval(lc *llama.Context, before string, embedding *llama.LlavaImageEmbed, after string) error {
-	beforeTokens, err := lc.Model().Tokenize(before, 2048, true, true)
-	if err != nil {
-		return err
-	}
-
-	afterTokens, err := lc.Model().Tokenize(after, 2048, true, true)
-	if err != nil {
-		return err
-	}
-
-	// eval before
-	batch := llama.NewBatch(512, 0, 1)
-
-	var nPast int
-
-	// prompt eval
-	for _, t := range beforeTokens {
-		batch.Add(t, nPast, []int{0}, true)
-		nPast++
-	}
-
-	err = lc.Decode(batch)
-	if err != nil {
-		return err
-	}
-
-	// batch.Clear()
-
-	llama.LlavaEvalImageEmbed(lc, embedding, 512, &nPast)
-
-	batch = llama.NewBatch(512, 0, 1)
-	for _, t := range afterTokens {
-		batch.Add(t, nPast, []int{0}, true)
-	}
-
-	// main loop
-	for n := nPast; n < 4096; n++ {
-		err = lc.Decode(batch)
-		if err != nil {
-			panic("Failed to decode")
-		}
-
-		// sample a token
-		token := lc.SampleTokenGreedy(batch)
-
-		// if it's an end of sequence token, break
-		if lc.Model().TokenIsEog(token) {
-			break
-		}
-
-		// print the token
-		str := lc.Model().TokenToPiece(token)
-		fmt.Print(str)
-		batch.Clear()
-		batch.Add(token, n, []int{0}, true)
-	}
-
-	return nil
-}

+ 2 - 0
llama/runner/README.md

@@ -1,5 +1,7 @@
 # `runner`
 
+A subprocess runner for loading a model and running inference via a small http web server.
+
 ```
 ./runner -model <model binary>
 ```

+ 8 - 2
llama/runner/runner.go

@@ -9,8 +9,10 @@ import (
 	"log/slog"
 	"net"
 	"net/http"
+	"strconv"
 	"sync"
 
+	"github.com/ollama/ollama/api"
 	"github.com/ollama/ollama/llama"
 )
 
@@ -131,7 +133,8 @@ func (s *Server) run(ctx context.Context) {
 				// sample a token
 				// TODO: sample based on the sequence
 				fmt.Println("Sampling token", i, ibatch[i])
-				token := s.lc.SampleTokenGreedy(batch, ibatch[i])
+				logits := s.lc.GetLogitsIth(ibatch[i])
+				token := s.lc.SampleTokenGreedy(logits)
 
 				// if it's an end of sequence token, break
 				// TODO: just end this sequence
@@ -155,6 +158,8 @@ func (s *Server) run(ctx context.Context) {
 type Request struct {
 	Prompt string   `json:"prompt"`
 	Images []string `json:"images"`
+
+	api.Options
 }
 
 type Response struct {
@@ -208,6 +213,7 @@ func main() {
 	mpath := flag.String("model", "", "Path to model binary file")
 	ppath := flag.String("projector", "", "Path to projector binary file")
 	parallel := flag.Int("parallel", 1, "Number of sequences to handle simultaneously")
+	port := flag.Int("port", 8080, "Port to expose the server on")
 	flag.Parse()
 
 	// load the model
@@ -241,7 +247,7 @@ func main() {
 	ctx, cancel := context.WithCancel(context.Background())
 	go server.run(ctx)
 
-	addr := "127.0.0.1:8080"
+	addr := "127.0.0.1:" + strconv.Itoa(*port)
 	listener, err := net.Listen("tcp", addr)
 	if err != nil {
 		fmt.Println("Listen error:", err)