Browse Source

add model IDs (#439)

Patrick Devine 1 year ago
parent
commit
8bbff2df98
4 changed files with 27 additions and 19 deletions
  1. 1 0
      api/types.go
  2. 3 3
      cmd/cmd.go
  3. 21 15
      server/images.go
  4. 2 1
      server/routes.go

+ 1 - 0
api/types.go

@@ -96,6 +96,7 @@ type ListResponseModel struct {
 	Name       string    `json:"name"`
 	ModifiedAt time.Time `json:"modified_at"`
 	Size       int       `json:"size"`
+	Digest     string    `json:"digest"`
 }
 
 type TokenResponse struct {

+ 3 - 3
cmd/cmd.go

@@ -196,12 +196,12 @@ func ListHandler(cmd *cobra.Command, args []string) error {
 
 	for _, m := range models.Models {
 		if len(args) == 0 || strings.HasPrefix(m.Name, args[0]) {
-			data = append(data, []string{m.Name, humanize.Bytes(uint64(m.Size)), format.HumanTime(m.ModifiedAt, "Never")})
+			data = append(data, []string{m.Name, m.Digest[:12], humanize.Bytes(uint64(m.Size)), format.HumanTime(m.ModifiedAt, "Never")})
 		}
 	}
 
 	table := tablewriter.NewWriter(os.Stdout)
-	table.SetHeader([]string{"NAME", "SIZE", "MODIFIED"})
+	table.SetHeader([]string{"NAME", "ID", "SIZE", "MODIFIED"})
 	table.SetHeaderAlignment(tablewriter.ALIGN_LEFT)
 	table.SetAlignment(tablewriter.ALIGN_LEFT)
 	table.SetHeaderLine(false)
@@ -527,7 +527,7 @@ func generateInteractive(cmd *cobra.Command, model string) error {
 					return err
 				}
 
-				manifest, err := server.GetManifest(mp)
+				manifest, _, err := server.GetManifest(mp)
 				if err != nil {
 					fmt.Println("error: couldn't get a manifest for this model")
 					continue

+ 21 - 15
server/images.go

@@ -5,6 +5,7 @@ import (
 	"bytes"
 	"context"
 	"crypto/sha256"
+	"encoding/hex"
 	"encoding/json"
 	"errors"
 	"fmt"
@@ -44,6 +45,7 @@ type Model struct {
 	Template     string
 	System       string
 	Digest       string
+	ConfigDigest string
 	Options      map[string]interface{}
 	Embeddings   []vector.Embedding
 }
@@ -131,41 +133,45 @@ func (m *ManifestV2) GetTotalSize() int {
 	return total
 }
 
-func GetManifest(mp ModelPath) (*ManifestV2, error) {
+func GetManifest(mp ModelPath) (*ManifestV2, string, error) {
 	fp, err := mp.GetManifestPath(false)
 	if err != nil {
-		return nil, err
+		return nil, "", err
 	}
 
 	if _, err = os.Stat(fp); err != nil {
-		return nil, err
+		return nil, "", err
 	}
 
 	var manifest *ManifestV2
 
 	bts, err := os.ReadFile(fp)
 	if err != nil {
-		return nil, fmt.Errorf("couldn't open file '%s'", fp)
+		return nil, "", fmt.Errorf("couldn't open file '%s'", fp)
 	}
 
+	shaSum := sha256.Sum256(bts)
+	shaStr := hex.EncodeToString(shaSum[:])
+
 	if err := json.Unmarshal(bts, &manifest); err != nil {
-		return nil, err
+		return nil, "", err
 	}
 
-	return manifest, nil
+	return manifest, shaStr, nil
 }
 
 func GetModel(name string) (*Model, error) {
 	mp := ParseModelPath(name)
-	manifest, err := GetManifest(mp)
+	manifest, digest, err := GetManifest(mp)
 	if err != nil {
 		return nil, err
 	}
 
 	model := &Model{
-		Name:     mp.GetFullTagname(),
-		Digest:   manifest.Config.Digest,
-		Template: "{{ .Prompt }}",
+		Name:         mp.GetFullTagname(),
+		Digest:       digest,
+		ConfigDigest: manifest.Config.Digest,
+		Template:     "{{ .Prompt }}",
 	}
 
 	for _, layer := range manifest.Layers {
@@ -277,7 +283,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
 			embed.model = c.Args
 
 			mp := ParseModelPath(c.Args)
-			mf, err := GetManifest(mp)
+			mf, _, err := GetManifest(mp)
 			if err != nil {
 				modelFile, err := filenameWithPath(path, c.Args)
 				if err != nil {
@@ -290,7 +296,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
 						if err := PullModel(ctx, c.Args, &RegistryOptions{}, fn); err != nil {
 							return err
 						}
-						mf, err = GetManifest(mp)
+						mf, _, err = GetManifest(mp)
 						if err != nil {
 							return fmt.Errorf("failed to open file after pull: %v", err)
 						}
@@ -839,7 +845,7 @@ func CopyModel(src, dest string) error {
 
 func DeleteModel(name string) error {
 	mp := ParseModelPath(name)
-	manifest, err := GetManifest(mp)
+	manifest, _, err := GetManifest(mp)
 	if err != nil {
 		return err
 	}
@@ -872,7 +878,7 @@ func DeleteModel(name string) error {
 			}
 
 			// save (i.e. delete from the deleteMap) any files used in other manifests
-			manifest, err := GetManifest(fmp)
+			manifest, _, err := GetManifest(fmp)
 			if err != nil {
 				log.Printf("skipping file: %s", fp)
 				return nil
@@ -924,7 +930,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
 		return fmt.Errorf("insecure protocol http")
 	}
 
-	manifest, err := GetManifest(mp)
+	manifest, _, err := GetManifest(mp)
 	if err != nil {
 		fn(api.ProgressResponse{Status: "couldn't retrieve manifest"})
 		return err

+ 2 - 1
server/routes.go

@@ -373,7 +373,7 @@ func ListModelsHandler(c *gin.Context) {
 			tag := path[:slashIndex] + ":" + path[slashIndex+1:]
 
 			mp := ParseModelPath(tag)
-			manifest, err := GetManifest(mp)
+			manifest, digest, err := GetManifest(mp)
 			if err != nil {
 				log.Printf("skipping file: %s", fp)
 				return nil
@@ -381,6 +381,7 @@ func ListModelsHandler(c *gin.Context) {
 			model := api.ListResponseModel{
 				Name:       mp.GetShortTagname(),
 				Size:       manifest.GetTotalSize(),
+				Digest:     digest,
 				ModifiedAt: fi.ModTime(),
 			}
 			models = append(models, model)