Browse Source

allow specifying zero values in modelfile

Bruce MacDonald 1 year ago
parent
commit
8b1e791820
5 changed files with 101 additions and 25 deletions
  1. 79 1
      api/types.go
  2. 0 1
      go.mod
  3. 0 2
      go.sum
  4. 18 17
      server/images.go
  5. 4 4
      server/routes.go

+ 79 - 1
api/types.go

@@ -3,9 +3,12 @@ package api
 import (
 	"encoding/json"
 	"fmt"
+	"log"
 	"math"
 	"os"
+	"reflect"
 	"runtime"
+	"strings"
 	"time"
 )
 
@@ -34,7 +37,7 @@ type GenerateRequest struct {
 	Prompt  string `json:"prompt"`
 	Context []int  `json:"context,omitempty"`
 
-	Options `json:"options"`
+	Options map[string]interface{} `json:"options"`
 }
 
 type CreateRequest struct {
@@ -177,6 +180,81 @@ type Options struct {
 	NumThread int `json:"num_thread,omitempty"`
 }
 
+func (opts *Options) FromMap(m map[string]interface{}) error {
+	valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
+	typeOpts := reflect.TypeOf(opts).Elem()   // types of the fields in the options struct
+
+	// build map of json struct tags to their types
+	jsonOpts := make(map[string]reflect.StructField)
+	for _, field := range reflect.VisibleFields(typeOpts) {
+		jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
+		if jsonTag != "" {
+			jsonOpts[jsonTag] = field
+		}
+	}
+
+	for key, val := range m {
+		if opt, ok := jsonOpts[key]; ok {
+			field := valueOpts.FieldByName(opt.Name)
+			if field.IsValid() && field.CanSet() {
+				switch field.Kind() {
+				case reflect.Int:
+					// when JSON unmarshals numbers, it uses float64 by default, not int
+					val, ok := val.(float64)
+					if !ok {
+						log.Printf("could not convert model parmeter %v to int, skipped", key)
+						continue
+					}
+					field.SetInt(int64(val))
+				case reflect.Bool:
+					val, ok := val.(bool)
+					if !ok {
+						log.Printf("could not convert model parmeter %v to bool, skipped", key)
+						continue
+					}
+					field.SetBool(val)
+				case reflect.Float32:
+					// JSON unmarshals to float64
+					val, ok := val.(float64)
+					if !ok {
+						log.Printf("could not convert model parmeter %v to float32, skipped", key)
+						continue
+					}
+					field.SetFloat(val)
+				case reflect.String:
+					val, ok := val.(string)
+					if !ok {
+						log.Printf("could not convert model parmeter %v to string, skipped", key)
+						continue
+					}
+					field.SetString(val)
+				case reflect.Slice:
+					// JSON unmarshals to []interface{}, not []string
+					val, ok := val.([]interface{})
+					if !ok {
+						log.Printf("could not convert model parmeter %v to slice, skipped", key)
+						continue
+					}
+					// convert []interface{} to []string
+					slice := make([]string, len(val))
+					for i, item := range val {
+						str, ok := item.(string)
+						if !ok {
+							log.Printf("could not convert model parmeter %v to slice of strings, skipped", key)
+							continue
+						}
+						slice[i] = str
+					}
+					field.Set(reflect.ValueOf(slice))
+				default:
+					return fmt.Errorf("unknown type loading config params: %v", field.Kind())
+				}
+			}
+		}
+	}
+	return nil
+}
+
 func DefaultOptions() Options {
 	return Options{
 		Seed: -1,

+ 0 - 1
go.mod

@@ -14,7 +14,6 @@ require (
 require github.com/rivo/uniseg v0.2.0 // indirect
 
 require (
-	dario.cat/mergo v1.0.0
 	github.com/bytedance/sonic v1.9.1 // indirect
 	github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
 	github.com/chzyer/readline v1.5.1

+ 0 - 2
go.sum

@@ -1,5 +1,3 @@
-dario.cat/mergo v1.0.0 h1:AGCNq9Evsj31mOgNPcLyXc+4PNABt905YmuqPYYpBWk=
-dario.cat/mergo v1.0.0/go.mod h1:uNxQE+84aUszobStD9th8a29P2fMDhsBdgRYvZOxGmk=
 github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
 github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
 github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=

+ 18 - 17
server/images.go

@@ -32,8 +32,8 @@ type Model struct {
 	ModelPath string
 	Template  string
 	System    string
-	Digest string
-	Options   api.Options
+	Digest    string
+	Options   map[string]interface{}
 }
 
 func (m *Model) Prompt(request api.GenerateRequest) (string, error) {
@@ -135,7 +135,7 @@ func GetModel(name string) (*Model, error) {
 	}
 
 	model := &Model{
-		Name: mp.GetFullTagname(),
+		Name:   mp.GetFullTagname(),
 		Digest: manifest.Config.Digest,
 	}
 
@@ -176,12 +176,10 @@ func GetModel(name string) (*Model, error) {
 			}
 			defer params.Close()
 
-			var opts api.Options
-			if err = json.NewDecoder(params).Decode(&opts); err != nil {
+			// parse model options parameters into a map so that we can see which fields have been specified explicitly
+			if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
 				return nil, err
 			}
-
-			model.Options = opts
 		}
 	}
 
@@ -442,11 +440,13 @@ func GetLayerWithBufferFromLayer(layer *Layer) (*LayerReader, error) {
 	return newLayer, nil
 }
 
+// paramsToReader converts specified parameter options to their correct types, and returns a reader for the json
 func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
-	opts := api.DefaultOptions()
-	typeOpts := reflect.TypeOf(opts)
+	opts := api.Options{}
+	valueOpts := reflect.ValueOf(&opts).Elem() // names of the fields in the options struct
+	typeOpts := reflect.TypeOf(opts)           // types of the fields in the options struct
 
-	// build map of json struct tags
+	// build map of json struct tags to their types
 	jsonOpts := make(map[string]reflect.StructField)
 	for _, field := range reflect.VisibleFields(typeOpts) {
 		jsonTag := strings.Split(field.Tag.Get("json"), ",")[0]
@@ -455,7 +455,7 @@ func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
 		}
 	}
 
-	valueOpts := reflect.ValueOf(&opts).Elem()
+	out := make(map[string]interface{})
 	// iterate params and set values based on json struct tags
 	for key, vals := range params {
 		if opt, ok := jsonOpts[key]; ok {
@@ -468,25 +468,26 @@ func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
 						return nil, fmt.Errorf("invalid float value %s", vals)
 					}
 
-					field.SetFloat(floatVal)
+					out[key] = floatVal
 				case reflect.Int:
 					intVal, err := strconv.ParseInt(vals[0], 10, 0)
 					if err != nil {
 						return nil, fmt.Errorf("invalid int value %s", vals)
 					}
 
-					field.SetInt(intVal)
+					out[key] = intVal
 				case reflect.Bool:
 					boolVal, err := strconv.ParseBool(vals[0])
 					if err != nil {
 						return nil, fmt.Errorf("invalid bool value %s", vals)
 					}
 
-					field.SetBool(boolVal)
+					out[key] = boolVal
 				case reflect.String:
-					field.SetString(vals[0])
+					out[key] = vals[0]
 				case reflect.Slice:
-					field.Set(reflect.ValueOf(vals))
+					// TODO: only string slices are supported right now
+					out[key] = vals
 				default:
 					return nil, fmt.Errorf("unknown type %s for %s", field.Kind(), key)
 				}
@@ -494,7 +495,7 @@ func paramsToReader(params map[string][]string) (io.ReadSeeker, error) {
 		}
 	}
 
-	bts, err := json.Marshal(opts)
+	bts, err := json.Marshal(out)
 	if err != nil {
 		return nil, err
 	}

+ 4 - 4
server/routes.go

@@ -15,7 +15,6 @@ import (
 	"sync"
 	"time"
 
-	"dario.cat/mergo"
 	"github.com/gin-contrib/cors"
 	"github.com/gin-gonic/gin"
 
@@ -61,12 +60,13 @@ func GenerateHandler(c *gin.Context) {
 		}
 
 		opts := api.DefaultOptions()
-		if err := mergo.Merge(&opts, model.Options, mergo.WithOverride); err != nil {
+		if err := opts.FromMap(model.Options); err != nil {
+			log.Printf("could not load model options: %v", err)
 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 			return
 		}
-
-		if err := mergo.Merge(&opts, req.Options, mergo.WithOverride); err != nil {
+		if err := opts.FromMap(req.Options); err != nil {
+			log.Printf("could not merge model options: %v", err)
 			c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
 			return
 		}