Przeglądaj źródła

model: Load tensors behind an interface

Currently, if a model uses an interface for its data structures (as mllama
does) then the tensor data in the structs implementing that interface will
not get loaded.
Jesse Gross 3 miesięcy temu
rodzic
commit
d650ad398f
2 zmienionych plików z 31 dodań i 16 usunięć
  1. 29 14
      model/model.go
  2. 2 2
      model/model_test.go

+ 29 - 14
model/model.go

@@ -147,15 +147,12 @@ func New(s string) (Model, error) {
 	}
 
 	v := reflect.ValueOf(m)
-	v.Elem().Set(populateFields(b, v))
+	v.Elem().Set(populateFields(b, v.Elem()))
 	return m, nil
 }
 
 func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value {
 	t := v.Type()
-	if t.Kind() == reflect.Pointer {
-		t, v = t.Elem(), v.Elem()
-	}
 
 	if t.Kind() == reflect.Struct {
 		allNil := true
@@ -205,18 +202,16 @@ func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value {
 						break
 					}
 				}
-			} else if tt.Kind() == reflect.Pointer {
-				vvv := vv.Elem()
-				if vv.IsNil() {
-					vvv = reflect.New(tt.Elem())
-				}
-
-				if f := populateFields(b, vvv, tagsCopy...); f.CanAddr() {
-					vv.Set(f.Addr())
-				}
+			} else if tt.Kind() == reflect.Pointer || tt.Kind() == reflect.Interface {
+				setPointer(b, vv, tagsCopy)
 			} else if tt.Kind() == reflect.Slice || tt.Kind() == reflect.Array {
 				for i := range vv.Len() {
-					vv.Index(i).Set(populateFields(b, vv.Index(i), append(tagsCopy, Tag{Name: strconv.Itoa(i)})...))
+					vvv := vv.Index(i)
+					if vvv.Kind() == reflect.Pointer || vvv.Kind() == reflect.Interface {
+						setPointer(b, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)}))
+					} else {
+						vvv.Set(populateFields(b, vvv, append(tagsCopy, Tag{Name: strconv.Itoa(i)})...))
+					}
 				}
 			}
 
@@ -233,6 +228,26 @@ func populateFields(b ml.Backend, v reflect.Value, tags ...Tag) reflect.Value {
 	return v
 }
 
+func setPointer(b ml.Backend, v reflect.Value, tags []Tag) {
+	vv := v
+	if v.Kind() == reflect.Interface {
+		if v.IsNil() {
+			return
+		}
+
+		vv = vv.Elem()
+	}
+
+	vv = vv.Elem()
+	if v.IsNil() {
+		vv = reflect.New(v.Type().Elem()).Elem()
+	}
+
+	if f := populateFields(b, vv, tags...); f.CanAddr() {
+		v.Set(f.Addr())
+	}
+}
+
 type Tag struct {
 	Name      string
 	Alternate []string

+ 2 - 2
model/model_test.go

@@ -90,7 +90,7 @@ func TestPopulateFields(t *testing.T) {
 			"output_norm.weight",
 			"output.weight",
 		},
-	}, v))
+	}, v.Elem()))
 
 	if diff := cmp.Diff(fakeModel{
 		Input:      &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}},
@@ -125,7 +125,7 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
 		names: []string{
 			"input.weight",
 		},
-	}, v))
+	}, v.Elem()))
 
 	if diff := cmp.Diff(fakeModel{
 		Input:  &nn.Embedding{Weight: &fakeTensor{Name: "input.weight"}},