Bläddra i källkod

changes to the parser, FROM line, and fix commands

Patrick Devine 1 år sedan
förälder
incheckning
0573eae4b4
3 ändrade filer med 174 tillägg och 74 borttagningar
  1. 0 7
      cmd/cmd.go
  2. 9 4
      parser/parser.go
  3. 165 63
      server/images.go

+ 0 - 7
cmd/cmd.go

@@ -292,12 +292,5 @@ func NewCLI() *cobra.Command {
 		pushCmd,
 	)
 
-	rootCmd.AddCommand(
-		serveCmd,
-		createCmd,
-		runCmd,
-		pullCmd,
-	)
-
 	return rootCmd
 }

+ 9 - 4
parser/parser.go

@@ -14,6 +14,7 @@ type Command struct {
 
 func Parse(reader io.Reader) ([]Command, error) {
 	var commands []Command
+	var foundModel bool
 
 	scanner := bufio.NewScanner(reader)
 	multiline := false
@@ -39,12 +40,12 @@ func Parse(reader io.Reader) ([]Command, error) {
 		command := Command{}
 		switch fields[0] {
 		case "FROM":
-			// TODO - support only one of FROM or MODELFILE
-			command.Name = "image"
-			command.Arg = fields[1]
-		case "MODELFILE":
 			command.Name = "model"
 			command.Arg = fields[1]
+			if command.Arg == "" {
+				return nil, fmt.Errorf("no model specified in FROM line")
+			}
+			foundModel = true
 		case "PROMPT":
 			command.Name = "prompt"
 			if fields[1] == `"""` {
@@ -65,6 +66,10 @@ func Parse(reader io.Reader) ([]Command, error) {
 		}
 	}
 
+	if !foundModel {
+		return nil, fmt.Errorf("no FROM line for the model was specified")
+	}
+
 	if multiline {
 		return nil, fmt.Errorf("unclosed multiline string")
 	}

+ 165 - 63
server/images.go

@@ -12,14 +12,16 @@ import (
 	"net/http"
 	"os"
 	"path"
+	"path/filepath"
 	"strings"
 
 	"github.com/jmorganca/ollama/api"
 	"github.com/jmorganca/ollama/parser"
 )
 
-// var DefaultRegistry string = "https://registry.ollama.ai"
-var DefaultRegistry string = "http://localhost:6000"
+var DefaultRegistry string = "https://registry.ollama.ai"
+
+//var DefaultRegistry string = "http://localhost:6000"
 
 type ManifestV2 struct {
 	SchemaVersion int      `json:"schemaVersion"`
@@ -57,17 +59,17 @@ func GetManifest(name string) (*ManifestV2, error) {
 		return nil, err
 	}
 
-	filepath := path.Join(home, ".ollama/models/manifests", name)
-	_, err = os.Stat(filepath)
+	fp := path.Join(home, ".ollama/models/manifests", name)
+	_, err = os.Stat(fp)
 	if os.IsNotExist(err) {
 		return nil, fmt.Errorf("couldn't find model '%s'", name)
 	}
 
 	var manifest *ManifestV2
 
-	f, err := os.Open(filepath)
+	f, err := os.Open(fp)
 	if err != nil {
-		return nil, fmt.Errorf("couldn't open file '%s'", filepath)
+		return nil, fmt.Errorf("couldn't open file '%s'", fp)
 	}
 
 	decoder := json.NewDecoder(f)
@@ -132,10 +134,24 @@ func GetModel(name string) (*Model, error) {
 	return model, nil
 }
 
+func getAbsPath(fn string) (string, error) {
+	if strings.HasPrefix(fn, "~/") {
+		home, err := os.UserHomeDir()
+		if err != nil {
+			log.Printf("error getting home directory: %v", err)
+			return "", err
+		}
+		fn = strings.Replace(fn, "~", home, 1)
+	}
+
+	return filepath.Abs(fn)
+}
+
 func CreateModel(name string, mf io.Reader, fn func(status string)) error {
 	fn("parsing modelfile")
 	commands, err := parser.Parse(mf)
 	if err != nil {
+		fn(fmt.Sprintf("error: %v", err))
 		return err
 	}
 
@@ -147,24 +163,51 @@ func CreateModel(name string, mf io.Reader, fn func(status string)) error {
 		log.Printf("[%s] - %s\n", c.Name, c.Arg)
 		switch c.Name {
 		case "model":
-			fn("creating model layer")
-			file, err := os.Open(c.Arg)
+			fn("looking for model")
+			mf, err := GetManifest(c.Arg)
 			if err != nil {
-				return fmt.Errorf("failed to open file: %v", err)
-			}
-			defer file.Close()
+				// if we couldn't read the manifest, try getting the bin file
+				fp, err := getAbsPath(c.Arg)
+				if err != nil {
+					fn("error determing path. exiting.")
+					return err
+				}
 
-			l, err := CreateLayer(file)
-			if err != nil {
-				return fmt.Errorf("failed to create layer: %v", err)
+				fn("creating model layer")
+				file, err := os.Open(fp)
+				if err != nil {
+					fn(fmt.Sprintf("couldn't find model '%s'", c.Arg))
+					return fmt.Errorf("failed to open file: %v", err)
+				}
+				defer file.Close()
+
+				l, err := CreateLayer(file)
+				if err != nil {
+					fn(fmt.Sprintf("couldn't create model layer: %v", err))
+					return fmt.Errorf("failed to create layer: %v", err)
+				}
+				l.MediaType = "application/vnd.ollama.image.model"
+				layers = append(layers, l)
+			} else {
+				log.Printf("manifest = %#v", mf)
+				for _, l := range mf.Layers {
+					newLayer, err := GetLayerWithBufferFromLayer(l)
+					if err != nil {
+						fn(fmt.Sprintf("couldn't read layer: %v", err))
+						return err
+					}
+					layers = append(layers, newLayer)
+				}
 			}
-			l.MediaType = "application/vnd.ollama.image.model"
-			layers = append(layers, l)
 		case "prompt":
 			fn("creating prompt layer")
+			// remove the prompt layer if one exists
+			layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.prompt")
+
 			prompt := strings.NewReader(c.Arg)
 			l, err := CreateLayer(prompt)
 			if err != nil {
+				fn(fmt.Sprintf("couldn't create prompt layer: %v", err))
 				return fmt.Errorf("failed to create layer: %v", err)
 			}
 			l.MediaType = "application/vnd.ollama.image.prompt"
@@ -176,22 +219,30 @@ func CreateModel(name string, mf io.Reader, fn func(status string)) error {
 
 	// Create a single layer for the parameters
 	fn("creating parameter layer")
-	paramData, err := paramsToReader(param)
-	if err != nil {
-		return fmt.Errorf("couldn't create params json: %v", err)
-	}
-	l, err := CreateLayer(paramData)
-	if err != nil {
-		return fmt.Errorf("failed to create layer: %v", err)
+	if len(param) > 0 {
+		layers = removeLayerFromLayers(layers, "application/vnd.ollama.image.params")
+		paramData, err := paramsToReader(param)
+		if err != nil {
+			return fmt.Errorf("couldn't create params json: %v", err)
+		}
+		l, err := CreateLayer(paramData)
+		if err != nil {
+			return fmt.Errorf("failed to create layer: %v", err)
+		}
+		l.MediaType = "application/vnd.ollama.image.params"
+		layers = append(layers, l)
 	}
-	l.MediaType = "application/vnd.ollama.image.params"
-	layers = append(layers, l)
 
 	digests, err := getLayerDigests(layers)
 	if err != nil {
 		return err
 	}
 
+	var manifestLayers []*Layer
+	for _, l := range layers {
+		manifestLayers = append(manifestLayers, &l.Layer)
+	}
+
 	// Create a layer for the config object
 	fn("creating config layer")
 	cfg, err := createConfigLayer(digests)
@@ -200,25 +251,52 @@ func CreateModel(name string, mf io.Reader, fn func(status string)) error {
 	}
 	layers = append(layers, cfg)
 
-	home, err := os.UserHomeDir()
+	err = SaveLayers(layers, fn, false)
 	if err != nil {
+		fn(fmt.Sprintf("error saving layers: %v", err))
 		return err
 	}
 
-	var manifestLayers []*Layer
+	// Create the manifest
+	fn("writing manifest")
+	err = CreateManifest(name, cfg, manifestLayers)
+	if err != nil {
+		fn(fmt.Sprintf("error creating manifest: %v", err))
+		return err
+	}
+
+	fn("success")
+	return nil
+}
+
+func removeLayerFromLayers(layers []*LayerWithBuffer, mediaType string) []*LayerWithBuffer {
+	j := 0
+	for _, l := range layers {
+		if l.MediaType != mediaType {
+			layers[j] = l
+			j++
+		}
+	}
+	return layers[:j]
+}
+
+func SaveLayers(layers []*LayerWithBuffer, fn func(status string), force bool) error {
+	home, err := os.UserHomeDir()
+	if err != nil {
+		log.Printf("error getting home directory: %v", err)
+		return err
+	}
 
 	// Write each of the layers to disk
 	for _, layer := range layers {
-		filepath := path.Join(home, ".ollama/models/blobs", layer.Digest)
+		fp := path.Join(home, ".ollama/models/blobs", layer.Digest)
 
-		// TODO add a force flag to always write out the layers
-
-		_, err = os.Stat(filepath)
-		if os.IsNotExist(err) {
+		_, err = os.Stat(fp)
+		if os.IsNotExist(err) || force {
 			fn(fmt.Sprintf("writing layer %s", layer.Digest))
-			out, err := os.Create(filepath)
+			out, err := os.Create(fp)
 			if err != nil {
-				log.Printf("couldn't create %s", filepath)
+				log.Printf("couldn't create %s", fp)
 				return err
 			}
 			defer out.Close()
@@ -230,22 +308,18 @@ func CreateModel(name string, mf io.Reader, fn func(status string)) error {
 		} else {
 			fn(fmt.Sprintf("using already created layer %s", layer.Digest))
 		}
+	}
 
-		if layer.MediaType == "application/vnd.docker.container.image.v1+json" {
-			continue
-		}
-
-		manifestLayer := &Layer{
-			MediaType: layer.MediaType,
-			Size:      layer.Size,
-			Digest:    layer.Digest,
-		}
+	return nil
+}
 
-		manifestLayers = append(manifestLayers, manifestLayer)
+func CreateManifest(name string, cfg *LayerWithBuffer, layers []*Layer) error {
+	home, err := os.UserHomeDir()
+	if err != nil {
+		log.Printf("error getting home directory: %v", err)
+		return err
 	}
 
-	// Create the manifest
-	fn("writing manifest")
 	manifest := ManifestV2{
 		SchemaVersion: 2,
 		MediaType:     "application/vnd.docker.distribution.manifest.v2+json",
@@ -254,7 +328,7 @@ func CreateModel(name string, mf io.Reader, fn func(status string)) error {
 			Size:      cfg.Size,
 			Digest:    cfg.Digest,
 		},
-		Layers: manifestLayers,
+		Layers: layers,
 	}
 
 	manifestJSON, err := json.Marshal(manifest)
@@ -262,17 +336,36 @@ func CreateModel(name string, mf io.Reader, fn func(status string)) error {
 		return err
 	}
 
-	filepath := path.Join(home, ".ollama/models/manifests", name)
-	err = os.WriteFile(filepath, manifestJSON, 0644)
+	fp := path.Join(home, ".ollama/models/manifests", name)
+	err = os.WriteFile(fp, manifestJSON, 0644)
 	if err != nil {
-		log.Printf("couldn't write to %s", filepath)
+		log.Printf("couldn't write to %s", fp)
 		return err
 	}
-
-	fn("success")
 	return nil
 }
 
+func GetLayerWithBufferFromLayer(layer *Layer) (*LayerWithBuffer, error) {
+	home, err := os.UserHomeDir()
+	if err != nil {
+		return nil, err
+	}
+
+	fp := path.Join(home, ".ollama/models/blobs", layer.Digest)
+	file, err := os.Open(fp)
+	defer file.Close()
+	if err != nil {
+		return nil, err
+	}
+
+	newLayer, err := CreateLayer(file)
+	if err != nil {
+		return nil, err
+	}
+	newLayer.MediaType = layer.MediaType
+	return newLayer, nil
+}
+
 func paramsToReader(m map[string]string) (io.Reader, error) {
 	data, err := json.MarshalIndent(m, "", "  ")
 	if err != nil {
@@ -429,6 +522,8 @@ func PullModel(name, username, password string, fn func(status, digest string, T
 		return err
 	}
 
+	log.Printf("manifest = %#v", manifest)
+
 	var layers []*Layer
 	var total int
 	var completed int
@@ -460,10 +555,10 @@ func PullModel(name, username, password string, fn func(status, digest string, T
 		return err
 	}
 
-	filepath := path.Join(home, ".ollama/models/manifests", name)
-	err = os.WriteFile(filepath, manifestJSON, 0644)
+	fp := path.Join(home, ".ollama/models/manifests", name)
+	err = os.WriteFile(fp, manifestJSON, 0644)
 	if err != nil {
-		log.Printf("couldn't write to %s", filepath)
+		log.Printf("couldn't write to %s", fp)
 		return err
 	}
 
@@ -594,8 +689,8 @@ func uploadBlob(location string, layer *Layer, username string, password string)
 	// TODO allow canceling uploads via DELETE
 	// TODO allow cross repo blob mount
 
-	filepath := path.Join(home, ".ollama/models/blobs", layer.Digest)
-	f, err := os.Open(filepath)
+	fp := path.Join(home, ".ollama/models/blobs", layer.Digest)
+	f, err := os.Open(fp)
 	if err != nil {
 		return err
 	}
@@ -622,9 +717,9 @@ func downloadBlob(registryURL, repoName, digest, username, password string) erro
 		return err
 	}
 
-	filepath := path.Join(home, ".ollama/models/blobs", digest)
+	fp := path.Join(home, ".ollama/models/blobs", digest)
 
-	_, err = os.Stat(filepath)
+	_, err = os.Stat(fp)
 	if !os.IsNotExist(err) {
 		// we already have the file, so return
 		log.Printf("already have %s\n", digest)
@@ -641,7 +736,6 @@ func downloadBlob(registryURL, repoName, digest, username, password string) erro
 	}
 	defer resp.Body.Close()
 
-	// TODO: handle 307 redirects
 	// TODO: handle range requests to make this resumable
 
 	if resp.StatusCode != http.StatusOK {
@@ -649,9 +743,9 @@ func downloadBlob(registryURL, repoName, digest, username, password string) erro
 		return fmt.Errorf("registry responded with code %d: %v", resp.StatusCode, string(body))
 	}
 
-	out, err := os.Create(filepath)
+	out, err := os.Create(fp)
 	if err != nil {
-		log.Printf("couldn't create %s", filepath)
+		log.Printf("couldn't create %s", fp)
 		return err
 	}
 	defer out.Close()
@@ -680,7 +774,15 @@ func makeRequest(method, url string, headers map[string]string, body io.Reader,
 		req.SetBasicAuth(username, password)
 	}
 
-	client := &http.Client{}
+	client := &http.Client{
+		CheckRedirect: func(req *http.Request, via []*http.Request) error {
+			if len(via) >= 10 {
+				return fmt.Errorf("too many redirects")
+			}
+			log.Printf("redirected to: %s", req.URL)
+			return nil
+		},
+	}
 	resp, err := client.Do(req)
 	if err != nil {
 		return nil, err