瀏覽代碼

Merge pull request #1971 from jmorganca/mxyng/max-context-length

add max context length check
Michael Yang 1 年之前
父節點
當前提交
356d178f6e
共有 3 個文件被更改,包括 15 次插入0 次删除
  1. 1 0
      llm/ggml.go
  2. 9 0
      llm/gguf.go
  3. 5 0
      llm/llm.go

+ 1 - 0
llm/ggml.go

@@ -83,6 +83,7 @@ type model interface {
 	NumEmbed() uint32
 	NumEmbed() uint32
 	NumHead() uint32
 	NumHead() uint32
 	NumHeadKv() uint32
 	NumHeadKv() uint32
+	NumCtx() uint32
 }
 }
 
 
 type container interface {
 type container interface {

+ 9 - 0
llm/gguf.go

@@ -308,6 +308,15 @@ func (llm *ggufModel) NumHeadKv() uint32 {
 	return value.(uint32)
 	return value.(uint32)
 }
 }
 
 
+func (llm *ggufModel) NumCtx() uint32 {
+	value, exists := llm.kv[fmt.Sprintf("%s.context_length", llm.ModelFamily())]
+	if !exists {
+		return 0
+	}
+
+	return value.(uint32)
+}
+
 func (llm *ggufModel) NumGQA() uint32 {
 func (llm *ggufModel) NumGQA() uint32 {
 	numHeadKv := llm.NumHeadKv()
 	numHeadKv := llm.NumHeadKv()
 	if numHeadKv == 0 {
 	if numHeadKv == 0 {

+ 5 - 0
llm/llm.go

@@ -35,6 +35,11 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options)
 		return nil, err
 		return nil, err
 	}
 	}
 
 
+	if opts.NumCtx > int(ggml.NumCtx()) {
+		log.Printf("WARNING: requested context length is greater than model's max context length (%d > %d), using %d instead", opts.NumCtx, ggml.NumCtx(), ggml.NumCtx())
+		opts.NumCtx = int(ggml.NumCtx())
+	}
+
 	if opts.NumCtx < 4 {
 	if opts.NumCtx < 4 {
 		opts.NumCtx = 4
 		opts.NumCtx = 4
 	}
 	}