Browse Source

fix model type for 70b

Michael Yang 1 year ago
parent
commit
0c5a454361
2 changed files with 15 additions and 3 deletions
  1. 6 0
      llm/gguf.go
  2. 9 3
      server/images.go

+ 6 - 0
llm/gguf.go

@@ -99,6 +99,12 @@ func (llm *ggufModel) ModelType() string {
 	switch llm.ModelFamily() {
 	case "llama":
 		if blocks, ok := llm.kv["llama.block_count"].(uint32); ok {
+			heads, headsOK := llm.kv["llama.head_count"].(uint32)
+			headKVs, headsKVsOK := llm.kv["llama.head_count_kv"].(uint32)
+			if headsOK && headsKVsOK && heads/headKVs == 8 {
+				return "70B"
+			}
+
 			return llamaModelType(blocks)
 		}
 	case "falcon":

+ 9 - 3
server/images.go

@@ -498,6 +498,12 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
 			}
 		}
 
+		if config.ModelType == "65B" {
+			if numGQA, ok := formattedParams["num_gqa"].(int); ok && numGQA == 8 {
+				config.ModelType = "70B"
+			}
+		}
+
 		bts, err := json.Marshal(formattedParams)
 		if err != nil {
 			return err
@@ -815,14 +821,14 @@ func formatParams(params map[string][]string) (map[string]interface{}, error) {
 						return nil, fmt.Errorf("invalid float value %s", vals)
 					}
 
-					out[key] = floatVal
+					out[key] = float32(floatVal)
 				case reflect.Int:
-					intVal, err := strconv.ParseInt(vals[0], 10, 0)
+					intVal, err := strconv.ParseInt(vals[0], 10, 64)
 					if err != nil {
 						return nil, fmt.Errorf("invalid int value %s", vals)
 					}
 
-					out[key] = intVal
+					out[key] = int(intVal)
 				case reflect.Bool:
 					boolVal, err := strconv.ParseBool(vals[0])
 					if err != nil {