Ver Fonte

Merge pull request #519 from jmorganca/mxyng/decode

Mxyng/decode
Michael Yang há 1 ano atrás
pai
commit
949553db23
5 ficheiros alterados com 134 adições e 157 exclusões
  1. 61 38
      llm/ggml.go
  2. 23 19
      llm/gguf.go
  3. 18 80
      llm/llama.go
  4. 15 9
      llm/llm.go
  5. 17 11
      server/images.go

+ 61 - 38
llm/ggml.go

@@ -8,54 +8,77 @@ import (
 	"sync"
 )
 
-type ModelFamily string
-
-const ModelFamilyUnknown ModelFamily = "unknown"
-
-type ModelType uint32
+type GGML struct {
+	magic uint32
+	container
+	model
+}
 
 const (
-	ModelType3B  ModelType = 26
-	ModelType7B  ModelType = 32
-	ModelType13B ModelType = 40
-	ModelType34B ModelType = 48
-	ModelType30B ModelType = 60
-	ModelType65B ModelType = 80
+	fileTypeF32 uint32 = iota
+	fileTypeF16
+	fileTypeQ4_0
+	fileTypeQ4_1
+	fileTypeQ4_1_F16
+	fileTypeQ8_0 uint32 = iota + 2
+	fileTypeQ5_0
+	fileTypeQ5_1
+	fileTypeQ2_K
+	fileTypeQ3_K_S
+	fileTypeQ3_K_M
+	fileTypeQ3_K_L
+	fileTypeQ4_K_S
+	fileTypeQ4_K_M
+	fileTypeQ5_K_S
+	fileTypeQ5_K_M
+	fileTypeQ6_K
 )
 
-func (mt ModelType) String() string {
-	switch mt {
-	case ModelType3B:
-		return "3B"
-	case ModelType7B:
-		return "7B"
-	case ModelType13B:
-		return "13B"
-	case ModelType34B:
-		return "34B"
-	case ModelType30B:
-		return "30B"
-	case ModelType65B:
-		return "65B"
+func fileType(fileType uint32) string {
+	switch fileType {
+	case fileTypeF32:
+		return "F32"
+	case fileTypeF16:
+		return "F16"
+	case fileTypeQ4_0:
+		return "Q4_0"
+	case fileTypeQ4_1:
+		return "Q4_1"
+	case fileTypeQ4_1_F16:
+		return "Q4_1_F16"
+	case fileTypeQ8_0:
+		return "Q8_0"
+	case fileTypeQ5_0:
+		return "Q5_0"
+	case fileTypeQ5_1:
+		return "Q5_1"
+	case fileTypeQ2_K:
+		return "Q2_K"
+	case fileTypeQ3_K_S:
+		return "Q3_K_S"
+	case fileTypeQ3_K_M:
+		return "Q3_K_M"
+	case fileTypeQ3_K_L:
+		return "Q3_K_L"
+	case fileTypeQ4_K_S:
+		return "Q4_K_S"
+	case fileTypeQ4_K_M:
+		return "Q4_K_M"
+	case fileTypeQ5_K_S:
+		return "Q5_K_S"
+	case fileTypeQ5_K_M:
+		return "Q5_K_M"
+	case fileTypeQ6_K:
+		return "Q6_K"
 	default:
 		return "Unknown"
 	}
 }
 
-type FileType interface {
-	String() string
-}
-
-type GGML struct {
-	magic uint32
-	container
-	model
-}
-
 type model interface {
-	ModelFamily() ModelFamily
-	ModelType() ModelType
-	FileType() FileType
+	ModelFamily() string
+	ModelType() string
+	FileType() string
 }
 
 type container interface {

+ 23 - 19
llm/gguf.go

@@ -6,7 +6,6 @@ import (
 	"errors"
 	"fmt"
 	"io"
-	"log"
 	"path"
 	"sync"
 )
@@ -87,38 +86,43 @@ func (llm *ggufModel) NumKV() uint64 {
 	return llm.V2.NumKV
 }
 
-func (llm *ggufModel) ModelFamily() ModelFamily {
+func (llm *ggufModel) ModelFamily() string {
 	t, ok := llm.kv["general.architecture"].(string)
 	if ok {
-		return ModelFamily(t)
+		return t
 	}
 
-	log.Printf("unknown model family: %T", t)
-	return ModelFamilyUnknown
+	return "unknown"
 }
 
-func (llm *ggufModel) ModelType() ModelType {
+func (llm *ggufModel) ModelType() string {
 	switch llm.ModelFamily() {
-	case ModelFamilyLlama:
-		blocks, ok := llm.kv["llama.block_count"].(uint32)
-		if ok {
-			return ModelType(blocks)
+	case "llama":
+		if blocks, ok := llm.kv["llama.block_count"].(uint32); ok {
+			heads, headsOK := llm.kv["llama.head_count"].(uint32)
+			headKVs, headsKVsOK := llm.kv["llama.head_count_kv"].(uint32)
+			if headsOK && headsKVsOK && heads/headKVs == 8 {
+				return "70B"
+			}
+
+			return llamaModelType(blocks)
+		}
+	case "falcon":
+		if blocks, ok := llm.kv["falcon.block_count"].(uint32); ok {
+			return falconModelType(blocks)
 		}
 	}
 
-	return ModelType7B
+	return "Unknown"
 }
 
-func (llm *ggufModel) FileType() FileType {
-	switch llm.ModelFamily() {
-	case ModelFamilyLlama:
-		t, ok := llm.kv["general.file_type"].(uint32)
-		if ok {
-			return llamaFileType(t)
-		}
+func (llm *ggufModel) FileType() string {
+	t, ok := llm.kv["general.file_type"].(uint32)
+	if ok {
+		return fileType(t)
 	}
 
-	return llamaFileTypeF16
+	return "Unknown"
 }
 
 func (llm *ggufModel) Decode(r io.Reader) error {

+ 18 - 80
llm/llama.go

@@ -95,38 +95,39 @@ func chooseRunner(gpuPath, cpuPath string) string {
 	return runPath
 }
 
-const ModelFamilyLlama ModelFamily = "llama"
-
 type llamaModel struct {
 	hyperparameters llamaHyperparameters
 }
 
-func (llm *llamaModel) ModelFamily() ModelFamily {
-	return ModelFamilyLlama
+func (llm *llamaModel) ModelFamily() string {
+	return "llama"
 }
 
-func (llm *llamaModel) ModelType() ModelType {
-	switch llm.hyperparameters.NumLayer {
+func llamaModelType(numLayer uint32) string {
+	switch numLayer {
 	case 26:
-		return ModelType3B
+		return "3B"
 	case 32:
-		return ModelType7B
+		return "7B"
 	case 40:
-		return ModelType13B
+		return "13B"
 	case 48:
-		return ModelType34B
+		return "34B"
 	case 60:
-		return ModelType30B
+		return "30B"
 	case 80:
-		return ModelType65B
+		return "65B"
+	default:
+		return "Unknown"
 	}
+}
 
-	// TODO: find a better default
-	return ModelType7B
+func (llm *llamaModel) ModelType() string {
+	return llamaModelType(llm.hyperparameters.NumLayer)
 }
 
-func (llm *llamaModel) FileType() FileType {
-	return llm.hyperparameters.FileType
+func (llm *llamaModel) FileType() string {
+	return fileType(llm.hyperparameters.FileType)
 }
 
 type llamaHyperparameters struct {
@@ -143,70 +144,7 @@ type llamaHyperparameters struct {
 	NumRot   uint32
 
 	// FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc.
-	FileType llamaFileType
-}
-
-type llamaFileType uint32
-
-const (
-	llamaFileTypeF32 llamaFileType = iota
-	llamaFileTypeF16
-	llamaFileTypeQ4_0
-	llamaFileTypeQ4_1
-	llamaFileTypeQ4_1_F16
-	llamaFileTypeQ8_0 llamaFileType = iota + 2
-	llamaFileTypeQ5_0
-	llamaFileTypeQ5_1
-	llamaFileTypeQ2_K
-	llamaFileTypeQ3_K_S
-	llamaFileTypeQ3_K_M
-	llamaFileTypeQ3_K_L
-	llamaFileTypeQ4_K_S
-	llamaFileTypeQ4_K_M
-	llamaFileTypeQ5_K_S
-	llamaFileTypeQ5_K_M
-	llamaFileTypeQ6_K
-)
-
-func (ft llamaFileType) String() string {
-	switch ft {
-	case llamaFileTypeF32:
-		return "F32"
-	case llamaFileTypeF16:
-		return "F16"
-	case llamaFileTypeQ4_0:
-		return "Q4_0"
-	case llamaFileTypeQ4_1:
-		return "Q4_1"
-	case llamaFileTypeQ4_1_F16:
-		return "Q4_1_F16"
-	case llamaFileTypeQ8_0:
-		return "Q8_0"
-	case llamaFileTypeQ5_0:
-		return "Q5_0"
-	case llamaFileTypeQ5_1:
-		return "Q5_1"
-	case llamaFileTypeQ2_K:
-		return "Q2_K"
-	case llamaFileTypeQ3_K_S:
-		return "Q3_K_S"
-	case llamaFileTypeQ3_K_M:
-		return "Q3_K_M"
-	case llamaFileTypeQ3_K_L:
-		return "Q3_K_L"
-	case llamaFileTypeQ4_K_S:
-		return "Q4_K_S"
-	case llamaFileTypeQ4_K_M:
-		return "Q4_K_M"
-	case llamaFileTypeQ5_K_S:
-		return "Q5_K_S"
-	case llamaFileTypeQ5_K_M:
-		return "Q5_K_M"
-	case llamaFileTypeQ6_K:
-		return "Q6_K"
-	default:
-		return "Unknown"
-	}
+	FileType uint32
 }
 
 type Running struct {

+ 15 - 9
llm/llm.go

@@ -37,7 +37,7 @@ func New(model string, adapters []string, opts api.Options) (LLM, error) {
 		return nil, err
 	}
 
-	switch ggml.FileType().String() {
+	switch ggml.FileType() {
 	case "Q8_0":
 		if ggml.Name() != "gguf" && opts.NumGPU != 0 {
 			// GGML Q8_0 do not support Metal API and will
@@ -56,30 +56,36 @@ func New(model string, adapters []string, opts api.Options) (LLM, error) {
 
 	totalResidentMemory := memory.TotalMemory()
 	switch ggml.ModelType() {
-	case ModelType3B, ModelType7B:
-		if ggml.FileType().String() == "F16" && totalResidentMemory < 16*1024*1024 {
+	case "3B", "7B":
+		if ggml.FileType() == "F16" && totalResidentMemory < 16*1024*1024 {
 			return nil, fmt.Errorf("F16 model requires at least 16GB of memory")
 		} else if totalResidentMemory < 8*1024*1024 {
 			return nil, fmt.Errorf("model requires at least 8GB of memory")
 		}
-	case ModelType13B:
-		if ggml.FileType().String() == "F16" && totalResidentMemory < 32*1024*1024 {
+	case "13B":
+		if ggml.FileType() == "F16" && totalResidentMemory < 32*1024*1024 {
 			return nil, fmt.Errorf("F16 model requires at least 32GB of memory")
 		} else if totalResidentMemory < 16*1024*1024 {
 			return nil, fmt.Errorf("model requires at least 16GB of memory")
 		}
-	case ModelType30B, ModelType34B:
-		if ggml.FileType().String() == "F16" && totalResidentMemory < 64*1024*1024 {
+	case "30B", "34B", "40B":
+		if ggml.FileType() == "F16" && totalResidentMemory < 64*1024*1024 {
 			return nil, fmt.Errorf("F16 model requires at least 64GB of memory")
 		} else if totalResidentMemory < 32*1024*1024 {
 			return nil, fmt.Errorf("model requires at least 32GB of memory")
 		}
-	case ModelType65B:
-		if ggml.FileType().String() == "F16" && totalResidentMemory < 128*1024*1024 {
+	case "65B", "70B":
+		if ggml.FileType() == "F16" && totalResidentMemory < 128*1024*1024 {
 			return nil, fmt.Errorf("F16 model requires at least 128GB of memory")
 		} else if totalResidentMemory < 64*1024*1024 {
 			return nil, fmt.Errorf("model requires at least 64GB of memory")
 		}
+	case "180B":
+		if ggml.FileType() == "F16" && totalResidentMemory < 512*1024*1024 {
+			return nil, fmt.Errorf("F16 model requires at least 512GB of memory")
+		} else if totalResidentMemory < 128*1024*1024 {
+			return nil, fmt.Errorf("model requires at least 128GB of memory")
+		}
 	}
 
 	switch ggml.Name() {

+ 17 - 11
server/images.go

@@ -114,11 +114,11 @@ type LayerReader struct {
 }
 
 type ConfigV2 struct {
-	ModelFamily llm.ModelFamily `json:"model_family"`
-	ModelType   string          `json:"model_type"`
-	ModelFormat string          `json:"model_format"`
-	FileType    string          `json:"file_type"`
-	RootFS      RootFS          `json:"rootfs"`
+	ModelFormat string `json:"model_format"`
+	ModelFamily string `json:"model_family"`
+	ModelType   string `json:"model_type"`
+	FileType    string `json:"file_type"`
+	RootFS      RootFS `json:"rootfs"`
 
 	// required by spec
 	Architecture string `json:"architecture"`
@@ -357,10 +357,10 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
 						return err
 					}
 
-					config.ModelFamily = ggml.ModelFamily()
-					config.ModelType = ggml.ModelType().String()
 					config.ModelFormat = ggml.Name()
-					config.FileType = ggml.FileType().String()
+					config.ModelFamily = ggml.ModelFamily()
+					config.ModelType = ggml.ModelType()
+					config.FileType = ggml.FileType()
 
 					// reset the file
 					file.Seek(0, io.SeekStart)
@@ -498,6 +498,12 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
 			}
 		}
 
+		if config.ModelType == "65B" {
+			if numGQA, ok := formattedParams["num_gqa"].(int); ok && numGQA == 8 {
+				config.ModelType = "70B"
+			}
+		}
+
 		bts, err := json.Marshal(formattedParams)
 		if err != nil {
 			return err
@@ -815,14 +821,14 @@ func formatParams(params map[string][]string) (map[string]interface{}, error) {
 						return nil, fmt.Errorf("invalid float value %s", vals)
 					}
 
-					out[key] = floatVal
+					out[key] = float32(floatVal)
 				case reflect.Int:
-					intVal, err := strconv.ParseInt(vals[0], 10, 0)
+					intVal, err := strconv.ParseInt(vals[0], 10, 64)
 					if err != nil {
 						return nil, fmt.Errorf("invalid int value %s", vals)
 					}
 
-					out[key] = intVal
+					out[key] = int(intVal)
 				case reflect.Bool:
 					boolVal, err := strconv.ParseBool(vals[0])
 					if err != nil {