|
@@ -9,14 +9,18 @@ package llm
|
|
|
// #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/linux/x86_64_static -L${SRCDIR}/build/linux/x86_64_static/src -L${SRCDIR}/build/linux/x86_64_static/ggml/src
|
|
|
// #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src
|
|
|
// #include <stdlib.h>
|
|
|
+// #include <stdatomic.h>
|
|
|
// #include "llama.h"
|
|
|
// bool update_quantize_progress(float progress, void* data) {
|
|
|
-// *((float*)data) = progress;
|
|
|
-// return true;
|
|
|
+// atomic_int* atomicData = (atomic_int*)data;
|
|
|
+// int intProgress = *((int*)&progress);
|
|
|
+// atomic_store(atomicData, intProgress);
|
|
|
+// return true;
|
|
|
// }
|
|
|
import "C"
|
|
|
import (
|
|
|
"fmt"
|
|
|
+ "sync/atomic"
|
|
|
"time"
|
|
|
"unsafe"
|
|
|
|
|
@@ -39,21 +43,17 @@ func Quantize(infile, outfile string, ftype fileType, fn func(resp api.ProgressR
|
|
|
params.nthread = -1
|
|
|
params.ftype = ftype.Value()
|
|
|
|
|
|
- // race condition with `store`
|
|
|
- // use atomicint/float idk yet
|
|
|
- // use set in the C.
|
|
|
-
|
|
|
// Initialize "global" to store progress
|
|
|
- store := C.malloc(C.sizeof_float)
|
|
|
- defer C.free(store)
|
|
|
+ store := (*int32)(C.malloc(C.sizeof_int))
|
|
|
+ defer C.free(unsafe.Pointer(store))
|
|
|
|
|
|
// Initialize store value, e.g., setting initial progress to 0
|
|
|
- *(*C.float)(store) = 0.0
|
|
|
+ atomic.StoreInt32(store, 0)
|
|
|
|
|
|
- params.quantize_callback_data = store
|
|
|
+ params.quantize_callback_data = unsafe.Pointer(store)
|
|
|
params.quantize_callback = (C.llama_progress_callback)(C.update_quantize_progress)
|
|
|
|
|
|
- ticker := time.NewTicker(60 * time.Millisecond)
|
|
|
+ ticker := time.NewTicker(30 * time.Millisecond)
|
|
|
done := make(chan struct{})
|
|
|
defer close(done)
|
|
|
|
|
@@ -62,11 +62,13 @@ func Quantize(infile, outfile string, ftype fileType, fn func(resp api.ProgressR
|
|
|
for {
|
|
|
select {
|
|
|
case <-ticker.C:
|
|
|
+ progressInt := atomic.LoadInt32(store)
|
|
|
+ progress := *(*float32)(unsafe.Pointer(&progressInt))
|
|
|
fn(api.ProgressResponse{
|
|
|
- Status: fmt.Sprintf("quantizing model tensors %d/%d", int(*((*C.float)(store))), tensorCount),
|
|
|
+ Status: fmt.Sprintf("quantizing model tensors %d/%d", int(progress), tensorCount),
|
|
|
Quantize: "quant",
|
|
|
})
|
|
|
- fmt.Println("Progress: ", *((*C.float)(store)))
|
|
|
+ fmt.Println("Progress: ", progress)
|
|
|
case <-done:
|
|
|
fn(api.ProgressResponse{
|
|
|
Status: fmt.Sprintf("quantizing model tensors %d/%d", tensorCount, tensorCount),
|