Преглед изворни кода

Merge pull request #1334 from jmorganca/mxyng/load-projectors

load projectors
Michael Yang пре 1 година
родитељ
комит
32f62fbb8e
5 измењених фајлова са 62 додато и 27 уклоњено
  1. 10 0
      api/types.go
  2. 6 1
      llm/llama.go
  3. 3 3
      llm/llm.go
  4. 32 17
      server/images.go
  5. 11 6
      server/routes.go

+ 10 - 0
api/types.go

@@ -203,12 +203,22 @@ type GenerateResponse struct {
 	CreatedAt time.Time `json:"created_at"`
 	Response  string    `json:"response"`
 
+	ModelConfiguration ModelConfiguration `json:"model_configuration"`
+
 	Done    bool  `json:"done"`
 	Context []int `json:"context,omitempty"`
 
 	Metrics
 }
 
+type ModelConfiguration struct {
+	ModelFormat   string   `json:"model_format"`
+	ModelFamily   string   `json:"model_family"`
+	ModelFamilies []string `json:"model_families"`
+	ModelType     string   `json:"model_type"`
+	FileType      string   `json:"file_type"`
+}
+
 func (m *Metrics) Summary() {
 	if m.TotalDuration > 0 {
 		fmt.Fprintf(os.Stderr, "total duration:       %v\n", m.TotalDuration)

+ 6 - 1
llm/llama.go

@@ -325,7 +325,7 @@ func (w *StatusWriter) Write(b []byte) (int, error) {
 	return os.Stderr.Write(b)
 }
 
-func newLlama(model string, adapters []string, runners []ModelRunner, numLayers int64, opts api.Options) (*llama, error) {
+func newLlama(model string, adapters, projectors []string, runners []ModelRunner, numLayers int64, opts api.Options) (*llama, error) {
 	fileInfo, err := os.Stat(model)
 	if err != nil {
 		return nil, err
@@ -365,6 +365,11 @@ func newLlama(model string, adapters []string, runners []ModelRunner, numLayers
 		params = append(params, "--lora", adapters[0])
 	}
 
+	if len(projectors) > 0 {
+		// TODO: applying multiple projectors is not supported by the llama.cpp server yet
+		params = append(params, "--mmproj", projectors[0])
+	}
+
 	if opts.NumThread > 0 {
 		params = append(params, "--threads", fmt.Sprintf("%d", opts.NumThread))
 	}

+ 3 - 3
llm/llm.go

@@ -23,7 +23,7 @@ type LLM interface {
 	Ping(context.Context) error
 }
 
-func New(workDir, model string, adapters []string, opts api.Options) (LLM, error) {
+func New(workDir, model string, adapters, projectors []string, opts api.Options) (LLM, error) {
 	if _, err := os.Stat(model); err != nil {
 		return nil, err
 	}
@@ -82,9 +82,9 @@ func New(workDir, model string, adapters []string, opts api.Options) (LLM, error
 		opts.NumGQA = 0
 		opts.RopeFrequencyBase = 0.0
 		opts.RopeFrequencyScale = 0.0
-		return newLlama(model, adapters, chooseRunners(workDir, "gguf"), ggml.NumLayers(), opts)
+		return newLlama(model, adapters, projectors, chooseRunners(workDir, "gguf"), ggml.NumLayers(), opts)
 	case "ggml", "ggmf", "ggjt", "ggla":
-		return newLlama(model, adapters, chooseRunners(workDir, "ggml"), ggml.NumLayers(), opts)
+		return newLlama(model, adapters, projectors, chooseRunners(workDir, "ggml"), ggml.NumLayers(), opts)
 	default:
 		return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily())
 	}

+ 32 - 17
server/images.go

@@ -35,16 +35,18 @@ type RegistryOptions struct {
 }
 
 type Model struct {
-	Name          string `json:"name"`
-	ShortName     string
-	ModelPath     string
-	OriginalModel string
-	AdapterPaths  []string
-	Template      string
-	System        string
-	License       []string
-	Digest        string
-	Options       map[string]interface{}
+	Name           string `json:"name"`
+	Config         ConfigV2
+	ShortName      string
+	ModelPath      string
+	OriginalModel  string
+	AdapterPaths   []string
+	ProjectorPaths []string
+	Template       string
+	System         string
+	License        []string
+	Digest         string
+	Options        map[string]interface{}
 }
 
 type PromptVars struct {
@@ -136,16 +138,12 @@ type ManifestV2 struct {
 }
 
 type ConfigV2 struct {
-	ModelFormat   string   `json:"model_format"`
-	ModelFamily   string   `json:"model_family"`
-	ModelFamilies []string `json:"model_families"`
-	ModelType     string   `json:"model_type"`
-	FileType      string   `json:"file_type"`
-	RootFS        RootFS   `json:"rootfs"`
-
 	// required by spec
 	Architecture string `json:"architecture"`
 	OS           string `json:"os"`
+	RootFS       RootFS `json:"rootfs"`
+
+	api.ModelConfiguration
 }
 
 func (c *ConfigV2) SetModelFormat(format string) {
@@ -234,6 +232,21 @@ func GetModel(name string) (*Model, error) {
 		License:   []string{},
 	}
 
+	filename, err := GetBlobsPath(manifest.Config.Digest)
+	if err != nil {
+		return nil, err
+	}
+
+	configFile, err := os.Open(filename)
+	if err != nil {
+		return nil, err
+	}
+	defer configFile.Close()
+
+	if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
+		return nil, err
+	}
+
 	for _, layer := range manifest.Layers {
 		filename, err := GetBlobsPath(layer.Digest)
 		if err != nil {
@@ -250,6 +263,8 @@ func GetModel(name string) (*Model, error) {
 			log.Print("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
 		case "application/vnd.ollama.image.adapter":
 			model.AdapterPaths = append(model.AdapterPaths, filename)
+		case "application/vnd.ollama.image.projector":
+			model.ProjectorPaths = append(model.ProjectorPaths, filename)
 		case "application/vnd.ollama.image.template":
 			bts, err := os.ReadFile(filename)
 			if err != nil {

+ 11 - 6
server/routes.go

@@ -105,7 +105,7 @@ func load(c *gin.Context, modelName string, reqOpts map[string]interface{}, sess
 			loaded.Options = nil
 		}
 
-		llmRunner, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, opts)
+		llmRunner, err := llm.New(workDir, model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
 		if err != nil {
 			// some older models are not compatible with newer versions of llama.cpp
 			// show a generalized compatibility error until there is a better way to
@@ -198,7 +198,11 @@ func GenerateHandler(c *gin.Context) {
 
 	// an empty request loads the model
 	if req.Prompt == "" && req.Template == "" && req.System == "" {
-		c.JSON(http.StatusOK, api.GenerateResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true})
+		c.JSON(http.StatusOK, api.GenerateResponse{
+			CreatedAt:          time.Now().UTC(),
+			Model:              req.Model,
+			ModelConfiguration: model.Config.ModelConfiguration,
+			Done:               true})
 		return
 	}
 
@@ -257,10 +261,11 @@ func GenerateHandler(c *gin.Context) {
 			}
 
 			resp := api.GenerateResponse{
-				Model:     r.Model,
-				CreatedAt: r.CreatedAt,
-				Done:      r.Done,
-				Response:  r.Content,
+				Model:              r.Model,
+				ModelConfiguration: model.Config.ModelConfiguration,
+				CreatedAt:          r.CreatedAt,
+				Done:               r.Done,
+				Response:           r.Content,
 				Metrics: api.Metrics{
 					TotalDuration:      r.TotalDuration,
 					LoadDuration:       r.LoadDuration,