Browse Source

OLLAMA version

Josh Yan 9 months ago
parent
commit
ab9dfbddea
7 changed files with 88 additions and 5 deletions
  1. 1 0
      go.mod
  2. 2 0
      go.sum
  3. 45 1
      parser/parser.go
  4. 5 2
      server/images.go
  5. 3 1
      server/manifest.go
  6. 31 0
      server/routes_create_test.go
  7. 1 1
      server/routes_delete_test.go

+ 1 - 0
go.mod

@@ -25,6 +25,7 @@ require (
 )
 
 require (
+	github.com/Masterminds/semver/v3 v3.2.1 // indirect
 	github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 // indirect
 	github.com/bytedance/sonic/loader v0.1.1 // indirect
 	github.com/chewxy/hm v1.0.0 // indirect

+ 2 - 0
go.sum

@@ -4,6 +4,8 @@ dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7
 gioui.org v0.0.0-20210308172011-57750fc8a0a6/go.mod h1:RSH6KIUZ0p2xy5zHDxgAM4zumjgTw83q2ge/PI+yyw8=
 github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
 github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
+github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0=
+github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ=
 github.com/agnivade/levenshtein v1.1.1 h1:QY8M92nrzkmr798gCo3kmMyqXFzdQVpxLlGPRBij0P8=
 github.com/agnivade/levenshtein v1.1.1/go.mod h1:veldBMzWxcCG2ZvUTKD2kJNRdCk5hVbJomOvKkmgYbo=
 github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw=

+ 45 - 1
parser/parser.go

@@ -9,6 +9,7 @@ import (
 	"strconv"
 	"strings"
 
+	"github.com/Masterminds/semver/v3"
 	"golang.org/x/text/encoding/unicode"
 	"golang.org/x/text/transform"
 )
@@ -41,6 +42,8 @@ func (c Command) String() string {
 	case "message":
 		role, message, _ := strings.Cut(c.Args, ": ")
 		fmt.Fprintf(&sb, "MESSAGE %s %s", role, quote(message))
+	case "ollama":
+		fmt.Fprintf(&sb, "OLLAMA %s", quote(c.Args))
 	default:
 		fmt.Fprintf(&sb, "PARAMETER %s %s", c.Name, quote(c.Args))
 	}
@@ -57,12 +60,14 @@ const (
 	stateParameter
 	stateMessage
 	stateComment
+	stateVersion
 )
 
 var (
 	errMissingFrom        = errors.New("no FROM line")
 	errInvalidMessageRole = errors.New("message role must be one of \"system\", \"user\", or \"assistant\"")
 	errInvalidCommand     = errors.New("command must be one of \"from\", \"license\", \"template\", \"system\", \"adapter\", \"parameter\", or \"message\"")
+	errInvalidVersion     = errors.New("invalid version")
 )
 
 func ParseFile(r io.Reader) (*File, error) {
@@ -109,6 +114,9 @@ func ParseFile(r io.Reader) (*File, error) {
 				case "message":
 					// transition to stateMessage which validates the message role
 					next = stateMessage
+					cmd.Name = s
+				case "ollama":
+					next = stateVersion
 					fallthrough
 				default:
 					cmd.Name = s
@@ -123,6 +131,23 @@ func ParseFile(r io.Reader) (*File, error) {
 				role = b.String()
 			case stateComment, stateNil:
 				// pass
+			case stateVersion:
+				s, ok := unquote(strings.TrimSpace(b.String()))
+				if !ok {
+					if _, err := b.WriteRune(r); err != nil {
+						return nil, err
+					}
+
+					continue
+				} else if isSpace(r){
+					return nil, errInvalidVersion
+				} else if _, err := semver.NewVersion(s); err != nil {
+					return nil, errInvalidVersion
+				}
+
+				cmd.Args = s
+				f.Commands = append(f.Commands, cmd)
+
 			case stateValue:
 				s, ok := unquote(strings.TrimSpace(b.String()))
 				if !ok || isSpace(r) {
@@ -157,6 +182,16 @@ func ParseFile(r io.Reader) (*File, error) {
 	switch curr {
 	case stateComment, stateNil:
 		// pass; nothing to flush
+	case stateVersion:
+		s, ok := unquote(strings.TrimSpace(b.String()))
+		if !ok {
+			return nil, io.ErrUnexpectedEOF
+		} else if _, err := semver.NewVersion(s); err != nil {
+			return nil, errInvalidVersion
+		}
+
+		cmd.Args = s
+		f.Commands = append(f.Commands, cmd)
 	case stateValue:
 		s, ok := unquote(strings.TrimSpace(b.String()))
 		if !ok {
@@ -236,6 +271,15 @@ func parseRuneForState(r rune, cs state) (state, rune, error) {
 		default:
 			return stateComment, 0, nil
 		}
+	case stateVersion:
+		switch {
+		case isNewline(r), isSpace(r):
+			return stateNil, 0, nil
+		case isAlpha(r), isNumber(r), r == '.':
+			return stateVersion, r, nil
+		default:
+			return stateNil, r, nil
+		}
 	default:
 		return stateNil, 0, errors.New("")
 	}
@@ -296,7 +340,7 @@ func isValidMessageRole(role string) bool {
 
 func isValidCommand(cmd string) bool {
 	switch strings.ToLower(cmd) {
-	case "from", "license", "template", "system", "adapter", "parameter", "message":
+	case "from", "license", "template", "system", "adapter", "parameter", "message", "ollama":
 		return true
 	default:
 		return false

+ 5 - 2
server/images.go

@@ -374,6 +374,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
 	}
 
 	var messages []*api.Message
+	var version string
 	parameters := make(map[string]any)
 
 	var layers []*Layer
@@ -529,6 +530,8 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
 			}
 
 			messages = append(messages, &api.Message{Role: role, Content: content})
+		case "ollama":
+			version = c.Args
 		default:
 			ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}})
 			if err != nil {
@@ -545,7 +548,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
 				}
 			}
 		}
-	}
+	}		
 
 	var err2 error
 	layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
@@ -642,7 +645,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
 	old, _ := ParseNamedManifest(name)
 
 	fn(api.ProgressResponse{Status: "writing manifest"})
-	if err := WriteManifest(name, layer, layers); err != nil {
+	if err := WriteManifest(name, layer, layers, version); err != nil {
 		return err
 	}
 

+ 3 - 1
server/manifest.go

@@ -18,6 +18,7 @@ type Manifest struct {
 	MediaType     string   `json:"mediaType"`
 	Config        *Layer   `json:"config"`
 	Layers        []*Layer `json:"layers"`
+	Ollama		  string   `json:"ollama"`
 
 	filepath string
 	fi       os.FileInfo
@@ -93,7 +94,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
 	return &m, nil
 }
 
-func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
+func WriteManifest(name model.Name, config *Layer, layers []*Layer, ollama string) error {
 	manifests, err := GetManifestPath()
 	if err != nil {
 		return err
@@ -115,6 +116,7 @@ func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
 		MediaType:     "application/vnd.docker.distribution.manifest.v2+json",
 		Config:        config,
 		Layers:        layers,
+		Ollama:		   ollama,
 	}
 
 	return json.NewEncoder(f).Encode(m)

+ 31 - 0
server/routes_create_test.go

@@ -623,3 +623,34 @@ func TestCreateDetectTemplate(t *testing.T) {
 		})
 	})
 }
+
+func TestCreateVersion(t *testing.T){
+	gin.SetMode(gin.TestMode)
+
+	p := t.TempDir()
+	t.Setenv("OLLAMA_MODELS", p)
+	envconfig.LoadConfig()
+	var s Server
+
+	 w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
+		Name:      "test",
+		Modelfile: fmt.Sprintf("FROM %s\nOLLAMA 0.2.3\nLICENSE MIT\nLICENSE Apache-2.0", createBinFile(t, nil, nil)),
+		Stream:    &stream,
+	})
+
+	if w.Code != http.StatusOK {
+		t.Fatalf("expected status code 200, actual %d", w.Code)
+	} 
+
+	t.Run("invalid version", func(t *testing.T) {
+		w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
+			Name:      "test",
+			Modelfile: fmt.Sprintf("FROM %s\nOLLAMA 0..400", createBinFile(t, nil, nil)),
+			Stream:    &stream,
+		})
+
+		if w.Code != http.StatusBadRequest {
+			t.Fatalf("expected status code 400, actual %d", w.Code)
+		}
+	})
+}

+ 1 - 1
server/routes_delete_test.go

@@ -99,7 +99,7 @@ func TestDeleteDuplicateLayers(t *testing.T) {
 	}
 
 	// create a manifest with duplicate layers
-	if err := WriteManifest(n, config, []*Layer{config}); err != nil {
+	if err := WriteManifest(n, config, []*Layer{config}, ""); err != nil {
 		t.Fatal(err)
 	}