Browse Source

add safetensors version

Patrick Devine 1 year ago
parent
commit
4730762e5c
2 changed files with 20 additions and 4 deletions
  1. 11 4
      convert/llama.go
  2. 9 0
      convert/safetensors.go

+ 11 - 4
convert/llama.go

@@ -20,7 +20,7 @@ type LlamaModel struct {
 	ModelData
 }
 
-func llamaLayerHandler(w io.Writer, r torchWriterTo) error {
+func llamaTorchLayerHandler(w io.Writer, r torchWriterTo) error {
 	slog.Debug(fmt.Sprintf("repacking layer '%s'", r.t.Name))
 
 	data := r.storage.(*pytorch.HalfStorage).Data
@@ -105,9 +105,16 @@ func (m *LlamaModel) GetTensors() error {
 		matches := re.FindAllStringSubmatch(l.Name, -1)
 		if len(matches) > 0 {
 			slog.Debug(fmt.Sprintf("setting handler for: %s", l.Name))
-			wt := l.WriterTo.(torchWriterTo)
-			wt.handler = llamaLayerHandler
-			l.WriterTo = wt
+			switch l.WriterTo.(type) {
+			case torchWriterTo:
+				wt := l.WriterTo.(torchWriterTo)
+				wt.handler = llamaTorchLayerHandler
+				l.WriterTo = wt
+			case safetensorWriterTo:
+				wt := l.WriterTo.(safetensorWriterTo)
+				wt.handler = mistralLayerHandler
+				l.WriterTo = wt
+			}
 		}
 		m.Tensors = append(m.Tensors, l)
 	}

+ 9 - 0
convert/safetensors.go

@@ -281,6 +281,15 @@ func (m *SafetensorFormat) GetModelArch(name, dirPath string, params *Params) (M
 		return nil, fmt.Errorf("No architecture specified to convert")
 	case 1:
 		switch params.Architectures[0] {
+		case "LlamaForCausalLM":
+			return &LlamaModel{
+				ModelData{
+					Name:   name,
+					Path:   dirPath,
+					Params: params,
+					Format: m,
+				},
+			}, nil
 		case "MistralForCausalLM":
 			return &MistralModel{
 				ModelData{