|
@@ -1,19 +1,24 @@
|
|
package main
|
|
package main
|
|
|
|
|
|
import (
|
|
import (
|
|
|
|
+ "encoding/base64"
|
|
"encoding/json"
|
|
"encoding/json"
|
|
"flag"
|
|
"flag"
|
|
"fmt"
|
|
"fmt"
|
|
"log"
|
|
"log"
|
|
|
|
+ "log/slog"
|
|
"net"
|
|
"net"
|
|
"net/http"
|
|
"net/http"
|
|
|
|
+ "regexp"
|
|
|
|
+ "strconv"
|
|
"sync"
|
|
"sync"
|
|
|
|
|
|
"github.com/ollama/ollama/llama"
|
|
"github.com/ollama/ollama/llama"
|
|
)
|
|
)
|
|
|
|
|
|
type Request struct {
|
|
type Request struct {
|
|
- Prompt string `json:"prompt"`
|
|
|
|
|
|
+ Prompt string `json:"prompt"`
|
|
|
|
+ Images []string `json:"images"`
|
|
}
|
|
}
|
|
|
|
|
|
type Response struct {
|
|
type Response struct {
|
|
@@ -23,6 +28,7 @@ type Response struct {
|
|
type Server struct {
|
|
type Server struct {
|
|
model *llama.Model
|
|
model *llama.Model
|
|
lc *llama.Context
|
|
lc *llama.Context
|
|
|
|
+ cc *llama.ClipContext
|
|
}
|
|
}
|
|
|
|
|
|
var mu sync.Mutex
|
|
var mu sync.Mutex
|
|
@@ -34,6 +40,9 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
|
|
return
|
|
return
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ mu.Lock()
|
|
|
|
+ defer mu.Unlock()
|
|
|
|
+
|
|
// Set the headers to indicate streaming
|
|
// Set the headers to indicate streaming
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.Header().Set("Transfer-Encoding", "chunked")
|
|
w.Header().Set("Transfer-Encoding", "chunked")
|
|
@@ -41,30 +50,69 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
|
|
|
|
|
|
enc := json.NewEncoder(w)
|
|
enc := json.NewEncoder(w)
|
|
|
|
|
|
- // main loop
|
|
|
|
- tokens, err := s.model.Tokenize(request.Prompt, 2048, true, true)
|
|
|
|
- if err != nil {
|
|
|
|
- panic(err)
|
|
|
|
|
|
+ // create embeddings for each image
|
|
|
|
+ var embeddings []*llama.LlavaImageEmbed
|
|
|
|
+ if s.cc != nil {
|
|
|
|
+ for _, img := range request.Images {
|
|
|
|
+ data, err := base64.StdEncoding.DecodeString(img)
|
|
|
|
+ if err != nil {
|
|
|
|
+ http.Error(w, "Failed to decode image", http.StatusBadRequest)
|
|
|
|
+ return
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ embd := llama.NewLlavaImageEmbed(s.cc, data)
|
|
|
|
+ embeddings = append(embeddings, embd)
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
|
|
- batch := llama.NewBatch(512, 0, 1)
|
|
|
|
|
|
+ var nPast int
|
|
|
|
|
|
- // prompt eval
|
|
|
|
- for i, t := range tokens {
|
|
|
|
- batch.Add(t, llama.Pos(i), []llama.SeqId{0}, true)
|
|
|
|
- }
|
|
|
|
|
|
+ // eval the prompt
|
|
|
|
+ re := regexp.MustCompile(`\[\s*img-(\d+)\s*\]`)
|
|
|
|
+ matches := re.FindAllStringSubmatchIndex(request.Prompt, -1)
|
|
|
|
|
|
- // main loop
|
|
|
|
- for n := batch.NumTokens(); n < 2048; n++ {
|
|
|
|
- mu.Lock()
|
|
|
|
- err = s.lc.Decode(batch)
|
|
|
|
|
|
+ // eval each chunk including images
|
|
|
|
+ pos := 0
|
|
|
|
+ for _, match := range matches {
|
|
|
|
+ part := request.Prompt[pos:match[0]]
|
|
|
|
+ fmt.Println("Text part:", part)
|
|
|
|
+
|
|
|
|
+ // eval text before image
|
|
|
|
+ err := s.evalText(part, &nPast)
|
|
if err != nil {
|
|
if err != nil {
|
|
- panic("Failed to decode")
|
|
|
|
|
|
+ log.Println("Failed to eval text:", err)
|
|
|
|
+ return
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ // eval image
|
|
|
|
+ imgIndexStr := request.Prompt[match[2]:match[3]]
|
|
|
|
+ imgIndex, err := strconv.Atoi(imgIndexStr)
|
|
|
|
+ if err != nil {
|
|
|
|
+ slog.Warn("Failed to parse image index", "index", imgIndexStr)
|
|
|
|
+ continue
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ fmt.Println("Tag index:", imgIndex)
|
|
|
|
+ if imgIndex <= len(embeddings) {
|
|
|
|
+ slog.Info("evaluating image", "index", imgIndex)
|
|
|
|
+ llama.LlavaEvalImageEmbed(s.lc, embeddings[imgIndex], 512, &nPast)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ pos = match[1]
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // eval remaining text
|
|
|
|
+ if pos < len(request.Prompt) {
|
|
|
|
+ s.evalText(request.Prompt[pos:], &nPast)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ batch := llama.NewBatch(512, 0, 1)
|
|
|
|
+ defer batch.Free()
|
|
|
|
+
|
|
|
|
+ // main loop
|
|
|
|
+ for n := nPast; n < 2048; n++ {
|
|
// sample a token
|
|
// sample a token
|
|
token := s.lc.SampleTokenGreedy(batch)
|
|
token := s.lc.SampleTokenGreedy(batch)
|
|
- mu.Unlock()
|
|
|
|
|
|
|
|
// if it's an end of sequence token, break
|
|
// if it's an end of sequence token, break
|
|
if s.model.TokenIsEog(token) {
|
|
if s.model.TokenIsEog(token) {
|
|
@@ -81,27 +129,44 @@ func (s *Server) stream(w http.ResponseWriter, r *http.Request) {
|
|
w.(http.Flusher).Flush()
|
|
w.(http.Flusher).Flush()
|
|
|
|
|
|
batch.Clear()
|
|
batch.Clear()
|
|
- batch.Add(token, llama.Pos(n), []llama.SeqId{0}, true)
|
|
|
|
|
|
+ batch.Add(token, n, []int{0}, true)
|
|
|
|
+
|
|
|
|
+ err := s.lc.Decode(batch)
|
|
|
|
+ if err != nil {
|
|
|
|
+ panic("Failed to decode")
|
|
|
|
+ }
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+ s.lc.KvCacheClear()
|
|
}
|
|
}
|
|
|
|
|
|
func main() {
|
|
func main() {
|
|
- mp := flag.String("model", "", "Path to model binary file")
|
|
|
|
|
|
+ mpath := flag.String("model", "", "Path to model binary file")
|
|
|
|
+ ppath := flag.String("projector", "", "Path to projector binary file")
|
|
flag.Parse()
|
|
flag.Parse()
|
|
|
|
|
|
// load the model
|
|
// load the model
|
|
llama.BackendInit()
|
|
llama.BackendInit()
|
|
params := llama.NewModelParams()
|
|
params := llama.NewModelParams()
|
|
- model := llama.LoadModelFromFile(*mp, params)
|
|
|
|
|
|
+ model := llama.LoadModelFromFile(*mpath, params)
|
|
ctxParams := llama.NewContextParams()
|
|
ctxParams := llama.NewContextParams()
|
|
lc := llama.NewContextWithModel(model, ctxParams)
|
|
lc := llama.NewContextWithModel(model, ctxParams)
|
|
if lc == nil {
|
|
if lc == nil {
|
|
panic("Failed to create context")
|
|
panic("Failed to create context")
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+ var cc *llama.ClipContext
|
|
|
|
+ if ppath != nil {
|
|
|
|
+ cc = llama.NewClipContext(*ppath)
|
|
|
|
+ if cc == nil {
|
|
|
|
+ panic("Failed to create clip context")
|
|
|
|
+ }
|
|
|
|
+ }
|
|
|
|
+
|
|
server := &Server{
|
|
server := &Server{
|
|
model: model,
|
|
model: model,
|
|
lc: lc,
|
|
lc: lc,
|
|
|
|
+ cc: cc,
|
|
}
|
|
}
|
|
|
|
|
|
addr := "127.0.0.1:8080"
|
|
addr := "127.0.0.1:8080"
|
|
@@ -121,3 +186,27 @@ func main() {
|
|
log.Fatal("server error:", err)
|
|
log.Fatal("server error:", err)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+func (s *Server) evalText(text string, nPast *int) error {
|
|
|
|
+ // eval before
|
|
|
|
+ batch := llama.NewBatch(512, 0, 1)
|
|
|
|
+ defer batch.Free()
|
|
|
|
+
|
|
|
|
+ tokens, err := s.lc.Model().Tokenize(text, 2048, true, true)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return fmt.Errorf("tokenize failed: %w", err)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ // prompt eval
|
|
|
|
+ for _, t := range tokens {
|
|
|
|
+ batch.Add(t, *nPast, []int{0}, true)
|
|
|
|
+ *nPast++
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ err = s.lc.Decode(batch)
|
|
|
|
+ if err != nil {
|
|
|
|
+ return fmt.Errorf("decode failed: %w", err)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ return nil
|
|
|
|
+}
|