|
@@ -20,7 +20,7 @@ type LlamaModel struct {
|
|
|
ModelData
|
|
|
}
|
|
|
|
|
|
-func llamaLayerHandler(w io.Writer, r torchWriterTo) error {
|
|
|
+func llamaTorchLayerHandler(w io.Writer, r torchWriterTo) error {
|
|
|
slog.Debug(fmt.Sprintf("repacking layer '%s'", r.t.Name))
|
|
|
|
|
|
data := r.storage.(*pytorch.HalfStorage).Data
|
|
@@ -105,9 +105,16 @@ func (m *LlamaModel) GetTensors() error {
|
|
|
matches := re.FindAllStringSubmatch(l.Name, -1)
|
|
|
if len(matches) > 0 {
|
|
|
slog.Debug(fmt.Sprintf("setting handler for: %s", l.Name))
|
|
|
- wt := l.WriterTo.(torchWriterTo)
|
|
|
- wt.handler = llamaLayerHandler
|
|
|
- l.WriterTo = wt
|
|
|
+ switch l.WriterTo.(type) {
|
|
|
+ case torchWriterTo:
|
|
|
+ wt := l.WriterTo.(torchWriterTo)
|
|
|
+ wt.handler = llamaTorchLayerHandler
|
|
|
+ l.WriterTo = wt
|
|
|
+ case safetensorWriterTo:
|
|
|
+ wt := l.WriterTo.(safetensorWriterTo)
|
|
|
+ wt.handler = mistralLayerHandler
|
|
|
+ l.WriterTo = wt
|
|
|
+ }
|
|
|
}
|
|
|
m.Tensors = append(m.Tensors, l)
|
|
|
}
|