Explorar el Código

server/.../safetensors: fix offsets and include all model parts (#9427)

Also, require the -as flag to be set when importing a model. This
prevents the confusing error message "invalid name".

Also, allow short names to be used when importing a model and
auto-complete the name with the default mask.
Blake Mizerany hace 2 meses
padre
commit
eed11ded30

+ 13 - 4
server/internal/client/ollama/registry.go

@@ -147,14 +147,23 @@ func (e *Error) UnmarshalJSON(b []byte) error {
 	return nil
 	return nil
 }
 }
 
 
-var defaultName = func() names.Name {
-	n := names.Parse("registry.ollama.ai/library/_:latest")
+const DefaultMask = "registry.ollama.ai/library/_:latest"
+
+var defaultMask = func() names.Name {
+	n := names.Parse(DefaultMask)
 	if !n.IsFullyQualified() {
 	if !n.IsFullyQualified() {
-		panic("default name is not fully qualified")
+		panic("default mask is not fully qualified")
 	}
 	}
 	return n
 	return n
 }()
 }()
 
 
+// CompleteName returns a fully qualified name by merging the given name with
+// the default mask. If the name is already fully qualified, it is returned
+// unchanged.
+func CompleteName(name string) string {
+	return names.Merge(names.Parse(name), defaultMask).String()
+}
+
 // Registry is a client for performing push and pull operations against an
 // Registry is a client for performing push and pull operations against an
 // Ollama registry.
 // Ollama registry.
 type Registry struct {
 type Registry struct {
@@ -249,7 +258,7 @@ type PushParams struct {
 //
 //
 // The scheme is returned as provided by [names.ParseExtended].
 // The scheme is returned as provided by [names.ParseExtended].
 func parseName(s, mask string) (scheme string, n names.Name, d blob.Digest, err error) {
 func parseName(s, mask string) (scheme string, n names.Name, d blob.Digest, err error) {
-	maskName := defaultName
+	maskName := defaultMask
 	if mask != "" {
 	if mask != "" {
 		maskName = names.Parse(mask)
 		maskName = names.Parse(mask)
 		if !maskName.IsFullyQualified() {
 		if !maskName.IsFullyQualified() {

+ 6 - 2
server/internal/cmd/opp/internal/safetensors/safetensors.go

@@ -86,6 +86,8 @@ func (m *Model) readTensors(fname string) ([]*Tensor, error) {
 		return nil, err
 		return nil, err
 	}
 	}
 
 
+	endOfHeader := 8 + headerSize // 8 bytes for header size plus the header itself
+
 	// TODO(bmizerany): do something with metadata? This could be another
 	// TODO(bmizerany): do something with metadata? This could be another
 	// header read if needed. We also need to figure out if the metadata is
 	// header read if needed. We also need to figure out if the metadata is
 	// present in only one .safetensors file or if each file may have their
 	// present in only one .safetensors file or if each file may have their
@@ -95,7 +97,8 @@ func (m *Model) readTensors(fname string) ([]*Tensor, error) {
 
 
 	tt := make([]*Tensor, 0, len(raws))
 	tt := make([]*Tensor, 0, len(raws))
 	for name, raw := range raws {
 	for name, raw := range raws {
-		if !strings.HasPrefix(name, "model.layer") {
+		if name == "__metadata__" {
+			// TODO(bmizerany): do something with metadata?
 			continue
 			continue
 		}
 		}
 		var v struct {
 		var v struct {
@@ -112,7 +115,8 @@ func (m *Model) readTensors(fname string) ([]*Tensor, error) {
 
 
 		// TODO(bmizerany): after collecting, validate all offests make
 		// TODO(bmizerany): after collecting, validate all offests make
 		// tensors contiguous?
 		// tensors contiguous?
-		begin, end := v.Offsets[0], v.Offsets[1]
+		begin := endOfHeader + v.Offsets[0]
+		end := endOfHeader + v.Offsets[1]
 		if err := checkBeginEnd(finfo.Size(), begin, end); err != nil {
 		if err := checkBeginEnd(finfo.Size(), begin, end); err != nil {
 			return nil, err
 			return nil, err
 		}
 		}

+ 7 - 1
server/internal/cmd/opp/opp.go

@@ -228,6 +228,10 @@ func cmdImport(ctx context.Context, c *blob.DiskCache) error {
 		flag.PrintDefaults()
 		flag.PrintDefaults()
 	}
 	}
 	flag.Parse(args)
 	flag.Parse(args)
+	if *flagAs == "" {
+		return fmt.Errorf("missing -as flag")
+	}
+	as := ollama.CompleteName(*flagAs)
 
 
 	dir := cmp.Or(flag.Arg(0), ".")
 	dir := cmp.Or(flag.Arg(0), ".")
 	fmt.Fprintf(os.Stderr, "Reading %s\n", dir)
 	fmt.Fprintf(os.Stderr, "Reading %s\n", dir)
@@ -311,7 +315,7 @@ func cmdImport(ctx context.Context, c *blob.DiskCache) error {
 			if err != nil {
 			if err != nil {
 				return err
 				return err
 			}
 			}
-			return c.Link(*flagAs, d)
+			return c.Link(as, d)
 		}()
 		}()
 	}()
 	}()
 
 
@@ -340,6 +344,8 @@ func cmdImport(ctx context.Context, c *blob.DiskCache) error {
 			writeProgress()
 			writeProgress()
 		case err := <-done:
 		case err := <-done:
 			writeProgress()
 			writeProgress()
+			fmt.Println()
+			fmt.Println("Successfully imported", as)
 			return err
 			return err
 		}
 		}
 	}
 	}