llm.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. package llm
  2. // #cgo CPPFLAGS: -Illama.cpp/ggml/include
  3. // #cgo LDFLAGS: -lllama -lggml -lstdc++ -lpthread
  4. // #cgo darwin,arm64 LDFLAGS: -L${SRCDIR}/build/darwin/arm64_static -L${SRCDIR}/build/darwin/arm64_static/src -L${SRCDIR}/build/darwin/arm64_static/ggml/src -framework Accelerate -framework Metal
  5. // #cgo darwin,amd64 LDFLAGS: -L${SRCDIR}/build/darwin/x86_64_static -L${SRCDIR}/build/darwin/x86_64_static/src -L${SRCDIR}/build/darwin/x86_64_static/ggml/src
  6. // #cgo windows,amd64 LDFLAGS: -static-libstdc++ -static-libgcc -static -L${SRCDIR}/build/windows/amd64_static -L${SRCDIR}/build/windows/amd64_static/src -L${SRCDIR}/build/windows/amd64_static/ggml/src
  7. // #cgo windows,arm64 LDFLAGS: -static-libstdc++ -static-libgcc -static -L${SRCDIR}/build/windows/arm64_static -L${SRCDIR}/build/windows/arm64_static/src -L${SRCDIR}/build/windows/arm64_static/ggml/src
  8. // #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/linux/x86_64_static -L${SRCDIR}/build/linux/x86_64_static/src -L${SRCDIR}/build/linux/x86_64_static/ggml/src
  9. // #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src
  10. // #include <stdlib.h>
  11. // #include <stdatomic.h>
  12. // #include "llama.h"
  13. // bool update_quantize_progress(float progress, void* data) {
  14. // atomic_int* atomicData = (atomic_int*)data;
  15. // int intProgress = *((int*)&progress);
  16. // atomic_store(atomicData, intProgress);
  17. // return true;
  18. // }
  19. import "C"
  20. import (
  21. "fmt"
  22. "sync/atomic"
  23. "time"
  24. "unsafe"
  25. "github.com/ollama/ollama/api"
  26. )
  27. // SystemInfo is an unused example of calling llama.cpp functions using CGo
  28. func SystemInfo() string {
  29. return C.GoString(C.llama_print_system_info())
  30. }
  31. func Quantize(infile, outfile string, ftype fileType, fn func(resp api.ProgressResponse), tensorCount int) error {
  32. cinfile := C.CString(infile)
  33. defer C.free(unsafe.Pointer(cinfile))
  34. coutfile := C.CString(outfile)
  35. defer C.free(unsafe.Pointer(coutfile))
  36. params := C.llama_model_quantize_default_params()
  37. params.nthread = -1
  38. params.ftype = ftype.Value()
  39. // Initialize "global" to store progress
  40. store := (*int32)(C.malloc(C.sizeof_int))
  41. defer C.free(unsafe.Pointer(store))
  42. // Initialize store value, e.g., setting initial progress to 0
  43. atomic.StoreInt32(store, 0)
  44. params.quantize_callback_data = unsafe.Pointer(store)
  45. params.quantize_callback = (C.llama_progress_callback)(C.update_quantize_progress)
  46. ticker := time.NewTicker(30 * time.Millisecond)
  47. done := make(chan struct{})
  48. defer close(done)
  49. go func() {
  50. defer ticker.Stop()
  51. for {
  52. select {
  53. case <-ticker.C:
  54. progressInt := atomic.LoadInt32(store)
  55. progress := *(*float32)(unsafe.Pointer(&progressInt))
  56. fn(api.ProgressResponse{
  57. Status: fmt.Sprintf("quantizing model tensors %d/%d", int(progress), tensorCount),
  58. Quantize: "quant",
  59. })
  60. fmt.Println("Progress: ", progress)
  61. case <-done:
  62. fn(api.ProgressResponse{
  63. Status: fmt.Sprintf("quantizing model tensors %d/%d", tensorCount, tensorCount),
  64. Quantize: "quant",
  65. })
  66. return
  67. }
  68. }
  69. }()
  70. if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
  71. return fmt.Errorf("llama_model_quantize: %d", rc)
  72. }
  73. return nil
  74. }