Procházet zdrojové kódy

model and file type as strings

Michael Yang před 1 rokem
rodič
revize
a894cc792d
4 změnil soubory, kde provedl 133 přidání a 48 odebrání
  1. 28 25
      llm/ggml.go
  2. 93 11
      llm/llama.go
  3. 6 6
      llm/llm.go
  4. 6 6
      server/images.go

+ 28 - 25
llm/ggml.go

@@ -9,8 +9,6 @@ import (
 
 type ModelFamily string
 
-const ModelFamilyLlama ModelFamily = "llama"
-
 type ModelType uint32
 
 const (
@@ -21,32 +19,37 @@ const (
 	ModelType65B ModelType = 80
 )
 
-type FileType uint32
+func (mt ModelType) String() string {
+	switch mt {
+	case ModelType3B:
+		return "3B"
+	case ModelType7B:
+		return "7B"
+	case ModelType13B:
+		return "13B"
+	case ModelType30B:
+		return "30B"
+	case ModelType65B:
+		return "65B"
+	default:
+		return "Unknown"
+	}
+}
 
-const (
-	FileTypeF32 FileType = iota
-	FileTypeF16
-	FileTypeQ4_0
-	FileTypeQ4_1
-	FileTypeQ4_1_F16
-	FileTypeQ8_0 = iota + 2
-	FileTypeQ5_0
-	FileTypeQ5_1
-	FileTypeQ2_K
-	FileTypeQ3_K
-	FileTypeQ4_K
-	FileTypeQ5_K
-	FileTypeQ6_K
-)
+type FileType interface {
+	String() string
+}
 
 type GGML struct {
-	ModelFamily
-	ModelType
-
 	magic uint32
 	container
+	model
+}
 
-	llamaHyperparameters
+type model interface {
+	ModelFamily() ModelFamily
+	ModelType() ModelType
+	FileType() FileType
 }
 
 type container interface {
@@ -166,14 +169,14 @@ func DecodeGGML(r io.ReadSeeker, hint ModelFamily) (*GGML, error) {
 	// different model types may have different layouts for hyperparameters
 	switch hint {
 	case ModelFamilyLlama:
-		binary.Read(r, binary.LittleEndian, &ggml.llamaHyperparameters)
+		var llama llamaModel
+		binary.Read(r, binary.LittleEndian, &llama.hyperparameters)
+		ggml.model = &llama
 		// TODO: sanity check hyperparameters
 	default:
 		return nil, fmt.Errorf("unsupported model type: %s", hint)
 	}
 
 	// final model type
-	ggml.ModelFamily = hint
-	ggml.ModelType = ModelType(ggml.NumLayer)
 	return &ggml, nil
 }

+ 93 - 11
llm/llama.go

@@ -106,19 +106,22 @@ import (
 //go:embed ggml-metal.metal
 var fs embed.FS
 
-type llama struct {
-	params *C.struct_llama_context_params
-	model  *C.struct_llama_model
-	ctx    *C.struct_llama_context
+const ModelFamilyLlama ModelFamily = "llama"
 
-	last   []C.llama_token
-	embd   []C.llama_token
-	cursor int
+type llamaModel struct {
+	hyperparameters llamaHyperparameters
+}
 
-	mu sync.Mutex
-	gc bool
+func (llm *llamaModel) ModelFamily() ModelFamily {
+	return ModelFamilyLlama
+}
 
-	api.Options
+func (llm *llamaModel) ModelType() ModelType {
+	return ModelType30B
+}
+
+func (llm *llamaModel) FileType() FileType {
+	return llm.hyperparameters.FileType
 }
 
 type llamaHyperparameters struct {
@@ -133,8 +136,87 @@ type llamaHyperparameters struct {
 	// NumLayer is the number of layers in the model.
 	NumLayer uint32
 	NumRot   uint32
+
 	// FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc.
-	FileType
+	FileType llamaFileType
+}
+
+type llamaFileType uint32
+
+const (
+	llamaFileTypeF32 llamaFileType = iota
+	llamaFileTypeF16
+	llamaFileTypeQ4_0
+	llamaFileTypeQ4_1
+	llamaFileTypeQ4_1_F16
+	llamaFileTypeQ8_0 llamaFileType = iota + 2
+	llamaFileTypeQ5_0
+	llamaFileTypeQ5_1
+	llamaFileTypeQ2_K
+	llamaFileTypeQ3_K_S
+	llamaFileTypeQ3_K_M
+	llamaFileTypeQ3_K_L
+	llamaFileTypeQ4_K_S
+	llamaFileTypeQ4_K_M
+	llamaFileTypeQ5_K_S
+	llamaFileTypeQ5_K_M
+	llamaFileTypeQ6_K
+)
+
+func (ft llamaFileType) String() string {
+	switch ft {
+	case llamaFileTypeF32:
+		return "F32"
+	case llamaFileTypeF16:
+		return "F16"
+	case llamaFileTypeQ4_0:
+		return "Q4_0"
+	case llamaFileTypeQ4_1:
+		return "Q4_1"
+	case llamaFileTypeQ4_1_F16:
+		return "Q4_1_F16"
+	case llamaFileTypeQ8_0:
+		return "Q8_0"
+	case llamaFileTypeQ5_0:
+		return "Q5_0"
+	case llamaFileTypeQ5_1:
+		return "Q5_1"
+	case llamaFileTypeQ2_K:
+		return "Q2_K"
+	case llamaFileTypeQ3_K_S:
+		return "Q3_K_S"
+	case llamaFileTypeQ3_K_M:
+		return "Q3_K_M"
+	case llamaFileTypeQ3_K_L:
+		return "Q3_K_L"
+	case llamaFileTypeQ4_K_S:
+		return "Q4_K_S"
+	case llamaFileTypeQ4_K_M:
+		return "Q4_K_M"
+	case llamaFileTypeQ5_K_S:
+		return "Q5_K_S"
+	case llamaFileTypeQ5_K_M:
+		return "Q5_K_M"
+	case llamaFileTypeQ6_K:
+		return "Q6_K"
+	default:
+		return "Unknown"
+	}
+}
+
+type llama struct {
+	params *C.struct_llama_context_params
+	model  *C.struct_llama_model
+	ctx    *C.struct_llama_context
+
+	last   []C.llama_token
+	embd   []C.llama_token
+	cursor int
+
+	mu sync.Mutex
+	gc bool
+
+	api.Options
 }
 
 func newLlama(model string, adapters []string, opts api.Options) (*llama, error) {

+ 6 - 6
llm/llm.go

@@ -35,10 +35,10 @@ func New(model string, adapters []string, opts api.Options) (LLM, error) {
 		return nil, err
 	}
 
-	switch ggml.FileType {
-	case FileTypeF32, FileTypeF16, FileTypeQ5_0, FileTypeQ5_1, FileTypeQ8_0:
+	switch ggml.FileType().String() {
+	case "F32", "F16", "Q5_0", "Q5_1", "Q8_0":
 		if opts.NumGPU != 0 {
-			// Q5_0, Q5_1, and Q8_0 do not support Metal API and will
+			// F32, F16, Q5_0, Q5_1, and Q8_0 do not support Metal API and will
 			// cause the runner to segmentation fault so disable GPU
 			log.Printf("WARNING: GPU disabled for F32, F16, Q5_0, Q5_1, and Q8_0")
 			opts.NumGPU = 0
@@ -46,7 +46,7 @@ func New(model string, adapters []string, opts api.Options) (LLM, error) {
 	}
 
 	totalResidentMemory := memory.TotalMemory()
-	switch ggml.ModelType {
+	switch ggml.ModelType() {
 	case ModelType3B, ModelType7B:
 		if totalResidentMemory < 8*1024*1024 {
 			return nil, fmt.Errorf("model requires at least 8GB of memory")
@@ -65,10 +65,10 @@ func New(model string, adapters []string, opts api.Options) (LLM, error) {
 		}
 	}
 
-	switch ggml.ModelFamily {
+	switch ggml.ModelFamily() {
 	case ModelFamilyLlama:
 		return newLlama(model, adapters, opts)
 	default:
-		return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily)
+		return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily())
 	}
 }

+ 6 - 6
server/images.go

@@ -105,9 +105,9 @@ type LayerReader struct {
 
 type ConfigV2 struct {
 	ModelFamily llm.ModelFamily `json:"model_family"`
-	ModelType   llm.ModelType   `json:"model_type"`
-	FileType    llm.FileType    `json:"file_type"`
-	RootFS      RootFS          `json:"rootfs"`
+	ModelType   string      `json:"model_type"`
+	FileType    string      `json:"file_type"`
+	RootFS      RootFS      `json:"rootfs"`
 
 	// required by spec
 	Architecture string `json:"architecture"`
@@ -308,9 +308,9 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
 						return err
 					}
 
-					config.ModelFamily = ggml.ModelFamily
-					config.ModelType = ggml.ModelType
-					config.FileType = ggml.FileType
+					config.ModelFamily = ggml.ModelFamily()
+					config.ModelType = ggml.ModelType().String()
+					config.FileType = ggml.FileType().String()
 
 					// reset the file
 					file.Seek(0, io.SeekStart)