|
@@ -7,6 +7,7 @@ import (
|
|
|
"fmt"
|
|
|
"log"
|
|
|
"log/slog"
|
|
|
+ "math"
|
|
|
"net"
|
|
|
"net/http"
|
|
|
"runtime"
|
|
@@ -28,6 +29,9 @@ type Sequence struct {
|
|
|
// channel to send responses over
|
|
|
responses chan string
|
|
|
|
|
|
+ // number of tokens to predict
|
|
|
+ numPredict int
|
|
|
+
|
|
|
samplingCtx *llama.SamplingContext
|
|
|
|
|
|
// channel to send back the embedding if embedding only
|
|
@@ -38,6 +42,8 @@ type Sequence struct {
|
|
|
|
|
|
// true if an embedding are to be returned instead of text generation
|
|
|
embeddingOnly bool
|
|
|
+
|
|
|
+ doneReason string
|
|
|
}
|
|
|
|
|
|
// prompt returns true if the prompt is still being processed
|
|
@@ -46,11 +52,18 @@ func (s *Sequence) prompt() bool {
|
|
|
}
|
|
|
|
|
|
func (s *Server) NewSequence(prompt string, stop []string, params *llama.SamplingParams, embedding bool) *Sequence {
|
|
|
- tokens, err := s.lc.Model().Tokenize(prompt, 2048, false, true)
|
|
|
+ tokens, err := s.lc.Model().Tokenize(prompt, false, true)
|
|
|
if err != nil {
|
|
|
panic(err)
|
|
|
}
|
|
|
|
|
|
+ // truncate to last n tokens
|
|
|
+ // TODO: this shouldn't happen and will severely impact generation
|
|
|
+ // quality. instead we should ensure to cut prompt in the API.
|
|
|
+ if len(tokens) > s.numCtx {
|
|
|
+ tokens = tokens[:s.numCtx]
|
|
|
+ }
|
|
|
+
|
|
|
var sc *llama.SamplingContext
|
|
|
if params != nil {
|
|
|
sc = llama.NewSamplingContext(*params)
|
|
@@ -83,9 +96,16 @@ type Server struct {
|
|
|
// TODO (jmorganca): this can probably be moved into run()
|
|
|
seqs []*Sequence
|
|
|
|
|
|
+ // context window size
|
|
|
+ numCtx int
|
|
|
+
|
|
|
mu sync.Mutex
|
|
|
|
|
|
cond *sync.Cond
|
|
|
+
|
|
|
+ progress float32
|
|
|
+
|
|
|
+ status string
|
|
|
}
|
|
|
|
|
|
func (s *Server) allNil() bool {
|
|
@@ -183,6 +203,15 @@ func (s *Server) run(ctx context.Context) {
|
|
|
continue
|
|
|
}
|
|
|
|
|
|
+ // we've reached the context limit
|
|
|
+ if seq.nPast > s.numCtx {
|
|
|
+ seq.doneReason = "limit"
|
|
|
+ close(seq.responses)
|
|
|
+ s.lc.KvCacheSeqRm(i, 0, -1)
|
|
|
+ s.seqs[i] = nil
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
for j, t := range seq.tokens {
|
|
|
// todo: make this n_batch
|
|
|
if j > s.batchSize {
|
|
@@ -252,6 +281,7 @@ func (s *Server) run(ctx context.Context) {
|
|
|
// as it's important for the /api/generate context
|
|
|
// seq.responses <- piece
|
|
|
|
|
|
+ seq.doneReason = "stop"
|
|
|
close(seq.responses)
|
|
|
seq.samplingCtx.Free()
|
|
|
pieces[i] = []string{}
|
|
@@ -273,6 +303,7 @@ func (s *Server) run(ctx context.Context) {
|
|
|
}
|
|
|
|
|
|
s.lc.KvCacheSeqRm(i, 0, -1)
|
|
|
+ seq.doneReason = "stop"
|
|
|
close(seq.responses)
|
|
|
seq.samplingCtx.Free()
|
|
|
pieces[i] = []string{}
|
|
@@ -411,6 +442,24 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+type HealthResponse struct {
|
|
|
+ Status string `json:"status"`
|
|
|
+ Progress float32 `json:"progress"`
|
|
|
+}
|
|
|
+
|
|
|
+// TODO (jmorganca): is it safe to do this concurrently with decoding?
|
|
|
+func (s *Server) health(w http.ResponseWriter, r *http.Request) {
|
|
|
+ w.Header().Set("Content-Type", "application/json")
|
|
|
+
|
|
|
+ if err := json.NewEncoder(w).Encode(&HealthResponse{
|
|
|
+ Status: s.status,
|
|
|
+ Progress: s.progress,
|
|
|
+ }); err != nil {
|
|
|
+ log.Println("Failed to encode result:", err)
|
|
|
+ return
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
func main() {
|
|
|
mpath := flag.String("model", "", "Path to model binary file")
|
|
|
ppath := flag.String("projector", "", "Path to projector binary file")
|
|
@@ -425,36 +474,31 @@ func main() {
|
|
|
threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
|
|
|
flag.Parse()
|
|
|
|
|
|
+ server := &Server{
|
|
|
+ numCtx: *numCtx,
|
|
|
+ batchSize: *batchSize,
|
|
|
+ parallel: *parallel,
|
|
|
+ seqs: make([]*Sequence, *parallel),
|
|
|
+ status: "loading",
|
|
|
+ }
|
|
|
+
|
|
|
// load the model
|
|
|
llama.BackendInit()
|
|
|
- params := llama.NewModelParams(*nGpuLayers, *mainGpu)
|
|
|
- model := llama.LoadModelFromFile(*mpath, params)
|
|
|
+ params := llama.NewModelParams(*nGpuLayers, *mainGpu, func(progress float32) {
|
|
|
+ slog.Info("Loading model", "progress %", math.Round(float64(progress*100)))
|
|
|
+ server.progress = progress
|
|
|
+ })
|
|
|
+ server.model = llama.LoadModelFromFile(*mpath, params)
|
|
|
|
|
|
if *lpath != "" {
|
|
|
- model.ApplyLoraFromFile(*lpath, 1.0, "", *threads)
|
|
|
+ server.model.ApplyLoraFromFile(*lpath, 1.0, "", *threads)
|
|
|
}
|
|
|
|
|
|
ctxParams := llama.NewContextParams(*numCtx, *threads, *flashAttention)
|
|
|
- lc := llama.NewContextWithModel(model, ctxParams)
|
|
|
- if lc == nil {
|
|
|
- panic("Failed to create context")
|
|
|
- }
|
|
|
+ server.lc = llama.NewContextWithModel(server.model, ctxParams)
|
|
|
|
|
|
- var cc *llama.ClipContext
|
|
|
if *ppath != "" {
|
|
|
- cc = llama.NewClipContext(*ppath)
|
|
|
- if cc == nil {
|
|
|
- panic("Failed to create clip context")
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- server := &Server{
|
|
|
- model: model,
|
|
|
- lc: lc,
|
|
|
- cc: cc,
|
|
|
- batchSize: *batchSize,
|
|
|
- parallel: *parallel,
|
|
|
- seqs: make([]*Sequence, *parallel),
|
|
|
+ server.cc = llama.NewClipContext(*ppath)
|
|
|
}
|
|
|
|
|
|
server.cond = sync.NewCond(&server.mu)
|
|
@@ -473,11 +517,14 @@ func main() {
|
|
|
mux := http.NewServeMux()
|
|
|
mux.HandleFunc("/embeddings", server.embeddings)
|
|
|
mux.HandleFunc("/completion", server.completion)
|
|
|
+ mux.HandleFunc("/health", server.health)
|
|
|
|
|
|
httpServer := http.Server{
|
|
|
Handler: mux,
|
|
|
}
|
|
|
|
|
|
+ server.status = "ready"
|
|
|
+
|
|
|
log.Println("Server listening on", addr)
|
|
|
if err := httpServer.Serve(listener); err != nil {
|
|
|
log.Fatal("server error:", err)
|