sampling_ext.h 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. // TODO: this is a temporary wrapper to allow calling C++ code from CGo
  2. #ifndef LLAMA_SAMPLING_EXT_H
  3. #define LLAMA_SAMPLING_EXT_H
  4. #include "llama.h"
  5. #ifdef __cplusplus
  6. extern "C"
  7. {
  8. #endif
  9. struct llama_sampling_cparams
  10. {
  11. int32_t top_k;
  12. float top_p;
  13. float tfs_z;
  14. float typical_p;
  15. float temp;
  16. int32_t penalty_last_n;
  17. float penalty_repeat;
  18. float penalty_freq;
  19. float penalty_present;
  20. int32_t mirostat;
  21. float mirostat_tau;
  22. float mirostat_eta;
  23. bool penalize_nl;
  24. uint32_t seed;
  25. char *grammar;
  26. };
  27. struct llama_sampling_context *llama_sampling_cinit(struct llama_sampling_cparams *params);
  28. void llama_sampling_cfree(struct llama_sampling_context *ctx);
  29. void llama_sampling_creset(struct llama_sampling_context *ctx);
  30. llama_token llama_sampling_csample(
  31. struct llama_sampling_context *ctx_sampling,
  32. struct llama_context *ctx_main,
  33. struct llama_context *ctx_cfg,
  34. int idx);
  35. void llama_sampling_caccept(
  36. struct llama_sampling_context *ctx_sampling,
  37. struct llama_context *ctx_main,
  38. llama_token id,
  39. bool apply_grammar);
  40. #ifdef __cplusplus
  41. }
  42. #endif
  43. #endif // LLAMA_SAMPLING_EXT_H