llama-sampling.h 3.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. /**
  2. * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #pragma once
  27. #include "llama-impl.h"
  28. struct llama_sampling {
  29. llama_sampling(int32_t n_vocab) : n_vocab(n_vocab) {}
  30. std::mt19937 rng;
  31. int32_t n_vocab = 0;
  32. mutable int64_t t_sample_us = 0;
  33. mutable int32_t n_sample = 0;
  34. void reset_timings() const {
  35. t_sample_us = 0;
  36. n_sample = 0;
  37. }
  38. };
  39. //
  40. // internal API
  41. //
  42. void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed);
  43. void llama_sample_softmax_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
  44. void llama_sample_top_k_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep);
  45. void llama_sample_top_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
  46. void llama_sample_min_p_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
  47. void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep);
  48. void llama_sample_typical_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep);
  49. void llama_sample_entropy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val);
  50. void llama_sample_temp_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float temp);
  51. void llama_sample_repetition_penalties_impl(
  52. struct llama_sampling * smpl,
  53. llama_token_data_array * candidates,
  54. const llama_token * last_tokens,
  55. size_t penalty_last_n,
  56. float penalty_repeat,
  57. float penalty_freq,
  58. float penalty_present);
  59. void llama_sample_apply_guidance_impl(
  60. struct llama_sampling * smpl,
  61. float * logits,
  62. float * logits_guidance,
  63. float scale);
  64. llama_token llama_sample_token_mirostat_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu);
  65. llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu);
  66. llama_token llama_sample_token_greedy_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);
  67. llama_token llama_sample_token_with_rng_impl (struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng);
  68. llama_token llama_sample_token_impl (struct llama_sampling * smpl, llama_token_data_array * candidates);