Browse Source

testing and FROM local version copy

Josh Yan 9 months ago
parent
commit
aaa1c08a5d
3 changed files with 97 additions and 13 deletions
  1. 4 2
      server/images.go
  2. 10 9
      server/model.go
  3. 83 2
      server/routes_create_test.go

+ 4 - 2
server/images.go

@@ -385,7 +385,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
 		case "model", "adapter":
 			var baseLayers []*layerGGML
 			if name := model.ParseName(c.Args); name.IsValid() {
-				baseLayers, err = parseFromModel(ctx, name, fn)
+				baseLayers, version, err = parseFromModel(ctx, name, fn)
 				if err != nil {
 					return err
 				}
@@ -531,7 +531,9 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
 
 			messages = append(messages, &api.Message{Role: role, Content: content})
 		case "ollama":
-			version = c.Args
+			if version == "" {
+				version = c.Args
+			}
 		default:
 			ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}})
 			if err != nil {

+ 10 - 9
server/model.go

@@ -30,26 +30,27 @@ type layerGGML struct {
 	*llm.GGML
 }
 
-func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
+func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerGGML, version string, err error) {
 	m, err := ParseNamedManifest(name)
 	switch {
 	case errors.Is(err, os.ErrNotExist):
 		if err := PullModel(ctx, name.String(), &registryOptions{}, fn); err != nil {
-			return nil, err
+			return nil, version, err
 		}
 
 		m, err = ParseNamedManifest(name)
 		if err != nil {
-			return nil, err
+			return nil, version, err
 		}
 	case err != nil:
-		return nil, err
+		return nil, version, err
 	}
 
+	version = m.Ollama
 	for _, layer := range m.Layers {
 		layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
 		if err != nil {
-			return nil, err
+			return nil, version, err
 		}
 
 		switch layer.MediaType {
@@ -58,18 +59,18 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
 			"application/vnd.ollama.image.adapter":
 			blobpath, err := GetBlobsPath(layer.Digest)
 			if err != nil {
-				return nil, err
+				return nil, version, err
 			}
 
 			blob, err := os.Open(blobpath)
 			if err != nil {
-				return nil, err
+				return nil, version, err
 			}
 			defer blob.Close()
 
 			ggml, _, err := llm.DecodeGGML(blob, 0)
 			if err != nil {
-				return nil, err
+				return nil, version, err
 			}
 
 			layers = append(layers, &layerGGML{layer, ggml})
@@ -78,7 +79,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
 		}
 	}
 
-	return layers, nil
+	return layers, version, nil
 }
 
 func extractFromZipFile(p string, file *os.File, fn func(api.ProgressResponse)) error {

+ 83 - 2
server/routes_create_test.go

@@ -632,7 +632,7 @@ func TestCreateVersion(t *testing.T){
 	envconfig.LoadConfig()
 	var s Server
 
-	 w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
+	w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
 		Name:      "test",
 		Modelfile: fmt.Sprintf("FROM %s\nOLLAMA 0.2.3\nLICENSE MIT\nLICENSE Apache-2.0", createBinFile(t, nil, nil)),
 		Stream:    &stream,
@@ -642,9 +642,59 @@ func TestCreateVersion(t *testing.T){
 		t.Fatalf("expected status code 200, actual %d", w.Code)
 	} 
 
+	checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
+		filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
+	})
+	
+	f, err := os.Open(filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"))
+	if err != nil {
+		t.Fatal(err)
+	}
+	bts := json.NewDecoder(f)
+
+	var m Manifest
+	if err := bts.Decode(&m); err != nil {
+		t.Fatal(err)
+	}
+
+	if m.Ollama != "0.2.3" {
+		t.Errorf("got %s != want 0.2.3", m.Ollama)
+	}
+
+	t.Run("no version", func(t *testing.T) {
+		w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
+			Name:      "noversion",
+			Modelfile: fmt.Sprintf("FROM %s\nLICENSE MIT\nLICENSE Apache-2.0", createBinFile(t, nil, nil)),
+			Stream:    &stream,
+		})
+		
+		if w.Code != http.StatusOK {
+			t.Fatalf("expected status code 200, actual %d", w.Code)
+		}
+		
+		checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "noversion", "*"), []string{
+			filepath.Join(p, "manifests", "registry.ollama.ai", "library", "noversion", "latest"),
+		})
+
+		f, err := os.Open(filepath.Join(p, "manifests", "registry.ollama.ai", "library", "noversion", "latest"))
+		if err != nil {
+			t.Fatal(err)
+		}
+
+		bts := json.NewDecoder(f)
+		var m Manifest
+		if err := bts.Decode(&m); err != nil {
+			t.Fatal(err)
+		}
+		
+		if m.Ollama != "" {
+			t.Errorf("got %s != want \"\"", m.Ollama)
+		}
+	})
+
 	t.Run("invalid version", func(t *testing.T) {
 		w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
-			Name:      "test",
+			Name:      "invalid",
 			Modelfile: fmt.Sprintf("FROM %s\nOLLAMA 0..400", createBinFile(t, nil, nil)),
 			Stream:    &stream,
 		})
@@ -653,4 +703,35 @@ func TestCreateVersion(t *testing.T){
 			t.Fatalf("expected status code 400, actual %d", w.Code)
 		}
 	})
+
+	t.Run("from valid version", func(t *testing.T) {
+		w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
+			Name:      "fromvalid",
+			Modelfile: "FROM test",
+			Stream:    &stream,
+		})
+
+		if w.Code != http.StatusOK {
+			t.Fatalf("expected status code 200, actual %d", w.Code)
+		} 
+	
+		checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "fromvalid", "*"), []string{
+			filepath.Join(p, "manifests", "registry.ollama.ai", "library", "fromvalid", "latest"),
+		})
+		
+		f, err := os.Open(filepath.Join(p, "manifests", "registry.ollama.ai", "library", "fromvalid", "latest"))
+		if err != nil {
+			t.Fatal(err)
+		}
+		bts := json.NewDecoder(f)
+	
+		var m Manifest
+		if err := bts.Decode(&m); err != nil {
+			t.Fatal(err)
+		}
+	
+		if m.Ollama != "0.2.3" {
+			t.Errorf("got %s != want 0.2.3", m.Ollama)
+		}
+	})
 }