浏览代码

link metal

jmorganca 10 月之前
父节点
当前提交
295c202b2f
共有 3 个文件被更改,包括 44 次插入39 次删除
  1. 0 3
      llama/ggml-metal.h
  2. 40 36
      llama/llama.go
  3. 4 0
      llama/metal-darwin-arm64.m

+ 0 - 3
llama/ggml-metal.h

@@ -61,9 +61,6 @@ struct ggml_cgraph;
 extern "C" {
 extern "C" {
 #endif
 #endif
 
 
-const char* ggml_metallib_start;
-const char* ggml_metallib_end;
-
 //
 //
 // backend API
 // backend API
 // user-code should use only these functions
 // user-code should use only these functions

+ 40 - 36
llama/llama.go

@@ -1,40 +1,45 @@
 package llama
 package llama
 
 
-// #cgo CFLAGS: -std=c11 -DNDEBUG -DLOG_DISABLE_LOGS
-// #cgo CXXFLAGS: -std=c++11 -DNDEBUG -DLOG_DISABLE_LOGS
-// #cgo darwin,arm64 CFLAGS: -DGGML_USE_METAL -DGGML_USE_ACCELERATE -DGGML_METAL_EMBED_LIBRARY -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
-// #cgo darwin,arm64 CXXFLAGS: -DGGML_USE_METAL -DGGML_USE_ACCELERATE -DGGML_METAL_EMBED_LIBRARY -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
-// #cgo darwin,arm64 LDFLAGS: -framework Foundation -framework Metal -framework MetalKit -framework Accelerate
-// #cgo darwin,amd64 CFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
-// #cgo darwin,amd64 CXXFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
-// #cgo darwin,amd64 LDFLAGS: -framework Foundation
-// #cgo darwin,amd64,avx2 CFLAGS: -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
-// #cgo darwin,amd64,avx2 CXXFLAGS: -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
-// #cgo darwin,amd64,avx2 LDFLAGS: -framework Accelerate
-// #cgo linux CFLAGS: -D_GNU_SOURCE
-// #cgo linux CXXFLAGS: -D_GNU_SOURCE
-// #cgo windows CFLAGS: -Wno-discarded-qualifiers
-// #cgo windows LDFLAGS: -lmsvcrt
-// #cgo avx CFLAGS: -mavx
-// #cgo avx CXXFLAGS: -mavx
-// #cgo avx2 CFLAGS: -mavx2 -mfma
-// #cgo avx2 CXXFLAGS: -mavx2 -mfma
-// #cgo cuda CFLAGS: -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
-// #cgo cuda CXXFLAGS: -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
-// #cgo rocm CFLAGS: -DGGML_USE_CUDA -DGGML_USE_HIPBLAS -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
-// #cgo rocm CXXFLAGS: -DGGML_USE_CUDA -DGGML_USE_HIPBLAS -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
-// #cgo rocm LDFLAGS: -L${SRCDIR} -lggml_hipblas -lhipblas -lamdhip64 -lrocblas
-// #cgo windows,cuda LDFLAGS: -L${SRCDIR} -L"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.3/lib/x64" -lggml_cuda -lcuda -lcudart -lcublas -lcublasLt
-// #cgo windows,rocm LDFLAGS: -L${SRCDIR} -L"C:/Program Files/AMD/ROCm/5.7/lib" -lggml_hipblas -lhipblas -lamdhip64 -lrocblas
-// #cgo linux,cuda LDFLAGS: -L${SRCDIR} -L/usr/local/cuda/lib64 -lggml_cuda -lcuda -lcudart -lcublas -lcublasLt -lpthread -ldl -lrt
-// #cgo linux,rocm LDFLAGS: -L/opt/rocm/lib
-// #include <stdlib.h>
-// #include "llama.h"
-// #include "clip.h"
-// #include "llava.h"
-// #include "sampling_ext.h"
-//
-// bool llamaProgressCallback(float progress, void *user_data);
+/*
+#cgo CFLAGS: -std=c11 -DNDEBUG -DLOG_DISABLE_LOGS
+#cgo CXXFLAGS: -std=c++11 -DNDEBUG -DLOG_DISABLE_LOGS
+#cgo darwin,arm64 CFLAGS: -DGGML_USE_METAL -DGGML_USE_ACCELERATE -DGGML_METAL_EMBED_LIBRARY -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
+#cgo darwin,arm64 CXXFLAGS: -DGGML_USE_METAL -DGGML_USE_ACCELERATE -DGGML_METAL_EMBED_LIBRARY -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
+#cgo darwin,arm64 LDFLAGS: -framework Foundation -framework Metal -framework MetalKit -framework Accelerate
+#cgo darwin,amd64 CFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
+#cgo darwin,amd64 CXXFLAGS: -Wno-incompatible-pointer-types-discards-qualifiers
+#cgo darwin,amd64 LDFLAGS: -framework Foundation
+#cgo darwin,amd64,avx2 CFLAGS: -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
+#cgo darwin,amd64,avx2 CXXFLAGS: -DGGML_USE_ACCELERATE -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
+#cgo darwin,amd64,avx2 LDFLAGS: -framework Accelerate
+#cgo linux CFLAGS: -D_GNU_SOURCE
+#cgo linux CXXFLAGS: -D_GNU_SOURCE
+#cgo windows CFLAGS: -Wno-discarded-qualifiers
+#cgo windows LDFLAGS: -lmsvcrt
+#cgo avx CFLAGS: -mavx
+#cgo avx CXXFLAGS: -mavx
+#cgo avx2 CFLAGS: -mavx2 -mfma
+#cgo avx2 CXXFLAGS: -mavx2 -mfma
+#cgo cuda CFLAGS: -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
+#cgo cuda CXXFLAGS: -DGGML_USE_CUDA -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
+#cgo rocm CFLAGS: -DGGML_USE_CUDA -DGGML_USE_HIPBLAS -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
+#cgo rocm CXXFLAGS: -DGGML_USE_CUDA -DGGML_USE_HIPBLAS -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_CUDA_MMV_Y=1 -DGGML_BUILD=1
+#cgo rocm LDFLAGS: -L${SRCDIR} -lggml_hipblas -lhipblas -lamdhip64 -lrocblas
+#cgo windows,cuda LDFLAGS: -L${SRCDIR} -L"C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.3/lib/x64" -lggml_cuda -lcuda -lcudart -lcublas -lcublasLt
+#cgo windows,rocm LDFLAGS: -L${SRCDIR} -L"C:/Program Files/AMD/ROCm/5.7/lib" -lggml_hipblas -lhipblas -lamdhip64 -lrocblas
+#cgo linux,cuda LDFLAGS: -L${SRCDIR} -L/usr/local/cuda/lib64 -lggml_cuda -lcuda -lcudart -lcublas -lcublasLt -lpthread -ldl -lrt
+#cgo linux,rocm LDFLAGS: -L/opt/rocm/lib
+
+#include <stdlib.h>
+#include "llama.h"
+#include "clip.h"
+#include "llava.h"
+#include "sampling_ext.h"
+
+bool llamaProgressCallback(float progress, void *user_data);
+extern const char* ggml_metallib_start;
+extern const char* ggml_metallib_end;
+*/
 import "C"
 import "C"
 import (
 import (
 	_ "embed"
 	_ "embed"
@@ -52,7 +57,6 @@ var ggmlCommon string
 //go:embed ggml-metal.metal
 //go:embed ggml-metal.metal
 var ggmlMetal string
 var ggmlMetal string
 
 
-// TODO: write me somewhere else
 func init() {
 func init() {
 	metal := strings.ReplaceAll(ggmlMetal, `#include "ggml-common.h"`, ggmlCommon)
 	metal := strings.ReplaceAll(ggmlMetal, `#include "ggml-common.h"`, ggmlCommon)
 	fmt.Println(metal)
 	fmt.Println(metal)

+ 4 - 0
llama/metal-darwin-arm64.m

@@ -0,0 +1,4 @@
+#import <Foundation/Foundation.h>
+
+const char* ggml_metallib_start = NULL;
+const char* ggml_metallib_end = NULL;