Bladeren bron

atomic for race

Josh Yan 9 maanden geleden
bovenliggende
commit
8476ef2bd8
1 gewijzigde bestanden met toevoegingen van 15 en 13 verwijderingen
  1. 15 13
      llm/llm.go

+ 15 - 13
llm/llm.go

@@ -9,14 +9,18 @@ package llm
 // #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/linux/x86_64_static -L${SRCDIR}/build/linux/x86_64_static/src -L${SRCDIR}/build/linux/x86_64_static/ggml/src
 // #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src
 // #include <stdlib.h>
+// #include <stdatomic.h>
 // #include "llama.h"
 // bool update_quantize_progress(float progress, void* data) {
-//  *((float*)data) = progress;
-// 	return true;
+//	atomic_int* atomicData = (atomic_int*)data;
+//  int intProgress = *((int*)&progress);
+//  atomic_store(atomicData, intProgress);
+//  return true;
 // }
 import "C"
 import (
 	"fmt"
+	"sync/atomic"
 	"time"
 	"unsafe"
 
@@ -39,21 +43,17 @@ func Quantize(infile, outfile string, ftype fileType, fn func(resp api.ProgressR
 	params.nthread = -1
 	params.ftype = ftype.Value()
 
-	// race condition with `store`
-	// use atomicint/float idk yet
-	// use set in the C.
-
 	// Initialize "global" to store progress
-	store := C.malloc(C.sizeof_float)
-	defer C.free(store)
+	store := (*int32)(C.malloc(C.sizeof_int))
+	defer C.free(unsafe.Pointer(store))
 
 	// Initialize store value, e.g., setting initial progress to 0
-	*(*C.float)(store) = 0.0
+	atomic.StoreInt32(store, 0)
 
-	params.quantize_callback_data = store
+	params.quantize_callback_data = unsafe.Pointer(store)
 	params.quantize_callback = (C.llama_progress_callback)(C.update_quantize_progress)
 
-	ticker := time.NewTicker(60 * time.Millisecond)
+	ticker := time.NewTicker(30 * time.Millisecond)
 	done := make(chan struct{})
 	defer close(done)
 
@@ -62,11 +62,13 @@ func Quantize(infile, outfile string, ftype fileType, fn func(resp api.ProgressR
 		for {
 			select {
 			case <-ticker.C:
+				progressInt := atomic.LoadInt32(store)
+				progress := *(*float32)(unsafe.Pointer(&progressInt))
 				fn(api.ProgressResponse{
-					Status:   fmt.Sprintf("quantizing model tensors %d/%d", int(*((*C.float)(store))), tensorCount),
+					Status:   fmt.Sprintf("quantizing model tensors %d/%d", int(progress), tensorCount),
 					Quantize: "quant",
 				})
-				fmt.Println("Progress: ", *((*C.float)(store)))
+				fmt.Println("Progress: ", progress)
 			case <-done:
 				fn(api.ProgressResponse{
 					Status:   fmt.Sprintf("quantizing model tensors %d/%d", tensorCount, tensorCount),