Jelajahi Sumber

GGUF support (#441)

Bruce MacDonald 1 tahun lalu
induk
melakukan
09dd2aeff9

+ 8 - 3
.gitmodules

@@ -1,4 +1,9 @@
 [submodule "llm/llama.cpp/ggml"]
-	path = llm/llama.cpp/ggml
-	url = https://github.com/ggerganov/llama.cpp.git
-	ignore = dirty
+    path = llm/llama.cpp/ggml
+    url = https://github.com/ggerganov/llama.cpp.git
+    ignore = dirty
+    shallow = true
+[submodule "llm/llama.cpp/gguf"]
+    path = llm/llama.cpp/gguf
+    url = https://github.com/ggerganov/llama.cpp.git
+    shallow = true

+ 50 - 31
llm/ggml.go

@@ -3,12 +3,15 @@ package llm
 import (
 	"encoding/binary"
 	"errors"
-	"fmt"
 	"io"
+	"path"
+	"sync"
 )
 
 type ModelFamily string
 
+const ModelFamilyUnknown ModelFamily = "unknown"
+
 type ModelType uint32
 
 const (
@@ -57,18 +60,17 @@ type model interface {
 
 type container interface {
 	Name() string
-	Decode(io.Reader) error
+	Decode(io.Reader) (model, error)
 }
 
-type containerGGML struct {
-}
+type containerGGML struct{}
 
 func (c *containerGGML) Name() string {
 	return "ggml"
 }
 
-func (c *containerGGML) Decode(r io.Reader) error {
-	return nil
+func (c *containerGGML) Decode(r io.Reader) (model, error) {
+	return nil, nil
 }
 
 type containerGGMF struct {
@@ -79,18 +81,18 @@ func (c *containerGGMF) Name() string {
 	return "ggmf"
 }
 
-func (c *containerGGMF) Decode(r io.Reader) error {
+func (c *containerGGMF) Decode(r io.Reader) (model, error) {
 	var version uint32
 	binary.Read(r, binary.LittleEndian, &version)
 
 	switch version {
 	case 1:
 	default:
-		return errors.New("invalid version")
+		return nil, errors.New("invalid version")
 	}
 
 	c.version = version
-	return nil
+	return nil, nil
 }
 
 type containerGGJT struct {
@@ -101,18 +103,22 @@ func (c *containerGGJT) Name() string {
 	return "ggjt"
 }
 
-func (c *containerGGJT) Decode(r io.Reader) error {
+func (c *containerGGJT) Decode(r io.Reader) (model, error) {
 	var version uint32
 	binary.Read(r, binary.LittleEndian, &version)
 
 	switch version {
 	case 1, 2, 3:
 	default:
-		return errors.New("invalid version")
+		return nil, errors.New("invalid version")
 	}
 
 	c.version = version
-	return nil
+
+	// different model types may have different layouts for hyperparameters
+	var llama llamaModel
+	binary.Read(r, binary.LittleEndian, &llama.hyperparameters)
+	return &llama, nil
 }
 
 type containerLORA struct {
@@ -123,32 +129,51 @@ func (c *containerLORA) Name() string {
 	return "ggla"
 }
 
-func (c *containerLORA) Decode(r io.Reader) error {
+func (c *containerLORA) Decode(r io.Reader) (model, error) {
 	var version uint32
 	binary.Read(r, binary.LittleEndian, &version)
 
 	switch version {
 	case 1:
 	default:
-		return errors.New("invalid version")
+		return nil, errors.New("invalid version")
 	}
 
 	c.version = version
-	return nil
+	return nil, nil
+}
+
+var (
+	ggmlGPU = path.Join("llama.cpp", "ggml", "build", "gpu", "bin")
+	ggmlCPU = path.Join("llama.cpp", "ggml", "build", "cpu", "bin")
+)
+
+var (
+	ggmlInit       sync.Once
+	ggmlRunnerPath string
+)
+
+func ggmlRunner() ModelRunner {
+	ggmlInit.Do(func() {
+		ggmlRunnerPath = chooseRunner(ggmlGPU, ggmlCPU)
+	})
+	return ModelRunner{Path: ggmlRunnerPath}
 }
 
 const (
-	// / Magic constant for `ggml` files (unversioned).
+	// Magic constant for `ggml` files (unversioned).
 	FILE_MAGIC_GGML = 0x67676d6c
-	// / Magic constant for `ggml` files (versioned, ggmf).
+	// Magic constant for `ggml` files (versioned, ggmf).
 	FILE_MAGIC_GGMF = 0x67676d66
-	// / Magic constant for `ggml` files (versioned, ggjt).
+	// Magic constant for `ggml` files (versioned, ggjt).
 	FILE_MAGIC_GGJT = 0x67676a74
-	// / Magic constant for `ggla` files (LoRA adapter).
+	// Magic constant for `ggla` files (LoRA adapter).
 	FILE_MAGIC_GGLA = 0x67676C61
+	// Magic constant for `gguf` files (versioned, gguf)
+	FILE_MAGIC_GGUF = 0x46554747
 )
 
-func DecodeGGML(r io.ReadSeeker, hint ModelFamily) (*GGML, error) {
+func DecodeGGML(r io.ReadSeeker) (*GGML, error) {
 	var ggml GGML
 	binary.Read(r, binary.LittleEndian, &ggml.magic)
 
@@ -161,24 +186,18 @@ func DecodeGGML(r io.ReadSeeker, hint ModelFamily) (*GGML, error) {
 		ggml.container = &containerGGJT{}
 	case FILE_MAGIC_GGLA:
 		ggml.container = &containerLORA{}
+	case FILE_MAGIC_GGUF:
+		ggml.container = &containerGGUF{}
 	default:
 		return nil, errors.New("invalid file magic")
 	}
 
-	if err := ggml.Decode(r); err != nil {
+	model, err := ggml.Decode(r)
+	if err != nil {
 		return nil, err
 	}
 
-	// different model types may have different layouts for hyperparameters
-	switch hint {
-	case ModelFamilyLlama:
-		var llama llamaModel
-		binary.Read(r, binary.LittleEndian, &llama.hyperparameters)
-		ggml.model = &llama
-		// TODO: sanity check hyperparameters
-	default:
-		return nil, fmt.Errorf("unsupported model type: %s", hint)
-	}
+	ggml.model = model
 
 	// final model type
 	return &ggml, nil

+ 385 - 0
llm/gguf.go

@@ -0,0 +1,385 @@
+package llm
+
+import (
+	"bytes"
+	"encoding/binary"
+	"errors"
+	"fmt"
+	"io"
+	"log"
+	"path"
+	"sync"
+)
+
+type containerGGUF struct {
+	Version uint32
+
+	V1 struct {
+		NumTensor uint32
+		NumKV     uint32
+	}
+
+	V2 struct {
+		NumTensor uint64
+		NumKV     uint64
+	}
+}
+
+func (c *containerGGUF) Name() string {
+	return "gguf"
+}
+
+func (c *containerGGUF) Decode(r io.Reader) (model, error) {
+	binary.Read(r, binary.LittleEndian, &c.Version)
+
+	switch c.Version {
+	case 1:
+		binary.Read(r, binary.LittleEndian, &c.V1)
+	case 2:
+		binary.Read(r, binary.LittleEndian, &c.V2)
+	default:
+		return nil, errors.New("invalid version")
+	}
+
+	model := newGGUFModel(c)
+	if err := model.Decode(r); err != nil {
+		return nil, err
+	}
+
+	return model, nil
+}
+
+const (
+	ggufTypeUint8 uint32 = iota
+	ggufTypeInt8
+	ggufTypeUint16
+	ggufTypeInt16
+	ggufTypeUint32
+	ggufTypeInt32
+	ggufTypeFloat32
+	ggufTypeBool
+	ggufTypeString
+	ggufTypeArray
+	ggufTypeUint64
+	ggufTypeInt64
+	ggufTypeFloat64
+)
+
+type kv map[string]any
+
+type ggufModel struct {
+	*containerGGUF
+	kv
+}
+
+func newGGUFModel(container *containerGGUF) *ggufModel {
+	return &ggufModel{
+		containerGGUF: container,
+		kv:            make(kv),
+	}
+}
+
+func (llm *ggufModel) NumKV() uint64 {
+	if llm.Version == 1 {
+		return uint64(llm.V1.NumKV)
+	}
+
+	return llm.V2.NumKV
+}
+
+func (llm *ggufModel) ModelFamily() ModelFamily {
+	t, ok := llm.kv["general.architecture"].(string)
+	if ok {
+		return ModelFamily(t)
+	}
+
+	log.Printf("unknown model family: %T", t)
+	return ModelFamilyUnknown
+}
+
+func (llm *ggufModel) ModelType() ModelType {
+	switch llm.ModelFamily() {
+	case ModelFamilyLlama:
+		blocks, ok := llm.kv["llama.block_count"].(uint32)
+		if ok {
+			return ModelType(blocks)
+		}
+	}
+
+	return ModelType7B
+}
+
+func (llm *ggufModel) FileType() FileType {
+	switch llm.ModelFamily() {
+	case ModelFamilyLlama:
+		t, ok := llm.kv["general.file_type"].(uint32)
+		if ok {
+			return llamaFileType(t)
+		}
+	}
+
+	return llamaFileTypeF16
+}
+
+func (llm *ggufModel) Decode(r io.Reader) error {
+	read := llm.readString
+	if llm.Version == 1 {
+		read = llm.readStringV1
+	}
+
+	for i := 0; uint64(i) < llm.NumKV(); i++ {
+		k, err := read(r)
+		if err != nil {
+			return err
+		}
+
+		vtype := llm.readU32(r)
+
+		var v any
+		switch vtype {
+		case ggufTypeUint8:
+			v = llm.readU8(r)
+		case ggufTypeInt8:
+			v = llm.readI8(r)
+		case ggufTypeUint16:
+			v = llm.readU16(r)
+		case ggufTypeInt16:
+			v = llm.readI16(r)
+		case ggufTypeUint32:
+			v = llm.readU32(r)
+		case ggufTypeInt32:
+			v = llm.readI32(r)
+		case ggufTypeUint64:
+			v = llm.readU64(r)
+		case ggufTypeInt64:
+			v = llm.readI64(r)
+		case ggufTypeFloat32:
+			v = llm.readF32(r)
+		case ggufTypeFloat64:
+			v = llm.readF64(r)
+		case ggufTypeBool:
+			v = llm.readBool(r)
+		case ggufTypeString:
+			fn := llm.readString
+			if llm.Version == 1 {
+				fn = llm.readStringV1
+			}
+
+			s, err := fn(r)
+			if err != nil {
+				return err
+			}
+
+			v = s
+		case ggufTypeArray:
+			fn := llm.readArray
+			if llm.Version == 1 {
+				fn = llm.readArrayV1
+			}
+
+			a, err := fn(r)
+			if err != nil {
+				return err
+			}
+
+			v = a
+		default:
+			return fmt.Errorf("invalid type: %d", vtype)
+		}
+
+		llm.kv[k] = v
+	}
+
+	return nil
+}
+
+func (ggufModel) readU8(r io.Reader) uint8 {
+	var u8 uint8
+	binary.Read(r, binary.LittleEndian, &u8)
+	return u8
+}
+
+func (ggufModel) readI8(r io.Reader) int8 {
+	var i8 int8
+	binary.Read(r, binary.LittleEndian, &i8)
+	return i8
+}
+
+func (ggufModel) readU16(r io.Reader) uint16 {
+	var u16 uint16
+	binary.Read(r, binary.LittleEndian, &u16)
+	return u16
+}
+
+func (ggufModel) readI16(r io.Reader) int16 {
+	var i16 int16
+	binary.Read(r, binary.LittleEndian, &i16)
+	return i16
+}
+
+func (ggufModel) readU32(r io.Reader) uint32 {
+	var u32 uint32
+	binary.Read(r, binary.LittleEndian, &u32)
+	return u32
+}
+
+func (ggufModel) readI32(r io.Reader) int32 {
+	var i32 int32
+	binary.Read(r, binary.LittleEndian, &i32)
+	return i32
+}
+
+func (ggufModel) readU64(r io.Reader) uint64 {
+	var u64 uint64
+	binary.Read(r, binary.LittleEndian, &u64)
+	return u64
+}
+
+func (ggufModel) readI64(r io.Reader) int64 {
+	var i64 int64
+	binary.Read(r, binary.LittleEndian, &i64)
+	return i64
+}
+
+func (ggufModel) readF32(r io.Reader) float32 {
+	var f32 float32
+	binary.Read(r, binary.LittleEndian, &f32)
+	return f32
+}
+
+func (ggufModel) readF64(r io.Reader) float64 {
+	var f64 float64
+	binary.Read(r, binary.LittleEndian, &f64)
+	return f64
+}
+
+func (ggufModel) readBool(r io.Reader) bool {
+	var b bool
+	binary.Read(r, binary.LittleEndian, &b)
+	return b
+}
+
+func (ggufModel) readStringV1(r io.Reader) (string, error) {
+	var nameLength uint32
+	binary.Read(r, binary.LittleEndian, &nameLength)
+
+	var b bytes.Buffer
+	if _, err := io.CopyN(&b, r, int64(nameLength)); err != nil {
+		return "", err
+	}
+
+	// gguf v1 strings are null-terminated
+	b.Truncate(b.Len() - 1)
+
+	return b.String(), nil
+}
+
+func (llm ggufModel) readString(r io.Reader) (string, error) {
+	var nameLength uint64
+	binary.Read(r, binary.LittleEndian, &nameLength)
+
+	var b bytes.Buffer
+	if _, err := io.CopyN(&b, r, int64(nameLength)); err != nil {
+		return "", err
+	}
+
+	return b.String(), nil
+}
+
+func (llm *ggufModel) readArrayV1(r io.Reader) (arr []any, err error) {
+	atype := llm.readU32(r)
+	n := llm.readU32(r)
+
+	for i := 0; uint32(i) < n; i++ {
+		switch atype {
+		case ggufTypeUint8:
+			arr = append(arr, llm.readU8(r))
+		case ggufTypeInt8:
+			arr = append(arr, llm.readU8(r))
+		case ggufTypeUint16:
+			arr = append(arr, llm.readU16(r))
+		case ggufTypeInt16:
+			arr = append(arr, llm.readI16(r))
+		case ggufTypeUint32:
+			arr = append(arr, llm.readU32(r))
+		case ggufTypeInt32:
+			arr = append(arr, llm.readI32(r))
+		case ggufTypeFloat32:
+			arr = append(arr, llm.readF32(r))
+		case ggufTypeBool:
+			arr = append(arr, llm.readBool(r))
+		case ggufTypeString:
+			s, err := llm.readStringV1(r)
+			if err != nil {
+				return nil, err
+			}
+
+			arr = append(arr, s)
+		default:
+			return nil, fmt.Errorf("invalid array type: %d", atype)
+		}
+	}
+
+	return
+}
+
+func (llm *ggufModel) readArray(r io.Reader) (arr []any, err error) {
+	atype := llm.readU32(r)
+	n := llm.readU64(r)
+
+	for i := 0; uint64(i) < n; i++ {
+		switch atype {
+		case ggufTypeUint8:
+			arr = append(arr, llm.readU8(r))
+		case ggufTypeInt8:
+			arr = append(arr, llm.readU8(r))
+		case ggufTypeUint16:
+			arr = append(arr, llm.readU16(r))
+		case ggufTypeInt16:
+			arr = append(arr, llm.readI16(r))
+		case ggufTypeUint32:
+			arr = append(arr, llm.readU32(r))
+		case ggufTypeInt32:
+			arr = append(arr, llm.readI32(r))
+		case ggufTypeUint64:
+			arr = append(arr, llm.readU64(r))
+		case ggufTypeInt64:
+			arr = append(arr, llm.readI64(r))
+		case ggufTypeFloat32:
+			arr = append(arr, llm.readF32(r))
+		case ggufTypeFloat64:
+			arr = append(arr, llm.readF64(r))
+		case ggufTypeBool:
+			arr = append(arr, llm.readBool(r))
+		case ggufTypeString:
+			s, err := llm.readString(r)
+			if err != nil {
+				return nil, err
+			}
+
+			arr = append(arr, s)
+		default:
+			return nil, fmt.Errorf("invalid array type: %d", atype)
+		}
+	}
+
+	return
+}
+
+var (
+	ggufGPU = path.Join("llama.cpp", "gguf", "build", "gpu", "bin")
+	ggufCPU = path.Join("llama.cpp", "gguf", "build", "cpu", "bin")
+)
+
+var (
+	ggufInit       sync.Once
+	ggufRunnerPath string
+)
+
+func ggufRunner() ModelRunner {
+	ggufInit.Do(func() {
+		ggufRunnerPath = chooseRunner(ggufGPU, ggufCPU)
+	})
+
+	return ModelRunner{Path: ggufRunnerPath}
+}

+ 3 - 1
llm/llama.cpp/generate.go

@@ -4,10 +4,12 @@
 package llm
 
 //go:generate git submodule init
-//go:generate git submodule update --force ggml
+//go:generate git submodule update --force ggml gguf
 //go:generate git -C ggml apply ../ggml_patch/0001-add-detokenize-endpoint.patch
 //go:generate git -C ggml apply ../ggml_patch/0002-34B-model-support.patch
 //go:generate git -C ggml apply ../ggml_patch/0003-metal-fix-synchronization-in-new-matrix-multiplicati.patch
 //go:generate git -C ggml apply ../ggml_patch/0004-metal-add-missing-barriers-for-mul-mat-2699.patch
 //go:generate cmake --fresh -S ggml -B ggml/build/cpu -DLLAMA_K_QUANTS=on
 //go:generate cmake --build ggml/build/cpu --target server --config Release
+//go:generate cmake -S gguf -B gguf/build/cpu -DLLAMA_K_QUANTS=on
+//go:generate cmake --build gguf/build/cpu --target server --config Release

+ 3 - 1
llm/llama.cpp/generate_darwin_amd64.go

@@ -1,10 +1,12 @@
 package llm
 
 //go:generate git submodule init
-//go:generate git submodule update --force ggml
+//go:generate git submodule update --force ggml gguf
 //go:generate git -C ggml apply ../ggml_patch/0001-add-detokenize-endpoint.patch
 //go:generate git -C ggml apply ../ggml_patch/0002-34B-model-support.patch
 //go:generate git -C ggml apply ../ggml_patch/0003-metal-fix-synchronization-in-new-matrix-multiplicati.patch
 //go:generate git -C ggml apply ../ggml_patch/0004-metal-add-missing-barriers-for-mul-mat-2699.patch
 //go:generate cmake --fresh -S ggml -B ggml/build/cpu -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on -DCMAKE_SYSTEM_PROCESSOR=x86_64 -DCMAKE_OSX_ARCHITECTURES=x86_64 -DCMAKE_OSX_DEPLOYMENT_TARGET=11.0
 //go:generate cmake --build ggml/build/cpu --target server --config Release
+//go:generate cmake --fresh -S gguf -B gguf/build/cpu -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on -DCMAKE_SYSTEM_PROCESSOR=x86_64 -DCMAKE_OSX_ARCHITECTURES=x86_64 -DCMAKE_OSX_DEPLOYMENT_TARGET=11.0
+//go:generate cmake --build gguf/build/cpu --target server --config Release

+ 3 - 1
llm/llama.cpp/generate_darwin_arm64.go

@@ -1,10 +1,12 @@
 package llm
 
 //go:generate git submodule init
-//go:generate git submodule update --force ggml
+//go:generate git submodule update --force ggml gguf
 //go:generate git -C ggml apply ../ggml_patch/0001-add-detokenize-endpoint.patch
 //go:generate git -C ggml apply ../ggml_patch/0002-34B-model-support.patch
 //go:generate git -C ggml apply ../ggml_patch/0003-metal-fix-synchronization-in-new-matrix-multiplicati.patch
 //go:generate git -C ggml apply ../ggml_patch/0004-metal-add-missing-barriers-for-mul-mat-2699.patch
 //go:generate cmake --fresh -S ggml -B ggml/build/gpu -DLLAMA_METAL=on -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on -DCMAKE_SYSTEM_PROCESSOR=arm64 -DCMAKE_OSX_ARCHITECTURES=arm64 -DCMAKE_OSX_DEPLOYMENT_TARGET=11.0
 //go:generate cmake --build ggml/build/gpu --target server --config Release
+//go:generate cmake -S gguf -B gguf/build/gpu -DLLAMA_METAL=on -DLLAMA_ACCELERATE=on -DLLAMA_K_QUANTS=on -DCMAKE_SYSTEM_PROCESSOR=arm64 -DCMAKE_OSX_ARCHITECTURES=arm64 -DCMAKE_OSX_DEPLOYMENT_TARGET=11.0
+//go:generate cmake --build gguf/build/gpu --target server --config Release

+ 1 - 0
llm/llama.cpp/gguf

@@ -0,0 +1 @@
+Subproject commit 53885d7256909ec3e2176cdc2477f3986c15ec69

+ 72 - 94
llm/ggml_llama.go → llm/llama.go

@@ -20,27 +20,14 @@ import (
 	"runtime"
 	"strconv"
 	"strings"
-	"sync"
 	"time"
 
 	"github.com/jmorganca/ollama/api"
 )
 
-const ModelFamilyLlama ModelFamily = "llama"
-
-//go:embed llama.cpp/ggml/build/*/bin/*
+//go:embed llama.cpp/*/build/*/bin/*
 var llamaCppEmbed embed.FS
 
-var (
-	ggmlGPU = path.Join("llama.cpp", "ggml", "build", "gpu", "bin")
-	ggmlCPU = path.Join("llama.cpp", "ggml", "build", "cpu", "bin")
-)
-
-var (
-	ggmlInit       sync.Once
-	ggmlRunnerPath string
-)
-
 func osPath(llamaPath string) string {
 	if runtime.GOOS == "windows" {
 		return path.Join(llamaPath, "Release")
@@ -49,68 +36,61 @@ func osPath(llamaPath string) string {
 	return llamaPath
 }
 
-func initGGML() {
-	ggmlInit.Do(func() {
-		tmpDir, err := os.MkdirTemp("", "llama-*")
-		if err != nil {
-			log.Fatalf("llama.cpp: failed to create temp dir: %v", err)
-		}
+func chooseRunner(gpuPath, cpuPath string) string {
+	tmpDir, err := os.MkdirTemp("", "llama-*")
+	if err != nil {
+		log.Fatalf("llama.cpp: failed to create temp dir: %v", err)
+	}
 
-		llamaPath := osPath(ggmlGPU)
+	llamaPath := osPath(gpuPath)
+	if _, err := fs.Stat(llamaCppEmbed, llamaPath); err != nil {
+		llamaPath = osPath(cpuPath)
 		if _, err := fs.Stat(llamaCppEmbed, llamaPath); err != nil {
-			llamaPath = osPath(ggmlCPU)
-			if _, err := fs.Stat(llamaCppEmbed, llamaPath); err != nil {
-				log.Fatalf("llama.cpp executable not found")
-			}
+			log.Fatalf("llama.cpp executable not found")
 		}
+	}
 
-		files := []string{"server"}
-		switch runtime.GOOS {
-		case "windows":
-			files = []string{"server.exe"}
-		case "darwin":
-			if llamaPath == osPath(ggmlGPU) {
-				files = append(files, "ggml-metal.metal")
-			}
+	files := []string{"server"}
+	switch runtime.GOOS {
+	case "windows":
+		files = []string{"server.exe"}
+	case "darwin":
+		if llamaPath == osPath(gpuPath) {
+			files = append(files, "ggml-metal.metal")
 		}
+	}
 
-		for _, f := range files {
-			srcPath := path.Join(llamaPath, f)
-			destPath := filepath.Join(tmpDir, f)
-
-			srcFile, err := llamaCppEmbed.Open(srcPath)
-			if err != nil {
-				log.Fatalf("read llama.cpp %s: %v", f, err)
-			}
-			defer srcFile.Close()
+	for _, f := range files {
+		srcPath := path.Join(llamaPath, f)
+		destPath := filepath.Join(tmpDir, f)
 
-			destFile, err := os.OpenFile(destPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
-			if err != nil {
-				log.Fatalf("write llama.cpp %s: %v", f, err)
-			}
-			defer destFile.Close()
+		srcFile, err := llamaCppEmbed.Open(srcPath)
+		if err != nil {
+			log.Fatalf("read llama.cpp %s: %v", f, err)
+		}
+		defer srcFile.Close()
 
-			if _, err := io.Copy(destFile, srcFile); err != nil {
-				log.Fatalf("copy llama.cpp %s: %v", f, err)
-			}
+		destFile, err := os.OpenFile(destPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
+		if err != nil {
+			log.Fatalf("write llama.cpp %s: %v", f, err)
 		}
+		defer destFile.Close()
 
-		ggmlRunnerPath = filepath.Join(tmpDir, "server")
-		if runtime.GOOS == "windows" {
-			ggmlRunnerPath = filepath.Join(tmpDir, "server.exe")
+		if _, err := io.Copy(destFile, srcFile); err != nil {
+			log.Fatalf("copy llama.cpp %s: %v", f, err)
 		}
-	})
-}
+	}
 
-type ModelRunner struct {
-	Path string // path to the model runner executable
-}
+	runPath := filepath.Join(tmpDir, "server")
+	if runtime.GOOS == "windows" {
+		runPath = filepath.Join(tmpDir, "server.exe")
+	}
 
-func ggmlRunner() ModelRunner {
-	initGGML()
-	return ModelRunner{Path: ggmlRunnerPath}
+	return runPath
 }
 
+const ModelFamilyLlama ModelFamily = "llama"
+
 type llamaModel struct {
 	hyperparameters llamaHyperparameters
 }
@@ -229,6 +209,10 @@ type Running struct {
 	Cancel context.CancelFunc
 }
 
+type ModelRunner struct {
+	Path string // path to the model runner executable
+}
+
 type llama struct {
 	api.Options
 	Running
@@ -250,7 +234,6 @@ func newLlama(model string, adapters []string, runner ModelRunner, opts api.Opti
 	params := []string{
 		"--model", model,
 		"--ctx-size", fmt.Sprintf("%d", opts.NumCtx),
-		"--gqa", fmt.Sprintf("%d", opts.NumGQA),
 		"--rope-freq-base", fmt.Sprintf("%f", opts.RopeFrequencyBase),
 		"--rope-freq-scale", fmt.Sprintf("%f", opts.RopeFrequencyScale),
 		"--batch-size", fmt.Sprintf("%d", opts.NumBatch),
@@ -258,6 +241,10 @@ func newLlama(model string, adapters []string, runner ModelRunner, opts api.Opti
 		"--embedding",
 	}
 
+	if opts.NumGQA > 0 {
+		params = append(params, "--gqa", fmt.Sprintf("%d", opts.NumGQA))
+	}
+
 	if len(adapters) > 0 {
 		// TODO: applying multiple adapters is not supported by the llama.cpp server yet
 		params = append(params, "--lora", adapters[0])
@@ -289,17 +276,25 @@ func newLlama(model string, adapters []string, runner ModelRunner, opts api.Opti
 			runner.Path,
 			append(params, "--port", strconv.Itoa(port))...,
 		)
+
 		cmd.Stdout = os.Stderr
 		cmd.Stderr = os.Stderr
 
 		llm := &llama{Options: opts, Running: Running{Port: port, Cmd: cmd, Cancel: cancel}}
 
+		log.Print("starting llama.cpp server")
+		if err := llm.Cmd.Start(); err != nil {
+			log.Printf("error starting the external llama.cpp server: %v", err)
+			continue
+		}
+
 		if err := waitForServer(llm); err != nil {
 			log.Printf("error starting llama.cpp server: %v", err)
 			llm.Close()
 			// try again
 			continue
 		}
+
 		// server started successfully
 		return llm, nil
 	}
@@ -308,48 +303,31 @@ func newLlama(model string, adapters []string, runner ModelRunner, opts api.Opti
 }
 
 func waitForServer(llm *llama) error {
-	log.Print("starting llama.cpp server")
-	var stderr bytes.Buffer
-	llm.Cmd.Stderr = &stderr
-	err := llm.Cmd.Start()
-	if err != nil {
-		return fmt.Errorf("error starting the external llama.cpp server: %w", err)
-	}
-
-	exitChan := make(chan error, 1)
-
-	// the server is a long running process, watch for it exiting to keep track of something going wrong
-	go func() {
-		err := llm.Cmd.Wait()
-		log.Print(stderr.String())
-		exitChan <- err
-	}()
-
 	// wait for the server to start responding
 	start := time.Now()
 	expiresAt := time.Now().Add(30 * time.Second)
-	ticker := time.NewTicker(100 * time.Millisecond)
+	ticker := time.NewTicker(200 * time.Millisecond)
 
 	log.Print("waiting for llama.cpp server to start responding")
+	for range ticker.C {
+		if time.Now().After(expiresAt) {
+			return fmt.Errorf("llama.cpp server did not start within alloted time, retrying")
+		}
 
-	for {
-		select {
-		case <-ticker.C:
-			if time.Now().After(expiresAt) {
-				return fmt.Errorf("llama.cpp server did not start responding within 30 seconds, retrying")
-			}
-			if err := llm.Ping(context.Background()); err == nil {
-				log.Printf("llama.cpp server started in %f seconds", time.Since(start).Seconds())
-				return nil
-			}
-		case err := <-exitChan:
-			return fmt.Errorf("llama.cpp server exited unexpectedly: %w", err)
+		if err := llm.Ping(context.Background()); err == nil {
+			break
 		}
 	}
+
+	log.Printf("llama.cpp server started in %f seconds", time.Since(start).Seconds())
+	return nil
 }
 
 func (llm *llama) Close() {
-	llm.Running.Cmd.Cancel()
+	llm.Cancel()
+	if err := llm.Cmd.Wait(); err != nil {
+		log.Printf("llama.cpp server exited with error: %v", err)
+	}
 }
 
 func (llm *llama) SetOptions(opts api.Options) {
@@ -676,7 +654,7 @@ func (llm *llama) Embedding(ctx context.Context, input string) ([]float64, error
 
 // Ping checks that the server subprocess is still running and responding to requests
 func (llm *llama) Ping(ctx context.Context) error {
-	resp, err := http.Head(fmt.Sprintf("http://127.0.0.1:%d", llm.Running.Port))
+	resp, err := http.Head(fmt.Sprintf("http://127.0.0.1:%d", llm.Port))
 	if err != nil {
 		return fmt.Errorf("ping resp: %w", err)
 	}

+ 15 - 5
llm/llm.go

@@ -32,15 +32,22 @@ func New(model string, adapters []string, opts api.Options) (LLM, error) {
 	}
 	defer f.Close()
 
-	ggml, err := DecodeGGML(f, ModelFamilyLlama)
+	ggml, err := DecodeGGML(f)
 	if err != nil {
 		return nil, err
 	}
 
 	switch ggml.FileType().String() {
-	case "F32", "Q5_0", "Q5_1", "Q8_0":
+	case "Q8_0":
+		if ggml.Name() != "gguf" && opts.NumGPU != 0 {
+			// GGML Q8_0 do not support Metal API and will
+			// cause the runner to segmentation fault so disable GPU
+			log.Printf("WARNING: GPU disabled for F32, Q5_0, Q5_1, and Q8_0")
+			opts.NumGPU = 0
+		}
+	case "F32", "Q5_0", "Q5_1":
 		if opts.NumGPU != 0 {
-			// F32, F16, Q5_0, Q5_1, and Q8_0 do not support Metal API and will
+			// F32, Q5_0, Q5_1, and Q8_0 do not support Metal API and will
 			// cause the runner to segmentation fault so disable GPU
 			log.Printf("WARNING: GPU disabled for F32, Q5_0, Q5_1, and Q8_0")
 			opts.NumGPU = 0
@@ -75,8 +82,11 @@ func New(model string, adapters []string, opts api.Options) (LLM, error) {
 		}
 	}
 
-	switch ggml.ModelFamily() {
-	case ModelFamilyLlama:
+	switch ggml.Name() {
+	case "gguf":
+		opts.NumGQA = 0 // TODO: remove this when llama.cpp runners differ enough to need separate newLlama functions
+		return newLlama(model, adapters, ggufRunner(), opts)
+	case "ggml", "ggmf", "ggjt", "ggla":
 		return newLlama(model, adapters, ggmlRunner(), opts)
 	default:
 		return nil, fmt.Errorf("unknown ggml type: %s", ggml.ModelFamily())

+ 1 - 1
server/images.go

@@ -328,7 +328,7 @@ func CreateModel(ctx context.Context, name string, path string, fn func(resp api
 					}
 					defer file.Close()
 
-					ggml, err := llm.DecodeGGML(file, llm.ModelFamilyLlama)
+					ggml, err := llm.DecodeGGML(file)
 					if err != nil {
 						return err
 					}