|
@@ -9,6 +9,7 @@ import (
|
|
"log/slog"
|
|
"log/slog"
|
|
"net"
|
|
"net"
|
|
"net/http"
|
|
"net/http"
|
|
|
|
+ "runtime"
|
|
"strconv"
|
|
"strconv"
|
|
"strings"
|
|
"strings"
|
|
"sync"
|
|
"sync"
|
|
@@ -73,6 +74,8 @@ type Server struct {
|
|
lc *llama.Context
|
|
lc *llama.Context
|
|
cc *llama.ClipContext
|
|
cc *llama.ClipContext
|
|
|
|
|
|
|
|
+ batchSize int
|
|
|
|
+
|
|
// parallel is the number of parallel requests to handle
|
|
// parallel is the number of parallel requests to handle
|
|
parallel int
|
|
parallel int
|
|
|
|
|
|
@@ -154,7 +157,7 @@ func truncateStop(pieces []string, stop string) []string {
|
|
}
|
|
}
|
|
|
|
|
|
func (s *Server) run(ctx context.Context) {
|
|
func (s *Server) run(ctx context.Context) {
|
|
- batch := llama.NewBatch(512, 0, s.parallel)
|
|
|
|
|
|
+ batch := llama.NewBatch(s.batchSize, 0, s.parallel)
|
|
defer batch.Free()
|
|
defer batch.Free()
|
|
|
|
|
|
// build up stop sequences as we recognize them
|
|
// build up stop sequences as we recognize them
|
|
@@ -182,7 +185,7 @@ func (s *Server) run(ctx context.Context) {
|
|
|
|
|
|
for j, t := range seq.tokens {
|
|
for j, t := range seq.tokens {
|
|
// todo: make this n_batch
|
|
// todo: make this n_batch
|
|
- if j > 512 {
|
|
|
|
|
|
+ if j > s.batchSize {
|
|
break
|
|
break
|
|
}
|
|
}
|
|
|
|
|
|
@@ -207,10 +210,10 @@ func (s *Server) run(ctx context.Context) {
|
|
|
|
|
|
// don't sample prompt processing
|
|
// don't sample prompt processing
|
|
if seq.prompt() {
|
|
if seq.prompt() {
|
|
- if len(seq.tokens) < 512 {
|
|
|
|
|
|
+ if len(seq.tokens) < s.batchSize {
|
|
seq.tokens = []int{}
|
|
seq.tokens = []int{}
|
|
} else {
|
|
} else {
|
|
- seq.tokens = seq.tokens[512:]
|
|
|
|
|
|
+ seq.tokens = seq.tokens[s.batchSize:]
|
|
}
|
|
}
|
|
|
|
|
|
continue
|
|
continue
|
|
@@ -412,14 +415,26 @@ func main() {
|
|
mpath := flag.String("model", "", "Path to model binary file")
|
|
mpath := flag.String("model", "", "Path to model binary file")
|
|
ppath := flag.String("projector", "", "Path to projector binary file")
|
|
ppath := flag.String("projector", "", "Path to projector binary file")
|
|
parallel := flag.Int("parallel", 1, "Number of sequences to handle simultaneously")
|
|
parallel := flag.Int("parallel", 1, "Number of sequences to handle simultaneously")
|
|
|
|
+ batchSize := flag.Int("batch-size", 512, "Batch size")
|
|
|
|
+ nGpuLayers := flag.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
|
|
|
|
+ mainGpu := flag.Int("main-gpu", 0, "Main GPU")
|
|
|
|
+ flashAttention := flag.Bool("flash-attention", false, "Enable flash attention")
|
|
|
|
+ numCtx := flag.Int("num-ctx", 2048, "Context (or KV cache) size")
|
|
|
|
+ lpath := flag.String("lora", "", "Path to lora layer file")
|
|
port := flag.Int("port", 8080, "Port to expose the server on")
|
|
port := flag.Int("port", 8080, "Port to expose the server on")
|
|
|
|
+ threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
|
|
flag.Parse()
|
|
flag.Parse()
|
|
|
|
|
|
// load the model
|
|
// load the model
|
|
llama.BackendInit()
|
|
llama.BackendInit()
|
|
- params := llama.NewModelParams()
|
|
|
|
|
|
+ params := llama.NewModelParams(*nGpuLayers, *mainGpu)
|
|
model := llama.LoadModelFromFile(*mpath, params)
|
|
model := llama.LoadModelFromFile(*mpath, params)
|
|
- ctxParams := llama.NewContextParams()
|
|
|
|
|
|
+
|
|
|
|
+ if *lpath != "" {
|
|
|
|
+ model.ApplyLoraFromFile(*lpath, 1.0, "", *threads)
|
|
|
|
+ }
|
|
|
|
+
|
|
|
|
+ ctxParams := llama.NewContextParams(*numCtx, *threads, *flashAttention)
|
|
lc := llama.NewContextWithModel(model, ctxParams)
|
|
lc := llama.NewContextWithModel(model, ctxParams)
|
|
if lc == nil {
|
|
if lc == nil {
|
|
panic("Failed to create context")
|
|
panic("Failed to create context")
|
|
@@ -434,11 +449,12 @@ func main() {
|
|
}
|
|
}
|
|
|
|
|
|
server := &Server{
|
|
server := &Server{
|
|
- model: model,
|
|
|
|
- lc: lc,
|
|
|
|
- cc: cc,
|
|
|
|
- parallel: *parallel,
|
|
|
|
- seqs: make([]*Sequence, *parallel),
|
|
|
|
|
|
+ model: model,
|
|
|
|
+ lc: lc,
|
|
|
|
+ cc: cc,
|
|
|
|
+ batchSize: *batchSize,
|
|
|
|
+ parallel: *parallel,
|
|
|
|
+ seqs: make([]*Sequence, *parallel),
|
|
}
|
|
}
|
|
|
|
|
|
server.cond = sync.NewCond(&server.mu)
|
|
server.cond = sync.NewCond(&server.mu)
|