Browse Source

partial decode ggml bin for more info

Michael Yang 1 year ago
parent
commit
fccf8d179f

+ 0 - 0
llama/ggml-alloc.c → llm/ggml-alloc.c


+ 0 - 0
llama/ggml-alloc.h → llm/ggml-alloc.h


+ 0 - 0
llama/ggml-cuda.cu → llm/ggml-cuda.cu


+ 0 - 0
llama/ggml-cuda.h → llm/ggml-cuda.h


+ 0 - 0
llama/ggml-metal.h → llm/ggml-metal.h


+ 0 - 0
llama/ggml-metal.m → llm/ggml-metal.m


+ 0 - 0
llama/ggml-metal.metal → llm/ggml-metal.metal


+ 0 - 0
llama/ggml-mpi.c → llm/ggml-mpi.c


+ 0 - 0
llama/ggml-mpi.h → llm/ggml-mpi.h


+ 0 - 0
llama/ggml-opencl.cpp → llm/ggml-opencl.cpp


+ 0 - 0
llama/ggml-opencl.h → llm/ggml-opencl.h


+ 0 - 0
llama/ggml.c → llm/ggml.c


+ 180 - 0
llm/ggml.go

@@ -0,0 +1,180 @@
+package llm
+
+import (
+	"encoding/binary"
+	"errors"
+	"fmt"
+	"io"
+)
+
+type ModelFamily string
+
+const ModelFamilyLlama ModelFamily = "llama"
+
+type ModelType uint32
+
+const (
+	ModelType3B  ModelType = 26
+	ModelType7B  ModelType = 32
+	ModelType13B ModelType = 40
+	ModelType30B ModelType = 60
+	ModelType65B ModelType = 80
+)
+
+type FileType uint32
+
+const (
+	FileTypeF32 FileType = iota
+	FileTypeF16
+	FileTypeQ4_0
+	FileTypeQ4_1
+	FileTypeQ4_1_F16
+	FileTypeQ8_0 = iota + 3
+	FileTypeQ5_0
+	FileTypeQ5_1
+	FileTypeQ2_K
+	FileTypeQ3_K
+	FileTypeQ4_K
+	FileTypeQ5_K
+	FileTypeQ6_K
+	FileTypeUnknown = -1
+)
+
+type GGML struct {
+	ModelFamily
+	ModelType
+
+	magic uint32
+	container
+
+	llamaHyperparameters
+}
+
+type container interface {
+	Name() string
+	Decode(io.Reader) error
+}
+
+type containerGGML struct {
+}
+
+func (c *containerGGML) Name() string {
+	return "ggml"
+}
+
+func (c *containerGGML) Decode(r io.Reader) error {
+	return nil
+}
+
+type containerGGMF struct {
+	version uint32
+}
+
+func (c *containerGGMF) Name() string {
+	return "ggmf"
+}
+
+func (c *containerGGMF) Decode(r io.Reader) error {
+	var version uint32
+	binary.Read(r, binary.LittleEndian, &version)
+
+	switch version {
+	case 1:
+	default:
+		return errors.New("invalid version")
+	}
+
+	c.version = version
+	return nil
+}
+
+type containerGGJT struct {
+	version uint32
+}
+
+func (c *containerGGJT) Name() string {
+	return "ggjt"
+}
+
+func (c *containerGGJT) Decode(r io.Reader) error {
+	var version uint32
+	binary.Read(r, binary.LittleEndian, &version)
+
+	switch version {
+	case 1, 2, 3:
+	default:
+		return errors.New("invalid version")
+	}
+
+	c.version = version
+	return nil
+}
+
+type containerLORA struct {
+	version uint32
+}
+
+func (c *containerLORA) Name() string {
+	return "ggla"
+}
+
+func (c *containerLORA) Decode(r io.Reader) error {
+	var version uint32
+	binary.Read(r, binary.LittleEndian, &version)
+
+	switch version {
+	case 1:
+	default:
+		return errors.New("invalid version")
+	}
+
+	c.version = version
+	return nil
+}
+
+const (
+	// / Magic constant for `ggml` files (unversioned).
+	FILE_MAGIC_GGML = 0x67676d6c
+	// / Magic constant for `ggml` files (versioned, ggmf).
+	FILE_MAGIC_GGMF = 0x67676d66
+	// / Magic constant for `ggml` files (versioned, ggjt).
+	FILE_MAGIC_GGJT = 0x67676a74
+	// / Magic constant for `ggla` files (LoRA adapter).
+	FILE_MAGIC_GGLA = 0x67676C61
+)
+
+func DecodeGGML(r io.ReadSeeker, hint ModelFamily) (*GGML, error) {
+	var ggml GGML
+	binary.Read(r, binary.LittleEndian, &ggml.magic)
+
+	switch ggml.magic {
+	case FILE_MAGIC_GGML:
+		ggml.container = &containerGGML{}
+	case FILE_MAGIC_GGMF:
+		ggml.container = &containerGGMF{}
+	case FILE_MAGIC_GGJT:
+		ggml.container = &containerGGJT{}
+	case FILE_MAGIC_GGLA:
+		ggml.container = &containerLORA{}
+	default:
+		return nil, errors.New("invalid file magic")
+	}
+
+	if err := ggml.Decode(r); err != nil {
+		return nil, err
+	}
+
+	// different model types may have different layouts for hyperparameters
+	switch hint {
+	case ModelFamilyLlama:
+		binary.Read(r, binary.LittleEndian, &ggml.llamaHyperparameters)
+		// TODO: sanity check hyperparameters
+	default:
+		return nil, fmt.Errorf("unsupported model type: %s", hint)
+	}
+
+	// final model type
+	ggml.ModelFamily = hint
+	ggml.ModelType = ModelType(ggml.NumLayer)
+	return &ggml, nil
+}

+ 0 - 0
llama/ggml.h → llm/ggml.h


+ 0 - 0
llama/k_quants.c → llm/k_quants.c


+ 0 - 0
llama/k_quants.h → llm/k_quants.h


+ 0 - 0
llama/llama-util.h → llm/llama-util.h


+ 0 - 0
llama/llama.cpp → llm/llama.cpp


+ 65 - 35
llama/llama.go → llm/llama.go

@@ -1,4 +1,4 @@
-package llama
+package llm
 
 /*
 #cgo CPPFLAGS: -O3 -Wall -Wextra -Wno-unused-function -Wno-unused-variable -DNDEBUG -DGGML_USE_K_QUANTS
@@ -105,7 +105,7 @@ import (
 //go:embed ggml-metal.metal
 var fs embed.FS
 
-type LLM struct {
+type llama struct {
 	params *C.struct_llama_context_params
 	model  *C.struct_llama_model
 	ctx    *C.struct_llama_context
@@ -120,12 +120,28 @@ type LLM struct {
 	api.Options
 }
 
-func New(model string, opts api.Options) (*LLM, error) {
+type llamaHyperparameters struct {
+	// NumVocab is the size of the model's vocabulary.
+	NumVocab uint32
+
+	// NumEmbd is the size of the model's embedding layer.
+	NumEmbd uint32
+	NumMult uint32
+	NumHead uint32
+
+	// NumLayer is the number of layers in the model.
+	NumLayer uint32
+	NumRot   uint32
+	// FileType describes the quantization level of the model, e.g. Q4_0, Q5_K, etc.
+	FileType
+}
+
+func newLlama(model string, opts api.Options) (*llama, error) {
 	if _, err := os.Stat(model); err != nil {
 		return nil, err
 	}
 
-	llm := LLM{Options: opts}
+	llm := llama{Options: opts}
 
 	C.llama_backend_init(C.bool(llm.UseNUMA))
 
@@ -168,7 +184,7 @@ func New(model string, opts api.Options) (*LLM, error) {
 	return &llm, nil
 }
 
-func (llm *LLM) Close() {
+func (llm *llama) Close() {
 	llm.gc = true
 
 	llm.mu.Lock()
@@ -180,17 +196,16 @@ func (llm *LLM) Close() {
 	C.llama_print_timings(llm.ctx)
 }
 
+func (llm *llama) SetOptions(opts api.Options) {
+	llm.Options = opts
+}
+
 var errNeedMoreData = errors.New("need more data")
 
-func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
+func (llm *llama) Predict(ctx []int, prompt string, fn func(api.GenerateResponse)) error {
 	C.llama_reset_timings(llm.ctx)
 
-	tokens := make([]C.llama_token, len(ctx))
-	for i := range tokens {
-		tokens[i] = C.llama_token(ctx[i])
-	}
-
-	llm.marshalPrompt(tokens, prompt)
+	llm.marshalPrompt(ctx, prompt)
 
 	C.llama_set_rng_seed(llm.ctx, C.uint(llm.Seed))
 
@@ -205,7 +220,7 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
 			return err
 		}
 
-		b.WriteString(llm.Decode(token))
+		b.WriteString(llm.Decode(int(token)))
 
 		if err := llm.checkStopConditions(b); err != nil {
 			if errors.Is(err, io.EOF) {
@@ -243,7 +258,7 @@ func (llm *LLM) Predict(ctx []int, prompt string, fn func(api.GenerateResponse))
 	return nil
 }
 
-func (llm *LLM) checkStopConditions(b bytes.Buffer) error {
+func (llm *llama) checkStopConditions(b bytes.Buffer) error {
 	for _, stopCondition := range llm.Stop {
 		if stopCondition == strings.TrimSpace(b.String()) {
 			return io.EOF
@@ -255,12 +270,17 @@ func (llm *LLM) checkStopConditions(b bytes.Buffer) error {
 	return nil
 }
 
-func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_token {
+func (llm *llama) marshalPrompt(ctx []int, prompt string) []C.llama_token {
 	tokens := append(ctx, llm.Encode(prompt)...)
 	if llm.NumKeep < 0 {
 		llm.NumKeep = len(tokens)
 	}
 
+	cTokens := make([]C.llama_token, len(tokens))
+	for i := range tokens {
+		cTokens[i] = C.llama_token(tokens[i])
+	}
+
 	// min(llm.NumCtx - 4, llm.NumKeep)
 	if llm.NumCtx-4 < llm.NumKeep {
 		llm.NumKeep = llm.NumCtx - 4
@@ -269,25 +289,25 @@ func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_toke
 	if len(tokens) >= llm.NumCtx {
 		// truncate input
 		numLeft := (llm.NumCtx - llm.NumKeep) / 2
-		truncated := tokens[:llm.NumKeep]
-		erasedBlocks := (len(tokens) - llm.NumKeep - numLeft - 1) / numLeft
-		truncated = append(truncated, tokens[llm.NumKeep+erasedBlocks*numLeft:]...)
-		copy(llm.last, tokens[len(tokens)-llm.NumCtx:])
+		truncated := cTokens[:llm.NumKeep]
+		erasedBlocks := (len(cTokens) - llm.NumKeep - numLeft - 1) / numLeft
+		truncated = append(truncated, cTokens[llm.NumKeep+erasedBlocks*numLeft:]...)
+		copy(llm.last, cTokens[len(cTokens)-llm.NumCtx:])
 
-		tokens = truncated
+		cTokens = truncated
 		log.Printf("input truncated: num_ctx=%d num_keep=%d num_left=%d num_tokens=%d", llm.NumCtx, llm.NumKeep, numLeft, len(truncated))
 	} else {
-		llm.last = make([]C.llama_token, llm.NumCtx-len(tokens))
-		llm.last = append(llm.last, tokens...)
+		llm.last = make([]C.llama_token, llm.NumCtx-len(cTokens))
+		llm.last = append(llm.last, cTokens...)
 	}
 
 	var i int
-	for i = 0; i < len(llm.embd) && i < len(tokens) && llm.embd[i] == tokens[i]; i++ {
+	for i = 0; i < len(llm.embd) && i < len(cTokens) && llm.embd[i] == cTokens[i]; i++ {
 		// noop
 	}
 
-	llm.embd = tokens
-	if i == len(tokens) {
+	llm.embd = cTokens
+	if i == len(cTokens) {
 		// evaluate at least one token to generate logits
 		i--
 	}
@@ -295,31 +315,36 @@ func (llm *LLM) marshalPrompt(ctx []C.llama_token, prompt string) []C.llama_toke
 	llm.cursor = i
 
 	log.Printf("prompt: num_past=%d cached=%v eval=%v", i, len(llm.embd[:i]), len(llm.embd[i:]))
-	return tokens
+	return cTokens
 }
 
-func (llm *LLM) Encode(prompt string) []C.llama_token {
+func (llm *llama) Encode(prompt string) []int {
 	cPrompt := C.CString(prompt)
 	defer C.free(unsafe.Pointer(cPrompt))
 
-	tokens := make([]C.llama_token, len(prompt)+1)
-	if n := C.llama_tokenize(llm.ctx, cPrompt, unsafe.SliceData(tokens), C.int(len(tokens)), true); n > 0 {
-		return tokens[:n]
+	cTokens := make([]C.llama_token, len(prompt)+1)
+	if n := C.llama_tokenize(llm.ctx, cPrompt, unsafe.SliceData(cTokens), C.int(len(cTokens)), true); n > 0 {
+		tokens := make([]int, n)
+		for i := range cTokens[:n] {
+			tokens[i] = int(cTokens[i])
+		}
+
+		return tokens
 	}
 
 	return nil
 }
 
-func (llm *LLM) Decode(tokens ...C.llama_token) string {
+func (llm *llama) Decode(tokens ...int) string {
 	var sb strings.Builder
 	for _, token := range tokens {
-		sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, token)))
+		sb.WriteString(C.GoString(C.llama_token_to_str(llm.ctx, C.llama_token(token))))
 	}
 
 	return sb.String()
 }
 
-func (llm *LLM) next() (C.llama_token, error) {
+func (llm *llama) next() (C.llama_token, error) {
 	llm.mu.Lock()
 	defer llm.mu.Unlock()
 
@@ -410,7 +435,7 @@ func (llm *LLM) next() (C.llama_token, error) {
 	return token, nil
 }
 
-func (llm *LLM) Embedding(input string) ([]float64, error) {
+func (llm *llama) Embedding(input string) ([]float64, error) {
 	if !llm.EmbeddingOnly {
 		return nil, errors.New("llama: embedding not enabled")
 	}
@@ -420,7 +445,12 @@ func (llm *LLM) Embedding(input string) ([]float64, error) {
 		return nil, errors.New("llama: tokenize embedding")
 	}
 
-	retval := C.llama_eval(llm.ctx, unsafe.SliceData(tokens), C.int(len(tokens)), 0, C.int(llm.NumThread))
+	cTokens := make([]C.llama_token, len(tokens))
+	for i := range tokens {
+		cTokens[i] = C.llama_token(tokens[i])
+	}
+
+	retval := C.llama_eval(llm.ctx, unsafe.SliceData(cTokens), C.int(len(tokens)), 0, C.int(llm.NumThread))
 	if retval != 0 {
 		return nil, errors.New("llama: eval")
 	}

+ 0 - 0
llama/llama.h → llm/llama.h


+ 1 - 1
llama/llama_darwin.go → llm/llama_darwin.go

@@ -1,4 +1,4 @@
-package llama
+package llm
 
 import (
 	"bytes"

+ 40 - 0
llm/llm.go

@@ -0,0 +1,40 @@
+package llm
+
+import (
+	"fmt"
+	"os"
+
+	"github.com/jmorganca/ollama/api"
+)
+
+type LLM interface {
+	Predict([]int, string, func(api.GenerateResponse)) error
+	Embedding(string) ([]float64, error)
+	Encode(string) []int
+	Decode(...int) string
+	SetOptions(api.Options)
+	Close()
+}
+
+func New(model string, opts api.Options) (LLM, error) {
+	if _, err := os.Stat(model); err != nil {
+		return nil, err
+	}
+
+	f, err := os.Open(model)
+	if err != nil {
+		return nil, err
+	}
+
+	ggml, err := DecodeGGML(f, ModelFamilyLlama)
+	if err != nil {
+		return nil, err
+	}
+
+	switch ggml.ModelFamily {
+	case ModelFamilyLlama:
+		return newLlama(model, opts)
+	default:
+		return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily)
+	}
+}

+ 0 - 0
llama/update-llama-cpp.sh → llm/update-llama-cpp.sh


+ 1 - 1
llama/utils.go → llm/utils.go

@@ -1,4 +1,4 @@
-package llama
+package llm
 
 import (
 	"fmt"

+ 36 - 22
server/images.go

@@ -19,7 +19,7 @@ import (
 	"strings"
 
 	"github.com/jmorganca/ollama/api"
-	"github.com/jmorganca/ollama/llama"
+	"github.com/jmorganca/ollama/llm"
 	"github.com/jmorganca/ollama/parser"
 	"github.com/jmorganca/ollama/vector"
 )
@@ -98,9 +98,14 @@ type LayerReader struct {
 }
 
 type ConfigV2 struct {
+	ModelFamily llm.ModelFamily `json:"model_family"`
+	ModelType   llm.ModelType   `json:"model_type"`
+	FileType    llm.FileType    `json:"file_type"`
+	RootFS      RootFS          `json:"rootfs"`
+
+	// required by spec
 	Architecture string `json:"architecture"`
 	OS           string `json:"os"`
-	RootFS       RootFS `json:"rootfs"`
 }
 
 type RootFS struct {
@@ -245,6 +250,11 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
 		return err
 	}
 
+	config := ConfigV2{
+		Architecture: "amd64",
+		OS:           "linux",
+	}
+
 	var layers []*LayerReader
 	params := make(map[string][]string)
 	embed := EmbeddingParams{fn: fn, opts: api.DefaultOptions()}
@@ -283,6 +293,18 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
 					}
 					defer file.Close()
 
+					ggml, err := llm.DecodeGGML(file, llm.ModelFamilyLlama)
+					if err != nil {
+						return err
+					}
+
+					config.ModelFamily = ggml.ModelFamily
+					config.ModelType = ggml.ModelType
+					config.FileType = ggml.FileType
+
+					// reset the file
+					file.Seek(0, io.SeekStart)
+
 					l, err := CreateLayer(file)
 					if err != nil {
 						return fmt.Errorf("failed to create layer: %v", err)
@@ -291,6 +313,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
 					layers = append(layers, l)
 				}
 			}
+
 			if mf != nil {
 				log.Printf("manifest = %#v", mf)
 				for _, l := range mf.Layers {
@@ -320,7 +343,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
 			layers = append(layers, layer)
 		case "template", "system", "prompt":
 			fn(api.ProgressResponse{Status: fmt.Sprintf("creating model %s layer", c.Name)})
-			// remove the prompt layer if one exists
+			// remove the layer if one exists
 			mediaType := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
 			layers = removeLayerFromLayers(layers, mediaType)
 
@@ -382,7 +405,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
 
 	// Create a layer for the config object
 	fn(api.ProgressResponse{Status: "creating config layer"})
-	cfg, err := createConfigLayer(digests)
+	cfg, err := createConfigLayer(config, digests)
 	if err != nil {
 		return err
 	}
@@ -429,13 +452,13 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
 		}
 
 		e.opts.EmbeddingOnly = true
-		llm, err := llama.New(e.model, e.opts)
+		llmModel, err := llm.New(e.model, e.opts)
 		if err != nil {
 			return nil, fmt.Errorf("load model to generate embeddings: %v", err)
 		}
 		defer func() {
-			if llm != nil {
-				llm.Close()
+			if llmModel != nil {
+				llmModel.Close()
 			}
 		}()
 
@@ -479,7 +502,7 @@ func embeddingLayers(e EmbeddingParams) ([]*LayerReader, error) {
 						Total:     len(data) - 1,
 						Completed: i,
 					})
-					embed, err := llm.Embedding(d)
+					embed, err := llmModel.Embedding(d)
 					if err != nil {
 						log.Printf("failed to generate embedding for '%s' line %d: %v", filePath, i+1, err)
 						continue
@@ -675,7 +698,7 @@ func getLayerDigests(layers []*LayerReader) ([]string, error) {
 // CreateLayer creates a Layer object from a given file
 func CreateLayer(f io.ReadSeeker) (*LayerReader, error) {
 	digest, size := GetSHA256Digest(f)
-	f.Seek(0, 0)
+	f.Seek(0, io.SeekStart)
 
 	layer := &LayerReader{
 		Layer: Layer{
@@ -767,10 +790,6 @@ func DeleteModel(name string) error {
 		return err
 	}
 
-	if err != nil {
-		return err
-	}
-
 	// only delete the files which are still in the deleteMap
 	for k, v := range deleteMap {
 		if v {
@@ -969,15 +988,10 @@ func pullModelManifest(mp ModelPath, regOpts *RegistryOptions) (*ManifestV2, err
 	return m, err
 }
 
-func createConfigLayer(layers []string) (*LayerReader, error) {
-	// TODO change architecture and OS
-	config := ConfigV2{
-		Architecture: "arm64",
-		OS:           "linux",
-		RootFS: RootFS{
-			Type:    "layers",
-			DiffIDs: layers,
-		},
+func createConfigLayer(config ConfigV2, layers []string) (*LayerReader, error) {
+	config.RootFS = RootFS{
+		Type:    "layers",
+		DiffIDs: layers,
 	}
 
 	configJSON, err := json.Marshal(config)

+ 13 - 10
server/routes.go

@@ -21,14 +21,14 @@ import (
 	"gonum.org/v1/gonum/mat"
 
 	"github.com/jmorganca/ollama/api"
-	"github.com/jmorganca/ollama/llama"
+	"github.com/jmorganca/ollama/llm"
 	"github.com/jmorganca/ollama/vector"
 )
 
 var loaded struct {
 	mu sync.Mutex
 
-	llm        *llama.LLM
+	llm        llm.LLM
 	Embeddings []vector.Embedding
 
 	expireAt    time.Time
@@ -63,11 +63,16 @@ func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Dur
 			loaded.Embeddings = model.Embeddings
 		}
 
-		llm, err := llama.New(model.ModelPath, opts)
+		llmModel, err := llm.New(model.ModelPath, opts)
 		if err != nil {
 			return err
 		}
 
+		// set cache values before modifying opts
+		loaded.llm = llmModel
+		loaded.digest = model.Digest
+		loaded.options = opts
+
 		if opts.NumKeep < 0 {
 			promptWithSystem, err := model.Prompt(api.GenerateRequest{}, "")
 			if err != nil {
@@ -79,15 +84,13 @@ func load(model *Model, reqOpts map[string]interface{}, sessionDuration time.Dur
 				return err
 			}
 
-			tokensWithSystem := llm.Encode(promptWithSystem)
-			tokensNoSystem := llm.Encode(promptNoSystem)
+			tokensWithSystem := llmModel.Encode(promptWithSystem)
+			tokensNoSystem := llmModel.Encode(promptNoSystem)
 
-			llm.NumKeep = len(tokensWithSystem) - len(tokensNoSystem) + 1
-		}
+			opts.NumKeep = len(tokensWithSystem) - len(tokensNoSystem) + 1
 
-		loaded.llm = llm
-		loaded.digest = model.Digest
-		loaded.options = opts
+			llmModel.SetOptions(opts)
+		}
 	}
 	loaded.expireAt = time.Now().Add(sessionDuration)