llm.go 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. package llm
  2. // #cgo CFLAGS: -Illama.cpp -Illama.cpp/include -Illama.cpp/ggml/include
  3. // #cgo LDFLAGS: -lllama -lggml -lstdc++ -lpthread
  4. // #cgo darwin,arm64 LDFLAGS: -L${SRCDIR}/build/darwin/arm64_static -L${SRCDIR}/build/darwin/arm64_static/src -L${SRCDIR}/build/darwin/arm64_static/ggml/src -framework Accelerate -framework Metal
  5. // #cgo darwin,amd64 LDFLAGS: -L${SRCDIR}/build/darwin/x86_64_static -L${SRCDIR}/build/darwin/x86_64_static/src -L${SRCDIR}/build/darwin/x86_64_static/ggml/src
  6. // #cgo windows,amd64 LDFLAGS: -static-libstdc++ -static-libgcc -static -L${SRCDIR}/build/windows/amd64_static -L${SRCDIR}/build/windows/amd64_static/src -L${SRCDIR}/build/windows/amd64_static/ggml/src
  7. // #cgo windows,arm64 LDFLAGS: -static-libstdc++ -static-libgcc -static -L${SRCDIR}/build/windows/arm64_static -L${SRCDIR}/build/windows/arm64_static/src -L${SRCDIR}/build/windows/arm64_static/ggml/src
  8. // #cgo linux,amd64 LDFLAGS: -L${SRCDIR}/build/linux/x86_64_static -L${SRCDIR}/build/linux/x86_64_static/src -L${SRCDIR}/build/linux/x86_64_static/ggml/src
  9. // #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src
  10. // #include <stdlib.h>
  11. // #include "llama.h"
  12. // bool update_quantize_progress(float progress, void* data) {
  13. // *((float*)data) = progress;
  14. // return true;
  15. // }
  16. import "C"
  17. import (
  18. "fmt"
  19. "unsafe"
  20. "time"
  21. "github.com/ollama/ollama/api"
  22. )
  23. // SystemInfo is an unused example of calling llama.cpp functions using CGo
  24. func SystemInfo() string {
  25. return C.GoString(C.llama_print_system_info())
  26. }
  27. func Quantize(infile, outfile string, ftype fileType, fn func(resp api.ProgressResponse), tensorCount int) error { cinfile := C.CString(infile)
  28. defer C.free(unsafe.Pointer(cinfile))
  29. coutfile := C.CString(outfile)
  30. defer C.free(unsafe.Pointer(coutfile))
  31. params := C.llama_model_quantize_default_params()
  32. params.nthread = -1
  33. params.ftype = ftype.Value()
  34. // Initialize "global" to store progress
  35. store := C.malloc(C.sizeof_float)
  36. defer C.free(unsafe.Pointer(store))
  37. // Initialize store value, e.g., setting initial progress to 0
  38. *(*C.float)(store) = 0.0
  39. params.quantize_callback_data = store
  40. params.quantize_callback = (C.llama_progress_callback)(C.update_quantize_progress)
  41. ticker := time.NewTicker(60 * time.Millisecond)
  42. done := make(chan struct{})
  43. defer close(done)
  44. go func() {
  45. defer ticker.Stop()
  46. for {
  47. select {
  48. case <-ticker.C:
  49. fn(api.ProgressResponse{
  50. Status: fmt.Sprintf("quantizing model %d/%d", int(*((*C.float)(store))), tensorCount),
  51. Quantize: "quant",
  52. })
  53. fmt.Println("Progress: ", *((*C.float)(store)))
  54. case <-done:
  55. fn(api.ProgressResponse{
  56. Status: fmt.Sprintf("quantizing model %d/%d", tensorCount, tensorCount),
  57. Quantize: "quant",
  58. })
  59. return
  60. }
  61. }
  62. }()
  63. if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
  64. return fmt.Errorf("llama_model_quantize: %d", rc)
  65. }
  66. return nil
  67. }