|
@@ -850,6 +850,7 @@ func (s *Server) loadModel(
|
|
lpath multiLPath,
|
|
lpath multiLPath,
|
|
ppath string,
|
|
ppath string,
|
|
kvSize int,
|
|
kvSize int,
|
|
|
|
+ kvCacheType string,
|
|
flashAttention bool,
|
|
flashAttention bool,
|
|
threads int,
|
|
threads int,
|
|
multiUserCache bool,
|
|
multiUserCache bool,
|
|
@@ -862,7 +863,7 @@ func (s *Server) loadModel(
|
|
panic(err)
|
|
panic(err)
|
|
}
|
|
}
|
|
|
|
|
|
- ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention)
|
|
|
|
|
|
+ ctxParams := llama.NewContextParams(kvSize, s.batchSize*s.parallel, s.parallel, threads, flashAttention, kvCacheType)
|
|
s.lc, err = llama.NewContextWithModel(s.model, ctxParams)
|
|
s.lc, err = llama.NewContextWithModel(s.model, ctxParams)
|
|
if err != nil {
|
|
if err != nil {
|
|
panic(err)
|
|
panic(err)
|
|
@@ -903,6 +904,7 @@ func main() {
|
|
mainGpu := flag.Int("main-gpu", 0, "Main GPU")
|
|
mainGpu := flag.Int("main-gpu", 0, "Main GPU")
|
|
flashAttention := flag.Bool("flash-attn", false, "Enable flash attention")
|
|
flashAttention := flag.Bool("flash-attn", false, "Enable flash attention")
|
|
kvSize := flag.Int("ctx-size", 2048, "Context (or KV cache) size")
|
|
kvSize := flag.Int("ctx-size", 2048, "Context (or KV cache) size")
|
|
|
|
+ kvCacheType := flag.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
|
|
port := flag.Int("port", 8080, "Port to expose the server on")
|
|
port := flag.Int("port", 8080, "Port to expose the server on")
|
|
threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
|
|
threads := flag.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
|
|
verbose := flag.Bool("verbose", false, "verbose output (default: disabled)")
|
|
verbose := flag.Bool("verbose", false, "verbose output (default: disabled)")
|
|
@@ -970,7 +972,7 @@ func main() {
|
|
}
|
|
}
|
|
|
|
|
|
server.ready.Add(1)
|
|
server.ready.Add(1)
|
|
- go server.loadModel(params, *mpath, lpaths, *ppath, *kvSize, *flashAttention, *threads, *multiUserCache)
|
|
|
|
|
|
+ go server.loadModel(params, *mpath, lpaths, *ppath, *kvSize, *kvCacheType, *flashAttention, *threads, *multiUserCache)
|
|
|
|
|
|
server.cond = sync.NewCond(&server.mu)
|
|
server.cond = sync.NewCond(&server.mu)
|
|
|
|
|