Explorar o código

Don't clamp ctx size in `PredictServerFit` (#4317)

* dont clamp ctx size in `PredictServerFit`

* minimum 4 context

* remove context warning
Jeffrey Morgan hai 11 meses
pai
achega
bb6fd02298
Modificáronse 3 ficheiros con 6 adicións e 19 borrados
  1. 1 10
      llm/memory.go
  2. 1 9
      llm/server.go
  3. 4 0
      server/sched.go

+ 1 - 10
llm/memory.go

@@ -12,17 +12,8 @@ import (
 
 // This algorithm looks for a complete fit to determine if we need to unload other models
 func PredictServerFit(allGpus gpu.GpuInfoList, ggml *GGML, adapters, projectors []string, opts api.Options) (bool, uint64) {
-	var estimatedVRAM uint64
-	if opts.NumCtx > int(ggml.KV().ContextLength()) {
-		slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
-		opts.NumCtx = int(ggml.KV().ContextLength())
-	}
-
-	if opts.NumCtx < 4 {
-		opts.NumCtx = 4
-	}
-
 	// Split up the GPUs by type and try them
+	var estimatedVRAM uint64
 	for _, gpus := range allGpus.ByLibrary() {
 		var layerCount int
 		layerCount, estimatedVRAM, _ = EstimateGPULayers(gpus, ggml, projectors, opts)

+ 1 - 9
llm/server.go

@@ -77,15 +77,7 @@ func LoadModel(model string) (*GGML, error) {
 // The gpu list must be a single family.
 func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) {
 	var err error
-	if opts.NumCtx > int(ggml.KV().ContextLength()) {
-		slog.Warn("requested context length is greater than the model's training context window size", "requested", opts.NumCtx, "training size", ggml.KV().ContextLength())
-	}
-
-	if opts.NumCtx < 4 {
-		opts.NumCtx = 4
-	}
-
-	cpuRunner := ""
+	var cpuRunner string
 	var estimatedVRAM uint64
 	var estimatedTotal uint64
 	var systemMemory uint64

+ 4 - 0
server/sched.go

@@ -61,6 +61,10 @@ func InitScheduler(ctx context.Context) *Scheduler {
 // context must be canceled to decrement ref count and release the runner
 func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration time.Duration) (chan *runnerRef, chan error) {
 	// allocate a large enough kv cache for all parallel requests
+	if opts.NumCtx < 4 {
+		opts.NumCtx = 4
+	}
+
 	opts.NumCtx = opts.NumCtx * envconfig.NumParallel
 
 	req := &LlmRequest{