Browse Source

fix ModelType()

Michael Yang 1 year ago
parent
commit
5ca05c2e88
1 changed files with 15 additions and 1 deletions
  1. 15 1
      llm/llama.go

+ 15 - 1
llm/llama.go

@@ -117,7 +117,21 @@ func (llm *llamaModel) ModelFamily() ModelFamily {
 }
 
 func (llm *llamaModel) ModelType() ModelType {
-	return ModelType30B
+	switch llm.hyperparameters.NumLayer {
+	case 26:
+		return ModelType3B
+	case 32:
+		return ModelType7B
+	case 40:
+		return ModelType13B
+	case 60:
+		return ModelType30B
+	case 80:
+		return ModelType65B
+	}
+
+	// TODO: find a better default
+	return ModelType7B
 }
 
 func (llm *llamaModel) FileType() FileType {