Parcourir la source

basic distribution w/ push/pull (#78)

* basic distribution w/ push/pull

* add the parser

* add create, pull, and push

* changes to the parser, FROM line, and fix commands

* mkdirp new manifest directories

* make `blobs` directory if it does not exist

* fix go warnings

* add progressbar for model pulls

* move model struct

---------

Co-authored-by: Jeffrey Morgan <jmorganca@gmail.com>
Patrick Devine il y a 1 an
Parent
commit
2fb52261ad
7 fichiers modifiés avec 1154 ajouts et 214 suppressions
  1. 26 0
      api/client.go
  2. 36 9
      api/types.go
  3. 87 13
      cmd/cmd.go
  4. 77 0
      parser/parser.go
  5. 842 0
      server/images.go
  6. 0 128
      server/models.go
  7. 86 64
      server/routes.go

+ 26 - 0
api/client.go

@@ -116,3 +116,29 @@ func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc
 		return fn(resp)
 	})
 }
+
+type PushProgressFunc func(PushProgress) error
+
+func (c *Client) Push(ctx context.Context, req *PushRequest, fn PushProgressFunc) error {
+	return c.stream(ctx, http.MethodPost, "/api/push", req, func(bts []byte) error {
+		var resp PushProgress
+		if err := json.Unmarshal(bts, &resp); err != nil {
+			return err
+		}
+
+		return fn(resp)
+	})
+}
+
+type CreateProgressFunc func(CreateProgress) error
+
+func (c *Client) Create(ctx context.Context, req *CreateRequest, fn CreateProgressFunc) error {
+	return c.stream(ctx, http.MethodPost, "/api/create", req, func(bts []byte) error {
+		var resp CreateProgress
+		if err := json.Unmarshal(bts, &resp); err != nil {
+			return err
+		}
+
+		return fn(resp)
+	})
+}

+ 36 - 9
api/types.go

@@ -7,22 +7,49 @@ import (
 	"time"
 )
 
+type GenerateRequest struct {
+	Model   string `json:"model"`
+	Prompt  string `json:"prompt"`
+	Context []int  `json:"context,omitempty"`
+
+	Options `json:"options"`
+}
+
+type CreateRequest struct {
+	Name string `json:"name"`
+	Path string `json:"path"`
+}
+
+type CreateProgress struct {
+	Status string `json:"status"`
+}
+
 type PullRequest struct {
-	Model string `json:"model"`
+	Name     string `json:"name"`
+	Username string `json:"username"`
+	Password string `json:"password"`
 }
 
 type PullProgress struct {
-	Total     int64   `json:"total"`
-	Completed int64   `json:"completed"`
-	Percent   float64 `json:"percent"`
+	Status    string  `json:"status"`
+	Digest    string  `json:"digest,omitempty"`
+	Total     int     `json:"total,omitempty"`
+	Completed int     `json:"completed,omitempty"`
+	Percent   float64 `json:"percent,omitempty"`
 }
 
-type GenerateRequest struct {
-	Model   string `json:"model"`
-	Prompt  string `json:"prompt"`
-	Context []int  `json:"context,omitempty"`
+type PushRequest struct {
+	Name     string `json:"name"`
+	Username string `json:"username"`
+	Password string `json:"password"`
+}
 
-	Options `json:"options"`
+type PushProgress struct {
+	Status    string  `json:"status"`
+	Digest    string  `json:"digest,omitempty"`
+	Total     int     `json:"total,omitempty"`
+	Completed int     `json:"completed,omitempty"`
+	Percent   float64 `json:"percent,omitempty"`
 }
 
 type GenerateResponse struct {

+ 87 - 13
cmd/cmd.go

@@ -30,6 +30,23 @@ func cacheDir() string {
 	return filepath.Join(home, ".ollama")
 }
 
+func create(cmd *cobra.Command, args []string) error {
+	filename, _ := cmd.Flags().GetString("file")
+	client := api.NewClient()
+
+	request := api.CreateRequest{Name: args[0], Path: filename}
+	fn := func(resp api.CreateProgress) error {
+		fmt.Println(resp.Status)
+		return nil
+	}
+
+	if err := client.Create(context.Background(), &request, fn); err != nil {
+		return err
+	}
+
+	return nil
+}
+
 func RunRun(cmd *cobra.Command, args []string) error {
 	_, err := os.Stat(args[0])
 	switch {
@@ -51,25 +68,56 @@ func RunRun(cmd *cobra.Command, args []string) error {
 	return RunGenerate(cmd, args)
 }
 
+func push(cmd *cobra.Command, args []string) error {
+	client := api.NewClient()
+
+	request := api.PushRequest{Name: args[0]}
+	fn := func(resp api.PushProgress) error {
+		fmt.Println(resp.Status)
+		return nil
+	}
+
+	if err := client.Push(context.Background(), &request, fn); err != nil {
+		return err
+	}
+	return nil
+}
+
+func RunPull(cmd *cobra.Command, args []string) error {
+	return pull(args[0])
+}
+
 func pull(model string) error {
 	client := api.NewClient()
+
 	var bar *progressbar.ProgressBar
-	return client.Pull(
-		context.Background(),
-		&api.PullRequest{Model: model},
-		func(progress api.PullProgress) error {
-			if bar == nil {
-				if progress.Percent >= 100 {
-					// already downloaded
-					return nil
-				}
 
-				bar = progressbar.DefaultBytes(progress.Total)
+	currentLayer := ""
+	request := api.PullRequest{Name: model}
+	fn := func(resp api.PullProgress) error {
+		if resp.Digest != currentLayer && resp.Digest != "" {
+			if currentLayer != "" {
+				fmt.Println()
 			}
+			currentLayer = resp.Digest
+			layerStr := resp.Digest[7:23] + "..."
+			bar = progressbar.DefaultBytes(
+				int64(resp.Total),
+				"pulling "+layerStr,
+			)
+		} else if resp.Digest == currentLayer && resp.Digest != "" {
+			bar.Set(resp.Completed)
+		} else {
+			currentLayer = ""
+			fmt.Println(resp.Status)
+		}
+		return nil
+	}
 
-			return bar.Set64(progress.Completed)
-		},
-	)
+	if err := client.Pull(context.Background(), &request, fn); err != nil {
+		return err
+	}
+	return nil
 }
 
 func RunGenerate(cmd *cobra.Command, args []string) error {
@@ -215,6 +263,15 @@ func NewCLI() *cobra.Command {
 
 	cobra.EnableCommandSorting = false
 
+	createCmd := &cobra.Command{
+		Use:   "create MODEL",
+		Short: "Create a model from a Modelfile",
+		Args:  cobra.MinimumNArgs(1),
+		RunE:  create,
+	}
+
+	createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile (default \"Modelfile\")")
+
 	runCmd := &cobra.Command{
 		Use:   "run MODEL [PROMPT]",
 		Short: "Run a model",
@@ -231,9 +288,26 @@ func NewCLI() *cobra.Command {
 		RunE:    RunServer,
 	}
 
+	pullCmd := &cobra.Command{
+		Use:   "pull MODEL",
+		Short: "Pull a model from a registry",
+		Args:  cobra.MinimumNArgs(1),
+		RunE:  RunPull,
+	}
+
+	pushCmd := &cobra.Command{
+		Use:   "push MODEL",
+		Short: "Push a model to a registry",
+		Args:  cobra.MinimumNArgs(1),
+		RunE:  push,
+	}
+
 	rootCmd.AddCommand(
 		serveCmd,
+		createCmd,
 		runCmd,
+		pullCmd,
+		pushCmd,
 	)
 
 	return rootCmd

+ 77 - 0
parser/parser.go

@@ -0,0 +1,77 @@
+package parser
+
+import (
+	"bufio"
+	"fmt"
+	"io"
+	"strings"
+)
+
+type Command struct {
+	Name string
+	Arg  string
+}
+
+func Parse(reader io.Reader) ([]Command, error) {
+	var commands []Command
+	var foundModel bool
+
+	scanner := bufio.NewScanner(reader)
+	multiline := false
+	var multilineCommand *Command
+	for scanner.Scan() {
+		line := scanner.Text()
+		if multiline {
+			// If we're in a multiline string and the line is """, end the multiline string.
+			if strings.TrimSpace(line) == `"""` {
+				multiline = false
+				commands = append(commands, *multilineCommand)
+			} else {
+				// Otherwise, append the line to the multiline string.
+				multilineCommand.Arg += "\n" + line
+			}
+			continue
+		}
+		fields := strings.Fields(line)
+		if len(fields) == 0 {
+			continue
+		}
+
+		command := Command{}
+		switch fields[0] {
+		case "FROM":
+			command.Name = "model"
+			command.Arg = fields[1]
+			if command.Arg == "" {
+				return nil, fmt.Errorf("no model specified in FROM line")
+			}
+			foundModel = true
+		case "PROMPT":
+			command.Name = "prompt"
+			if fields[1] == `"""` {
+				multiline = true
+				multilineCommand = &command
+				multilineCommand.Arg = ""
+			} else {
+				command.Arg = strings.Join(fields[1:], " ")
+			}
+		case "PARAMETER":
+			command.Name = fields[1]
+			command.Arg = strings.Join(fields[2:], " ")
+		default:
+			continue
+		}
+		if !multiline {
+			commands = append(commands, command)
+		}
+	}
+
+	if !foundModel {
+		return nil, fmt.Errorf("no FROM line for the model was specified")
+	}
+
+	if multiline {
+		return nil, fmt.Errorf("unclosed multiline string")
+	}
+	return commands, scanner.Err()
+}

+ 842 - 0
server/images.go

@@ -0,0 +1,842 @@
+package server
+
+import (
+	"bytes"
+	"crypto/sha256"
+	"encoding/hex"
+	"encoding/json"
+	"errors"
+	"fmt"
+	"io"
+	"io/ioutil"
+	"log"
+	"net/http"
+	"os"
+	"path"
+	"path/filepath"
+	"strconv"
+	"strings"
+
+	"github.com/jmorganca/ollama/api"
+	"github.com/jmorganca/ollama/parser"
+)
+
+var DefaultRegistry string = "https://registry.ollama.ai"
+
+type Model struct {
+	Name      string `json:"name"`
+	ModelPath string
+	Prompt    string
+	Options   api.Options
+}
+
+type ManifestV2 struct {
+	SchemaVersion int      `json:"schemaVersion"`
+	MediaType     string   `json:"mediaType"`
+	Config        Layer    `json:"config"`
+	Layers        []*Layer `json:"layers"`
+}
+
+type Layer struct {
+	MediaType string `json:"mediaType"`
+	Digest    string `json:"digest"`
+	Size      int    `json:"size"`
+}
+
+type LayerWithBuffer struct {
+	Layer
+
+	Buffer *bytes.Buffer
+}
+
+type ConfigV2 struct {
+	Architecture string `json:"architecture"`
+	OS           string `json:"os"`
+	RootFS       RootFS `json:"rootfs"`
+}
+
+type RootFS struct {
+	Type    string   `json:"type"`
+	DiffIDs []string `json:"diff_ids"`
+}
+
+func GetManifest(name string) (*ManifestV2, error) {
+	home, err := os.UserHomeDir()
+	if err != nil {
+		return nil, err
+	}
+
+	fp := filepath.Join(home, ".ollama/models/manifests", name)
+	_, err = os.Stat(fp)
+	if os.IsNotExist(err) {
+		return nil, fmt.Errorf("couldn't find model '%s'", name)
+	}
+
+	var manifest *ManifestV2
+
+	f, err := os.Open(fp)
+	if err != nil {
+		return nil, fmt.Errorf("couldn't open file '%s'", fp)
+	}
+
+	decoder := json.NewDecoder(f)
+	err = decoder.Decode(&manifest)
+	if err != nil {
+		return nil, err
+	}
+
+	return manifest, nil
+}
+
+func GetModel(name string) (*Model, error) {
+	home, err := os.UserHomeDir()
+	if err != nil {
+		return nil, err
+	}
+
+	manifest, err := GetManifest(name)
+	if err != nil {
+		return nil, err
+	}
+
+	model := &Model{
+		Name: name,
+	}
+
+	for _, layer := range manifest.Layers {
+		filename := filepath.Join(home, ".ollama/models/blobs", layer.Digest)
+		switch layer.MediaType {
+		case "application/vnd.ollama.image.model":
+			model.ModelPath = filename
+		case "application/vnd.ollama.image.prompt":
+			data, err := os.ReadFile(filename)
+			if err != nil {
+				return nil, err
+			}
+			model.Prompt = string(data)
+		case "application/vnd.ollama.image.params":
+			/*
+				f, err = os.Open(filename)
+				if err != nil {
+					return nil, err
+				}
+			*/
+
+			var opts api.Options
+			/*
+				decoder = json.NewDecoder(f)
+				err = decoder.Decode(&opts)
+				if err != nil {
+					return nil, err
+				}
+			*/
+			model.Options = opts
+		}
+	}
+
+	return model, nil
+}
+
+func getAbsPath(fn string) (string, error) {
+	if strings.HasPrefix(fn, "~/") {
+		home, err := os.UserHomeDir()
+		if err != nil {
+			log.Printf("error getting home directory: %v", err)
+			return "", err
+		}
+		fn = strings.Replace(fn, "~", home, 1)
+	}
+
+	return filepath.Abs(fn)
+}
+
+func CreateModel(name string, mf io.Reader, fn func(status string)) error {
+	fn("parsing modelfile")
+	commands, err := parser.Parse(mf)
+	if err != nil {
+		fn(fmt.Sprintf("error: %v", err))
+		return err
+	}
+
+	var layers []*LayerWithBuffer
+	param := make(map[string]string)
+
+	for _, c := range commands {
+		log.Printf("[%s] - %s\n", c.Name, c.Arg)
+		switch c.Name {
+		case "model":
+			fn("looking for model")
+			mf, err := GetManifest(c.Arg)
+			if err != nil {
+				// if we couldn't read the manifest, try getting the bin file
+				fp, err := getAbsPath(c.Arg)
+				if err != nil {
+					fn("error determing path. exiting.")
+					return err
+				}
+
+				fn("creating model layer")
+				file, err := os.Open(fp)
+				if err != nil {
+					fn(fmt.Sprintf("couldn't find model '%s'", c.Arg))
+					return fmt.Errorf("failed to open file: %v", err)
+				}
+				defer file.Close()
+
+				l, err := CreateLayer(file)
+				if err != nil {
+					fn(fmt.Sprintf("couldn't create model layer: %v", err))
+					return fmt.Errorf("failed to create layer: %v", err)
+				}
+				l.MediaType = "application/vnd.ollama.image.model"
+				layers = append(layers, l)
+			} else {
+				log.Printf("manifest = %#v", mf)
+				for _, l := range mf.Layers {
+					newLayer, err := GetLayerWithBufferFromLayer(l)
+					if err != nil {
+						fn(fmt.Sprintf("couldn't read layer: %v", err))
+						return err
+					}
+					layers = append(layers, newLayer)
+				}
+			}
+		case "prompt":
+			fn("creating prompt layer")
+			// remove the prompt layer if one exists
+			layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.prompt")
+
+			prompt := strings.NewReader(c.Arg)
+			l, err := CreateLayer(prompt)
+			if err != nil {
+				fn(fmt.Sprintf("couldn't create prompt layer: %v", err))
+				return fmt.Errorf("failed to create layer: %v", err)
+			}
+			l.MediaType = "application/vnd.ollama.image.prompt"
+			layers = append(layers, l)
+		default:
+			param[c.Name] = c.Arg
+		}
+	}
+
+	// Create a single layer for the parameters
+	fn("creating parameter layer")
+	if len(param) > 0 {
+		layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
+		paramData, err := paramsToReader(param)
+		if err != nil {
+			return fmt.Errorf("couldn't create params json: %v", err)
+		}
+		l, err := CreateLayer(paramData)
+		if err != nil {
+			return fmt.Errorf("failed to create layer: %v", err)
+		}
+		l.MediaType = "application/vnd.ollama.image.params"
+		layers = append(layers, l)
+	}
+
+	digests, err := getLayerDigests(layers)
+	if err != nil {
+		return err
+	}
+
+	var manifestLayers []*Layer
+	for _, l := range layers {
+		manifestLayers = append(manifestLayers, &l.Layer)
+	}
+
+	// Create a layer for the config object
+	fn("creating config layer")
+	cfg, err := createConfigLayer(digests)
+	if err != nil {
+		return err
+	}
+	layers = append(layers, cfg)
+
+	err = SaveLayers(layers, fn, false)
+	if err != nil {
+		fn(fmt.Sprintf("error saving layers: %v", err))
+		return err
+	}
+
+	// Create the manifest
+	fn("writing manifest")
+	err = CreateManifest(name, cfg, manifestLayers)
+	if err != nil {
+		fn(fmt.Sprintf("error creating manifest: %v", err))
+		return err
+	}
+
+	fn("success")
+	return nil
+}
+
+func removeLayerFromLayers(layers []*LayerWithBuffer, mediaType string) []*LayerWithBuffer {
+	j := 0
+	for _, l := range layers {
+		if l.MediaType != mediaType {
+			layers[j] = l
+			j++
+		}
+	}
+	return layers[:j]
+}
+
+func SaveLayers(layers []*LayerWithBuffer, fn func(status string), force bool) error {
+	home, err := os.UserHomeDir()
+	if err != nil {
+		log.Printf("error getting home directory: %v", err)
+		return err
+	}
+
+	dir := filepath.Join(home, ".ollama/models/blobs")
+
+	err = os.MkdirAll(dir, 0o700)
+	if err != nil {
+		return fmt.Errorf("make blobs directory: %w", err)
+	}
+
+	// Write each of the layers to disk
+	for _, layer := range layers {
+		fp := filepath.Join(dir, layer.Digest)
+
+		_, err = os.Stat(fp)
+		if os.IsNotExist(err) || force {
+			fn(fmt.Sprintf("writing layer %s", layer.Digest))
+			out, err := os.Create(fp)
+			if err != nil {
+				log.Printf("couldn't create %s", fp)
+				return err
+			}
+			defer out.Close()
+
+			_, err = io.Copy(out, layer.Buffer)
+			if err != nil {
+				return err
+			}
+		} else {
+			fn(fmt.Sprintf("using already created layer %s", layer.Digest))
+		}
+	}
+
+	return nil
+}
+
+func CreateManifest(name string, cfg *LayerWithBuffer, layers []*Layer) error {
+	home, err := os.UserHomeDir()
+	if err != nil {
+		log.Printf("error getting home directory: %v", err)
+		return err
+	}
+
+	manifest := ManifestV2{
+		SchemaVersion: 2,
+		MediaType:     "application/vnd.docker.distribution.manifest.v2+json",
+		Config: Layer{
+			MediaType: cfg.MediaType,
+			Size:      cfg.Size,
+			Digest:    cfg.Digest,
+		},
+		Layers: layers,
+	}
+
+	manifestJSON, err := json.Marshal(manifest)
+	if err != nil {
+		return err
+	}
+
+	fp := filepath.Join(home, ".ollama/models/manifests", name)
+	err = os.WriteFile(fp, manifestJSON, 0644)
+	if err != nil {
+		log.Printf("couldn't write to %s", fp)
+		return err
+	}
+	return nil
+}
+
+func GetLayerWithBufferFromLayer(layer *Layer) (*LayerWithBuffer, error) {
+	home, err := os.UserHomeDir()
+	if err != nil {
+		return nil, err
+	}
+
+	fp := filepath.Join(home, ".ollama/models/blobs", layer.Digest)
+	file, err := os.Open(fp)
+	if err != nil {
+		return nil, fmt.Errorf("could not open blob: %w", err)
+	}
+	defer file.Close()
+
+	newLayer, err := CreateLayer(file)
+	if err != nil {
+		return nil, err
+	}
+	newLayer.MediaType = layer.MediaType
+	return newLayer, nil
+}
+
+func paramsToReader(m map[string]string) (io.Reader, error) {
+	data, err := json.MarshalIndent(m, "", "  ")
+	if err != nil {
+		return nil, err
+	}
+
+	return strings.NewReader(string(data)), nil
+}
+
+func getLayerDigests(layers []*LayerWithBuffer) ([]string, error) {
+	var digests []string
+	for _, l := range layers {
+		if l.Digest == "" {
+			return nil, fmt.Errorf("layer is missing a digest")
+		}
+		digests = append(digests, l.Digest)
+	}
+	return digests, nil
+}
+
+// CreateLayer creates a Layer object from a given file
+func CreateLayer(f io.Reader) (*LayerWithBuffer, error) {
+	buf := new(bytes.Buffer)
+	_, err := io.Copy(buf, f)
+	if err != nil {
+		return nil, err
+	}
+
+	digest, size := GetSHA256Digest(buf)
+
+	layer := &LayerWithBuffer{
+		Layer: Layer{
+			MediaType: "application/vnd.docker.image.rootfs.diff.tar",
+			Digest:    digest,
+			Size:      size,
+		},
+		Buffer: buf,
+	}
+
+	return layer, nil
+}
+
+func PushModel(name, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error {
+	fn("retrieving manifest", "", 0, 0, 0)
+	manifest, err := GetManifest(name)
+	if err != nil {
+		fn("couldn't retrieve manifest", "", 0, 0, 0)
+		return err
+	}
+
+	var repoName string
+	var tag string
+
+	comps := strings.Split(name, ":")
+	switch {
+	case len(comps) < 1 || len(comps) > 2:
+		return fmt.Errorf("repository name was invalid")
+	case len(comps) == 1:
+		repoName = comps[0]
+		tag = "latest"
+	case len(comps) == 2:
+		repoName = comps[0]
+		tag = comps[1]
+	}
+
+	var layers []*Layer
+	var total int
+	var completed int
+	for _, layer := range manifest.Layers {
+		layers = append(layers, layer)
+		total += layer.Size
+	}
+	layers = append(layers, &manifest.Config)
+	total += manifest.Config.Size
+
+	for _, layer := range layers {
+		exists, err := checkBlobExistence(DefaultRegistry, repoName, layer.Digest, username, password)
+		if err != nil {
+			return err
+		}
+
+		if exists {
+			completed += layer.Size
+			fn("using existing layer", layer.Digest, total, completed, float64(completed)/float64(total))
+			continue
+		}
+
+		fn("starting upload", layer.Digest, total, completed, float64(completed)/float64(total))
+
+		location, err := startUpload(DefaultRegistry, repoName, username, password)
+		if err != nil {
+			log.Printf("couldn't start upload: %v", err)
+			return err
+		}
+
+		err = uploadBlob(location, layer, username, password)
+		if err != nil {
+			log.Printf("error uploading blob: %v", err)
+			return err
+		}
+		completed += layer.Size
+		fn("upload complete", layer.Digest, total, completed, float64(completed)/float64(total))
+	}
+
+	fn("pushing manifest", "", total, completed, float64(completed/total))
+	url := fmt.Sprintf("%s/v2/%s/manifests/%s", DefaultRegistry, repoName, tag)
+	headers := map[string]string{
+		"Content-Type": "application/vnd.docker.distribution.manifest.v2+json",
+	}
+
+	manifestJSON, err := json.Marshal(manifest)
+	if err != nil {
+		return err
+	}
+
+	resp, err := makeRequest("PUT", url, headers, bytes.NewReader(manifestJSON), username, password)
+	if err != nil {
+		return err
+	}
+	defer resp.Body.Close()
+
+	// Check for success: For a successful upload, the Docker registry will respond with a 201 Created
+	if resp.StatusCode != http.StatusCreated {
+		body, _ := io.ReadAll(resp.Body)
+		return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body))
+	}
+
+	fn("success", "", total, completed, 1.0)
+
+	return nil
+}
+
+func PullModel(name, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error {
+	var repoName string
+	var tag string
+
+	comps := strings.Split(name, ":")
+	switch {
+	case len(comps) < 1 || len(comps) > 2:
+		return fmt.Errorf("repository name was invalid")
+	case len(comps) == 1:
+		repoName = comps[0]
+		tag = "latest"
+	case len(comps) == 2:
+		repoName = comps[0]
+		tag = comps[1]
+	}
+
+	fn("pulling manifest", "", 0, 0, 0)
+
+	manifest, err := pullModelManifest(DefaultRegistry, repoName, tag, username, password)
+	if err != nil {
+		return fmt.Errorf("pull model manifest: %q", err)
+	}
+
+	log.Printf("manifest = %#v", manifest)
+
+	var layers []*Layer
+	var total int
+	var completed int
+	for _, layer := range manifest.Layers {
+		layers = append(layers, layer)
+		total += layer.Size
+	}
+	layers = append(layers, &manifest.Config)
+	total += manifest.Config.Size
+
+	for _, layer := range layers {
+		fn("starting download", layer.Digest, total, completed, float64(completed)/float64(total))
+		if err := downloadBlob(DefaultRegistry, repoName, layer.Digest, username, password, fn); err != nil {
+			fn(fmt.Sprintf("error downloading: %v", err), layer.Digest, 0, 0, 0)
+			return err
+		}
+		completed += layer.Size
+		fn("download complete", layer.Digest, total, completed, float64(completed)/float64(total))
+	}
+
+	fn("writing manifest", "", total, completed, 1.0)
+
+	home, err := os.UserHomeDir()
+	if err != nil {
+		return err
+	}
+
+	manifestJSON, err := json.Marshal(manifest)
+	if err != nil {
+		return err
+	}
+
+	fp := filepath.Join(home, ".ollama/models/manifests", name)
+
+	err = os.MkdirAll(path.Dir(fp), 0o700)
+	if err != nil {
+		return fmt.Errorf("make manifests directory: %w", err)
+	}
+
+	err = os.WriteFile(fp, manifestJSON, 0644)
+	if err != nil {
+		log.Printf("couldn't write to %s", fp)
+		return err
+	}
+
+	fn("success", "", total, completed, 1.0)
+
+	return nil
+}
+
+func pullModelManifest(registryURL, repoName, tag, username, password string) (*ManifestV2, error) {
+	url := fmt.Sprintf("%s/v2/%s/manifests/%s", registryURL, repoName, tag)
+	headers := map[string]string{
+		"Accept": "application/vnd.docker.distribution.manifest.v2+json",
+	}
+
+	resp, err := makeRequest("GET", url, headers, nil, username, password)
+	if err != nil {
+		log.Printf("couldn't get manifest: %v", err)
+		return nil, err
+	}
+	defer resp.Body.Close()
+
+	// Check for success: For a successful upload, the Docker registry will respond with a 201 Created
+	if resp.StatusCode != http.StatusOK {
+		body, _ := io.ReadAll(resp.Body)
+		return nil, fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body))
+	}
+
+	var m *ManifestV2
+	if err := json.NewDecoder(resp.Body).Decode(&m); err != nil {
+		return nil, err
+	}
+
+	return m, err
+}
+
+func createConfigLayer(layers []string) (*LayerWithBuffer, error) {
+	// TODO change architecture and OS
+	config := ConfigV2{
+		Architecture: "arm64",
+		OS:           "linux",
+		RootFS: RootFS{
+			Type:    "layers",
+			DiffIDs: layers,
+		},
+	}
+
+	configJSON, err := json.Marshal(config)
+	if err != nil {
+		return nil, err
+	}
+
+	buf := bytes.NewBuffer(configJSON)
+	digest, size := GetSHA256Digest(buf)
+
+	layer := &LayerWithBuffer{
+		Layer: Layer{
+			MediaType: "application/vnd.docker.container.image.v1+json",
+			Digest:    digest,
+			Size:      size,
+		},
+		Buffer: buf,
+	}
+	return layer, nil
+}
+
+// GetSHA256Digest returns the SHA256 hash of a given buffer and returns it, and the size of buffer
+func GetSHA256Digest(data *bytes.Buffer) (string, int) {
+	layerBytes := data.Bytes()
+	hash := sha256.Sum256(layerBytes)
+	return "sha256:" + hex.EncodeToString(hash[:]), len(layerBytes)
+}
+
+func startUpload(registryURL string, repositoryName string, username string, password string) (string, error) {
+	url := fmt.Sprintf("%s/v2/%s/blobs/uploads/", registryURL, repositoryName)
+
+	resp, err := makeRequest("POST", url, nil, nil, username, password)
+	if err != nil {
+		log.Printf("couldn't start upload: %v", err)
+		return "", err
+	}
+	defer resp.Body.Close()
+
+	// Check for success
+	if resp.StatusCode != http.StatusAccepted {
+		body, _ := io.ReadAll(resp.Body)
+		return "", fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body))
+	}
+
+	// Extract UUID location from header
+	location := resp.Header.Get("Location")
+	if location == "" {
+		return "", fmt.Errorf("location header is missing in response")
+	}
+
+	return location, nil
+}
+
+// Function to check if a blob already exists in the Docker registry
+func checkBlobExistence(registryURL string, repositoryName string, digest string, username string, password string) (bool, error) {
+	url := fmt.Sprintf("%s/v2/%s/blobs/%s", registryURL, repositoryName, digest)
+
+	resp, err := makeRequest("HEAD", url, nil, nil, username, password)
+	if err != nil {
+		log.Printf("couldn't check for blob: %v", err)
+		return false, err
+	}
+	defer resp.Body.Close()
+
+	// Check for success: If the blob exists, the Docker registry will respond with a 200 OK
+	return resp.StatusCode == http.StatusOK, nil
+}
+
+func uploadBlob(location string, layer *Layer, username string, password string) error {
+	home, err := os.UserHomeDir()
+	if err != nil {
+		return err
+	}
+
+	// Create URL
+	url := fmt.Sprintf("%s&digest=%s", location, layer.Digest)
+
+	headers := make(map[string]string)
+	headers["Content-Length"] = fmt.Sprintf("%d", layer.Size)
+	headers["Content-Type"] = "application/octet-stream"
+
+	// TODO change from monolithic uploads to chunked uploads
+	// TODO allow resumability
+	// TODO allow canceling uploads via DELETE
+	// TODO allow cross repo blob mount
+
+	fp := filepath.Join(home, ".ollama/models/blobs", layer.Digest)
+	f, err := os.Open(fp)
+	if err != nil {
+		return err
+	}
+
+	resp, err := makeRequest("PUT", url, headers, f, username, password)
+	if err != nil {
+		log.Printf("couldn't upload blob: %v", err)
+		return err
+	}
+	defer resp.Body.Close()
+
+	// Check for success: For a successful upload, the Docker registry will respond with a 201 Created
+	if resp.StatusCode != http.StatusCreated {
+		body, _ := io.ReadAll(resp.Body)
+		return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body))
+	}
+
+	return nil
+}
+
+func downloadBlob(registryURL, repoName, digest string, username, password string, fn func(status, digest string, Total, Completed int, Percent float64)) error {
+	home, err := os.UserHomeDir()
+	if err != nil {
+		return err
+	}
+
+	fp := filepath.Join(home, ".ollama/models/blobs", digest)
+
+	_, err = os.Stat(fp)
+	if !os.IsNotExist(err) {
+		// we already have the file, so return
+		log.Printf("already have %s\n", digest)
+		return nil
+	}
+
+	var size int64
+
+	fi, err := os.Stat(fp + "-partial")
+	switch {
+	case errors.Is(err, os.ErrNotExist):
+		// noop, file doesn't exist so create it
+	case err != nil:
+		return fmt.Errorf("stat: %w", err)
+	default:
+		size = fi.Size()
+	}
+
+	url := fmt.Sprintf("%s/v2/%s/blobs/%s", registryURL, repoName, digest)
+	headers := map[string]string{
+		"Range": fmt.Sprintf("bytes=%d-", size),
+	}
+
+	resp, err := makeRequest("GET", url, headers, nil, username, password)
+	if err != nil {
+		log.Printf("couldn't download blob: %v", err)
+		return err
+	}
+	defer resp.Body.Close()
+
+	if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent {
+		body, _ := ioutil.ReadAll(resp.Body)
+		return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body))
+	}
+
+	err = os.MkdirAll(path.Dir(fp), 0o700)
+	if err != nil {
+		return fmt.Errorf("make blobs directory: %w", err)
+	}
+
+	out, err := os.OpenFile(fp+"-partial", os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
+	if err != nil {
+		panic(err)
+	}
+	defer out.Close()
+
+	remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
+	completed := size
+	total := remaining + completed
+
+	for {
+		fn(fmt.Sprintf("Downloading %s", digest), digest, int(total), int(completed), float64(completed)/float64(total))
+		if completed >= total {
+			fmt.Printf("finished downloading\n")
+			err = os.Rename(fp+"-partial", fp)
+			if err != nil {
+				fmt.Printf("error: %v\n", err)
+				fn(fmt.Sprintf("error renaming file: %v", err), digest, int(total), int(completed), 1)
+				return err
+			}
+			break
+		}
+
+		n, err := io.CopyN(out, resp.Body, 8192)
+		if err != nil && !errors.Is(err, io.EOF) {
+			return err
+		}
+		completed += n
+	}
+
+	log.Printf("success getting %s\n", digest)
+	return nil
+}
+
+func makeRequest(method, url string, headers map[string]string, body io.Reader, username, password string) (*http.Response, error) {
+	req, err := http.NewRequest(method, url, body)
+	if err != nil {
+		return nil, err
+	}
+
+	for k, v := range headers {
+		req.Header.Set(k, v)
+	}
+
+	// TODO: better auth
+	if username != "" && password != "" {
+		req.SetBasicAuth(username, password)
+	}
+
+	client := &http.Client{
+		CheckRedirect: func(req *http.Request, via []*http.Request) error {
+			if len(via) >= 10 {
+				return fmt.Errorf("too many redirects")
+			}
+			log.Printf("redirected to: %s\n", req.URL)
+			return nil
+		},
+	}
+	resp, err := client.Do(req)
+	if err != nil {
+		return nil, err
+	}
+
+	return resp, nil
+}

+ 0 - 128
server/models.go

@@ -1,128 +0,0 @@
-package server
-
-import (
-	"encoding/json"
-	"errors"
-	"fmt"
-	"io"
-	"net/http"
-	"os"
-	"path/filepath"
-	"strconv"
-)
-
-const directoryURL = "https://ollama.ai/api/models"
-
-type Model struct {
-	Name             string `json:"name"`
-	DisplayName      string `json:"display_name"`
-	Parameters       string `json:"parameters"`
-	URL              string `json:"url"`
-	ShortDescription string `json:"short_description"`
-	Description      string `json:"description"`
-	PublishedBy      string `json:"published_by"`
-	OriginalAuthor   string `json:"original_author"`
-	OriginalURL      string `json:"original_url"`
-	License          string `json:"license"`
-}
-
-func (m *Model) FullName() string {
-	home, err := os.UserHomeDir()
-	if err != nil {
-		panic(err)
-	}
-
-	return filepath.Join(home, ".ollama", "models", m.Name+".bin")
-}
-
-func (m *Model) TempFile() string {
-	fullName := m.FullName()
-	return filepath.Join(
-		filepath.Dir(fullName),
-		fmt.Sprintf(".%s.part", filepath.Base(fullName)),
-	)
-}
-
-func getRemote(model string) (*Model, error) {
-	// resolve the model download from our directory
-	resp, err := http.Get(directoryURL)
-	if err != nil {
-		return nil, fmt.Errorf("failed to get directory: %w", err)
-	}
-	defer resp.Body.Close()
-	body, err := io.ReadAll(resp.Body)
-	if err != nil {
-		return nil, fmt.Errorf("failed to read directory: %w", err)
-	}
-	var models []Model
-	err = json.Unmarshal(body, &models)
-	if err != nil {
-		return nil, fmt.Errorf("failed to parse directory: %w", err)
-	}
-	for _, m := range models {
-		if m.Name == model {
-			return &m, nil
-		}
-	}
-	return nil, fmt.Errorf("model not found in directory: %s", model)
-}
-
-func saveModel(model *Model, fn func(total, completed int64)) error {
-	// this models cache directory is created by the server on startup
-
-	client := &http.Client{}
-	req, err := http.NewRequest("GET", model.URL, nil)
-	if err != nil {
-		return fmt.Errorf("failed to download model: %w", err)
-	}
-
-	var size int64
-
-	// completed file doesn't exist, check partial file
-	fi, err := os.Stat(model.TempFile())
-	switch {
-	case errors.Is(err, os.ErrNotExist):
-		// noop, file doesn't exist so create it
-	case err != nil:
-		return fmt.Errorf("stat: %w", err)
-	default:
-		size = fi.Size()
-	}
-
-	req.Header.Add("Range", fmt.Sprintf("bytes=%d-", size))
-
-	resp, err := client.Do(req)
-	if err != nil {
-		return fmt.Errorf("failed to download model: %w", err)
-	}
-	defer resp.Body.Close()
-
-	if resp.StatusCode >= 400 {
-		return fmt.Errorf("failed to download model: %s", resp.Status)
-	}
-
-	out, err := os.OpenFile(model.TempFile(), os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0o644)
-	if err != nil {
-		panic(err)
-	}
-	defer out.Close()
-
-	remaining, _ := strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
-	completed := size
-
-	total := remaining + completed
-
-	for {
-		fn(total, completed)
-		if completed >= total {
-			return os.Rename(model.TempFile(), model.FullName())
-		}
-
-		n, err := io.CopyN(out, resp.Body, 8192)
-		if err != nil && !errors.Is(err, io.EOF) {
-			return err
-		}
-
-		completed += n
-	}
-}

+ 86 - 64
server/routes.go

@@ -1,12 +1,10 @@
 package server
 
 import (
-	"embed"
 	"encoding/json"
-	"errors"
+	"fmt"
 	"io"
 	"log"
-	"math"
 	"net"
 	"net/http"
 	"os"
@@ -16,16 +14,11 @@ import (
 	"time"
 
 	"github.com/gin-gonic/gin"
-	"github.com/lithammer/fuzzysearch/fuzzy"
 
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/llama"
 )
 
-//go:embed templates/*
-var templatesFS embed.FS
-var templates = template.Must(template.ParseFS(templatesFS, "templates/*.prompt"))
-
 func cacheDir() string {
 	home, err := os.UserHomeDir()
 	if err != nil {
@@ -40,6 +33,7 @@ func generate(c *gin.Context) {
 
 	req := api.GenerateRequest{
 		Options: api.DefaultOptions(),
+		Prompt:  "",
 	}
 
 	if err := c.ShouldBindJSON(&req); err != nil {
@@ -47,34 +41,28 @@ func generate(c *gin.Context) {
 		return
 	}
 
-	if remoteModel, _ := getRemote(req.Model); remoteModel != nil {
-		req.Model = remoteModel.FullName()
-	}
-	if _, err := os.Stat(req.Model); err != nil {
-		if !errors.Is(err, os.ErrNotExist) {
-			c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
-			return
-		}
-		req.Model = filepath.Join(cacheDir(), "models", req.Model+".bin")
+	model, err := GetModel(req.Model)
+	if err != nil {
+		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+		return
 	}
 
-	templateNames := make([]string, 0, len(templates.Templates()))
-	for _, template := range templates.Templates() {
-		templateNames = append(templateNames, template.Name())
+	templ, err := template.New("").Parse(model.Prompt)
+	if err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
 	}
 
-	match, _ := matchRankOne(filepath.Base(req.Model), templateNames)
-	if template := templates.Lookup(match); template != nil {
-		var sb strings.Builder
-		if err := template.Execute(&sb, req); err != nil {
-			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
-			return
-		}
-
-		req.Prompt = sb.String()
+	var sb strings.Builder
+	if err = templ.Execute(&sb, req); err != nil {
+		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+		return
 	}
+	req.Prompt = sb.String()
 
-	llm, err := llama.New(req.Model, req.Options)
+	fmt.Printf("prompt = >>>%s<<<\n", req.Prompt)
+
+	llm, err := llama.New(model.ModelPath, req.Options)
 	if err != nil {
 		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 		return
@@ -105,40 +93,84 @@ func pull(c *gin.Context) {
 		return
 	}
 
-	remote, err := getRemote(req.Model)
-	if err != nil {
-		c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
+	ch := make(chan any)
+	go func() {
+		defer close(ch)
+		fn := func(status, digest string, total, completed int, percent float64) {
+			ch <- api.PullProgress{
+				Status:    status,
+				Digest:    digest,
+				Total:     total,
+				Completed: completed,
+				Percent:   percent,
+			}
+		}
+		if err := PullModel(req.Name, req.Username, req.Password, fn); err != nil {
+			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			return
+		}
+	}()
+
+	streamResponse(c, ch)
+}
+
+func push(c *gin.Context) {
+	var req api.PushRequest
+	if err := c.ShouldBindJSON(&req); err != nil {
+		c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
 		return
 	}
 
-	// check if completed file exists
-	fi, err := os.Stat(remote.FullName())
-	switch {
-	case errors.Is(err, os.ErrNotExist):
-		// noop, file doesn't exist so create it
-	case err != nil:
-		c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+	ch := make(chan any)
+	go func() {
+		defer close(ch)
+		fn := func(status, digest string, total, completed int, percent float64) {
+			ch <- api.PushProgress{
+				Status:    status,
+				Digest:    digest,
+				Total:     total,
+				Completed: completed,
+				Percent:   percent,
+			}
+		}
+		if err := PushModel(req.Name, req.Username, req.Password, fn); err != nil {
+			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+			return
+		}
+	}()
+
+	streamResponse(c, ch)
+}
+
+func create(c *gin.Context) {
+	var req api.CreateRequest
+	if err := c.ShouldBindJSON(&req); err != nil {
+		c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
 		return
-	default:
-		c.JSON(http.StatusOK, api.PullProgress{
-			Total:     fi.Size(),
-			Completed: fi.Size(),
-			Percent:   100,
-		})
+	}
+
+	// NOTE consider passing the entire Modelfile in the json instead of the path to it
 
+	file, err := os.Open(req.Path)
+	if err != nil {
+		c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
 		return
 	}
+	defer file.Close()
 
 	ch := make(chan any)
 	go func() {
 		defer close(ch)
-		saveModel(remote, func(total, completed int64) {
-			ch <- api.PullProgress{
-				Total:     total,
-				Completed: completed,
-				Percent:   float64(completed) / float64(total) * 100,
+		fn := func(status string) {
+			ch <- api.CreateProgress{
+				Status: status,
 			}
-		})
+		}
+
+		if err := CreateModel(req.Name, file, fn); err != nil {
+			c.JSON(http.StatusBadRequest, gin.H{"message": err.Error()})
+			return
+		}
 	}()
 
 	streamResponse(c, ch)
@@ -153,6 +185,8 @@ func Serve(ln net.Listener) error {
 
 	r.POST("/api/pull", pull)
 	r.POST("/api/generate", generate)
+	r.POST("/api/create", create)
+	r.POST("/api/push", push)
 
 	log.Printf("Listening on %s", ln.Addr())
 	s := &http.Server{
@@ -162,18 +196,6 @@ func Serve(ln net.Listener) error {
 	return s.Serve(ln)
 }
 
-func matchRankOne(source string, targets []string) (bestMatch string, bestRank int) {
-	bestRank = math.MaxInt
-	for _, target := range targets {
-		if rank := fuzzy.LevenshteinDistance(source, target); bestRank > rank {
-			bestRank = rank
-			bestMatch = target
-		}
-	}
-
-	return
-}
-
 func streamResponse(c *gin.Context, ch chan any) {
 	c.Stream(func(w io.Writer) bool {
 		val, ok := <-ch