llama-sampling.cpp 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661
  1. /**
  2. * llama.cpp - commit 8962422b1c6f9b8b15f5aeaea42600bcc2d44177 - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #include "llama-sampling.h"
  27. #include <algorithm>
  28. #include <cstring>
  29. #include <ctime>
  30. #include <cfloat>
  31. #include <numeric>
  32. #include <unordered_map>
  33. static void llama_log_softmax(float * array, size_t size) {
  34. float max_l = *std::max_element(array, array + size);
  35. float sum = 0.f;
  36. for (size_t i = 0; i < size; ++i) {
  37. float p = expf(array[i] - max_l);
  38. sum += p;
  39. array[i] = p;
  40. }
  41. for (size_t i = 0; i < size; ++i) {
  42. array[i] = logf(array[i] / sum);
  43. }
  44. }
  45. void llama_set_rng_seed_impl(struct llama_sampling * smpl, uint32_t seed) {
  46. if (seed == LLAMA_DEFAULT_SEED) {
  47. seed = time(NULL);
  48. }
  49. smpl->rng.seed(seed);
  50. }
  51. void llama_sample_softmax_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
  52. GGML_ASSERT(candidates->size > 0);
  53. const int64_t t_start_sample_us = ggml_time_us();
  54. // Sort the logits in descending order
  55. if (!candidates->sorted) {
  56. std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
  57. return a.logit > b.logit;
  58. });
  59. candidates->sorted = true;
  60. }
  61. float max_l = candidates->data[0].logit;
  62. float cum_sum = 0.0f;
  63. for (size_t i = 0; i < candidates->size; ++i) {
  64. float p = expf(candidates->data[i].logit - max_l);
  65. candidates->data[i].p = p;
  66. cum_sum += p;
  67. }
  68. for (size_t i = 0; i < candidates->size; ++i) {
  69. candidates->data[i].p /= cum_sum;
  70. }
  71. if (smpl) {
  72. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  73. }
  74. }
  75. void llama_sample_top_k_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, int32_t k, size_t min_keep) {
  76. // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
  77. // if (k >= (int32_t)candidates->size) {
  78. // return;
  79. // }
  80. const int64_t t_start_sample_us = ggml_time_us();
  81. if (k <= 0) {
  82. k = candidates->size;
  83. }
  84. k = std::max(k, (int) min_keep);
  85. k = std::min(k, (int) candidates->size);
  86. // Sort scores in descending order
  87. if (!candidates->sorted) {
  88. auto comp = [](const llama_token_data & a, const llama_token_data & b) {
  89. return a.logit > b.logit;
  90. };
  91. if (k <= 128) {
  92. std::partial_sort(candidates->data, candidates->data + k, candidates->data + candidates->size, comp);
  93. } else {
  94. constexpr int nbuckets = 128;
  95. constexpr float bucket_low = -10.0f;
  96. constexpr float bucket_high = 10.0f;
  97. constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
  98. constexpr float bucket_inter = -bucket_low * bucket_scale;
  99. std::vector<int> bucket_idx(candidates->size);
  100. std::vector<int> histo(nbuckets, 0);
  101. for (int i = 0; i < (int)candidates->size; ++i) {
  102. const float val = candidates->data[i].logit;
  103. int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
  104. ib = std::max(0, std::min(nbuckets-1, ib));
  105. bucket_idx[i] = ib;
  106. ++histo[ib];
  107. }
  108. int nhave = 0;
  109. int ib = nbuckets - 1;
  110. for ( ; ib >= 0; --ib) {
  111. nhave += histo[ib];
  112. if (nhave >= k) break;
  113. }
  114. std::vector<llama_token_data> tmp_tokens(nhave);
  115. auto ptr = tmp_tokens.data();
  116. std::vector<llama_token_data*> bucket_ptrs;
  117. bucket_ptrs.reserve(nbuckets - ib);
  118. for (int j = nbuckets - 1; j >= ib; --j) {
  119. bucket_ptrs.push_back(ptr);
  120. ptr += histo[j];
  121. }
  122. for (int i = 0; i < (int)candidates->size; ++i) {
  123. int j = bucket_idx[i];
  124. if (j >= ib) {
  125. *bucket_ptrs[nbuckets-1-j]++ = candidates->data[i];
  126. }
  127. }
  128. ptr = tmp_tokens.data();
  129. int ndone = 0;
  130. for (int j = nbuckets-1; j > ib; --j) {
  131. std::sort(ptr, ptr + histo[j], comp);
  132. ptr += histo[j];
  133. ndone += histo[j];
  134. }
  135. std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
  136. std::memcpy(candidates->data, tmp_tokens.data(), k*sizeof(llama_token_data));
  137. }
  138. candidates->sorted = true;
  139. }
  140. candidates->size = k;
  141. if (smpl) {
  142. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  143. }
  144. }
  145. void llama_sample_top_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
  146. if (p >= 1.0f) {
  147. return;
  148. }
  149. llama_sample_softmax_impl(smpl, candidates);
  150. const int64_t t_start_sample_us = ggml_time_us();
  151. // Compute the cumulative probabilities
  152. float cum_sum = 0.0f;
  153. size_t last_idx = candidates->size;
  154. for (size_t i = 0; i < candidates->size; ++i) {
  155. cum_sum += candidates->data[i].p;
  156. // Check if the running sum is at least p or if we have kept at least min_keep tokens
  157. // we set the last index to i+1 to indicate that the current iterate should be included in the set
  158. if (cum_sum >= p && i + 1 >= min_keep) {
  159. last_idx = i + 1;
  160. break;
  161. }
  162. }
  163. // Resize the output vector to keep only the top-p tokens
  164. candidates->size = last_idx;
  165. if (smpl) {
  166. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  167. }
  168. }
  169. void llama_sample_min_p_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
  170. if (p <= 0.0f || !candidates->size) {
  171. return;
  172. }
  173. const int64_t t_start_sample_us = ggml_time_us();
  174. bool min_p_applied = false;
  175. // if the candidates aren't sorted, try the unsorted implementation first
  176. if (!candidates->sorted) {
  177. std::vector<llama_token_data> filtered_tokens;
  178. float max_logit = -FLT_MAX;
  179. for (size_t i = 0; i < candidates->size; ++i) {
  180. max_logit = std::max(max_logit, candidates->data[i].logit);
  181. }
  182. const float min_logit = max_logit + logf(p); // min logit for p_i >= p * p_max
  183. for (size_t i = 0; i < candidates->size; ++i) {
  184. if (candidates->data[i].logit >= min_logit) {
  185. filtered_tokens.push_back(candidates->data[i]);
  186. }
  187. }
  188. // if we have enough values the operation was a success
  189. if (filtered_tokens.size() >= min_keep) {
  190. memcpy(candidates->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
  191. candidates->size = filtered_tokens.size();
  192. min_p_applied = true;
  193. }
  194. }
  195. // if the candidates are sorted or the unsorted implementation failed, use this implementation
  196. if (!min_p_applied) {
  197. // Sort the logits in descending order
  198. if (!candidates->sorted) {
  199. std::sort(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
  200. return a.logit > b.logit;
  201. });
  202. candidates->sorted = true;
  203. }
  204. const float min_logit = candidates->data[0].logit + logf(p); // min logit for p_i >= p * p_max
  205. size_t i = 1; // first token always matches
  206. for (; i < candidates->size; ++i) {
  207. if (candidates->data[i].logit < min_logit && i >= min_keep) {
  208. break; // prob too small
  209. }
  210. }
  211. // Resize the output vector to keep only the matching tokens
  212. candidates->size = i;
  213. }
  214. if (smpl) {
  215. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  216. }
  217. }
  218. void llama_sample_tail_free_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float z, size_t min_keep) {
  219. if (z >= 1.0f || candidates->size <= 2) {
  220. return;
  221. }
  222. llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
  223. const int64_t t_start_sample_us = ggml_time_us();
  224. // Compute the first and second derivatives
  225. std::vector<float> first_derivatives(candidates->size - 1);
  226. std::vector<float> second_derivatives(candidates->size - 2);
  227. for (size_t i = 0; i < first_derivatives.size(); ++i) {
  228. first_derivatives[i] = candidates->data[i].p - candidates->data[i + 1].p;
  229. }
  230. for (size_t i = 0; i < second_derivatives.size(); ++i) {
  231. second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
  232. }
  233. // Calculate absolute value of second derivatives
  234. for (size_t i = 0; i < second_derivatives.size(); ++i) {
  235. second_derivatives[i] = std::abs(second_derivatives[i]);
  236. }
  237. // Normalize the second derivatives
  238. {
  239. const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f);
  240. if (second_derivatives_sum > 1e-6f) {
  241. for (float & value : second_derivatives) {
  242. value /= second_derivatives_sum;
  243. }
  244. } else {
  245. for (float & value : second_derivatives) {
  246. value = 1.0f / second_derivatives.size();
  247. }
  248. }
  249. }
  250. float cum_sum = 0.0f;
  251. size_t last_idx = candidates->size;
  252. for (size_t i = 0; i < second_derivatives.size(); ++i) {
  253. cum_sum += second_derivatives[i];
  254. // Check if the running sum is greater than z or if we have kept at least min_keep tokens
  255. if (cum_sum > z && i >= min_keep) {
  256. last_idx = i;
  257. break;
  258. }
  259. }
  260. // Resize the output vector to keep only the tokens above the tail location
  261. candidates->size = last_idx;
  262. if (smpl) {
  263. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  264. }
  265. }
  266. void llama_sample_typical_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float p, size_t min_keep) {
  267. // Reference implementation:
  268. // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
  269. if (p >= 1.0f) {
  270. return;
  271. }
  272. // Compute the softmax of logits and calculate entropy
  273. llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
  274. const int64_t t_start_sample_us = ggml_time_us();
  275. float entropy = 0.0f;
  276. for (size_t i = 0; i < candidates->size; ++i) {
  277. entropy += -candidates->data[i].p * logf(candidates->data[i].p);
  278. }
  279. // Compute the absolute difference between negative log probability and entropy for each candidate
  280. std::vector<float> shifted_scores;
  281. for (size_t i = 0; i < candidates->size; ++i) {
  282. float shifted_score = fabsf(-logf(candidates->data[i].p) - entropy);
  283. shifted_scores.push_back(shifted_score);
  284. }
  285. // Sort tokens based on the shifted_scores and their corresponding indices
  286. std::vector<size_t> indices(candidates->size);
  287. std::iota(indices.begin(), indices.end(), 0);
  288. std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
  289. return shifted_scores[a] < shifted_scores[b];
  290. });
  291. // Compute the cumulative probabilities
  292. float cum_sum = 0.0f;
  293. size_t last_idx = indices.size();
  294. for (size_t i = 0; i < indices.size(); ++i) {
  295. size_t idx = indices[i];
  296. cum_sum += candidates->data[idx].p;
  297. // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
  298. if (cum_sum > p && i >= min_keep - 1) {
  299. last_idx = i + 1;
  300. break;
  301. }
  302. }
  303. // Resize the output vector to keep only the locally typical tokens
  304. std::vector<llama_token_data> new_candidates;
  305. for (size_t i = 0; i < last_idx; ++i) {
  306. size_t idx = indices[i];
  307. new_candidates.push_back(candidates->data[idx]);
  308. }
  309. // Replace the data in candidates with the new_candidates data
  310. std::copy(new_candidates.begin(), new_candidates.end(), candidates->data);
  311. candidates->size = new_candidates.size();
  312. candidates->sorted = false;
  313. if (smpl) {
  314. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  315. }
  316. }
  317. void llama_sample_entropy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float min_temp, float max_temp, float exponent_val) {
  318. const int64_t t_start_sample_us = ggml_time_us();
  319. // no need to do anything if there is only one (or zero) candidates
  320. if(candidates->size <= 1) {
  321. return;
  322. }
  323. // Calculate maximum possible entropy
  324. float max_entropy = -logf(1.0f / candidates->size);
  325. llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
  326. // Calculate entropy of the softmax probabilities
  327. float entropy = 0.0f;
  328. for (size_t i = 0; i < candidates->size; ++i) {
  329. float prob = candidates->data[i].p;
  330. if (prob > 0.0f) { // Ensure no log(0)
  331. entropy -= prob * logf(prob);
  332. }
  333. }
  334. // Normalize the entropy (max_entropy cannot be 0 here because we checked candidates->size != 1 above)
  335. float normalized_entropy = entropy / max_entropy;
  336. // Map the normalized entropy to the desired temperature range using the power function
  337. float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
  338. #ifdef DEBUG
  339. LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
  340. LLAMA_LOG_INFO("Entropy: %f\n", entropy);
  341. LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
  342. LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
  343. LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
  344. LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
  345. #endif
  346. // Apply the dynamically calculated temperature scaling
  347. for (size_t i = 0; i < candidates->size; ++i) {
  348. candidates->data[i].logit /= dyn_temp;
  349. }
  350. // Re-compute softmax probabilities after scaling logits with dynamic temperature
  351. double max_l_double = candidates->data[0].logit;
  352. double cum_sum_double = 0.0;
  353. for (size_t i = 0; i < candidates->size; ++i) {
  354. double p = exp(candidates->data[i].logit - max_l_double);
  355. candidates->data[i].p = p; // Store the scaled probability
  356. cum_sum_double += p;
  357. }
  358. for (size_t i = 0; i < candidates->size; ++i) {
  359. candidates->data[i].p /= cum_sum_double; // Re-normalize the probabilities
  360. }
  361. #ifdef DEBUG
  362. // Print the updated top 25 probabilities after temperature scaling
  363. LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
  364. for (size_t i = 0; i < 25 && i < candidates->size; ++i) {
  365. LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, candidates->data[i].p * 100.0f);
  366. }
  367. #endif
  368. if (smpl) {
  369. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  370. }
  371. }
  372. void llama_sample_temp_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float temp) {
  373. const int64_t t_start_sample_us = ggml_time_us();
  374. for (size_t i = 0; i < candidates->size; ++i) {
  375. candidates->data[i].logit /= temp;
  376. }
  377. if (smpl) {
  378. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  379. }
  380. }
  381. void llama_sample_repetition_penalties_impl(
  382. struct llama_sampling * smpl,
  383. llama_token_data_array * candidates,
  384. const llama_token * last_tokens,
  385. size_t penalty_last_n,
  386. float penalty_repeat,
  387. float penalty_freq,
  388. float penalty_present) {
  389. if (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)) {
  390. return;
  391. }
  392. const int64_t t_start_sample_us = ggml_time_us();
  393. // Create a frequency map to count occurrences of each token in last_tokens
  394. std::unordered_map<llama_token, int> token_count;
  395. for (size_t i = 0; i < penalty_last_n; ++i) {
  396. token_count[last_tokens[i]]++;
  397. }
  398. // Apply frequency and presence penalties to the candidates
  399. for (size_t i = 0; i < candidates->size; ++i) {
  400. const auto token_iter = token_count.find(candidates->data[i].id);
  401. if (token_iter == token_count.end()) {
  402. continue;
  403. }
  404. const int count = token_iter->second;
  405. // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
  406. // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
  407. if (candidates->data[i].logit <= 0) {
  408. candidates->data[i].logit *= penalty_repeat;
  409. } else {
  410. candidates->data[i].logit /= penalty_repeat;
  411. }
  412. candidates->data[i].logit -= float(count) * penalty_freq + float(count > 0) * penalty_present;
  413. }
  414. candidates->sorted = false;
  415. if (smpl) {
  416. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  417. }
  418. }
  419. void llama_sample_apply_guidance_impl(
  420. struct llama_sampling * smpl,
  421. float * logits,
  422. float * logits_guidance,
  423. float scale) {
  424. GGML_ASSERT(smpl);
  425. const auto t_start_sample_us = ggml_time_us();
  426. const auto n_vocab = smpl->n_vocab;
  427. llama_log_softmax(logits, n_vocab);
  428. llama_log_softmax(logits_guidance, n_vocab);
  429. for (int i = 0; i < n_vocab; ++i) {
  430. auto & l = logits[i];
  431. const auto & g = logits_guidance[i];
  432. l = scale * (l - g) + g;
  433. }
  434. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  435. }
  436. llama_token llama_sample_token_mirostat_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, int32_t m, float * mu) {
  437. GGML_ASSERT(smpl);
  438. const int32_t n_vocab = float(smpl->n_vocab);
  439. int64_t t_start_sample_us = ggml_time_us();
  440. llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
  441. // Estimate s_hat using the most probable m tokens
  442. float s_hat = 0.0;
  443. float sum_ti_bi = 0.0;
  444. float sum_ti_sq = 0.0;
  445. for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
  446. float t_i = logf(float(i + 2) / float(i + 1));
  447. float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
  448. sum_ti_bi += t_i * b_i;
  449. sum_ti_sq += t_i * t_i;
  450. }
  451. s_hat = sum_ti_bi / sum_ti_sq;
  452. // Compute k from the estimated s_hat and target surprise value
  453. float epsilon_hat = s_hat - 1;
  454. float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(n_vocab, -epsilon_hat)), 1 / s_hat);
  455. // Sample the next word X using top-k sampling
  456. llama_sample_top_k_impl((struct llama_sampling *) nullptr, candidates, int(k), 1);
  457. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  458. llama_token X = llama_sample_token_impl(smpl, candidates);
  459. t_start_sample_us = ggml_time_us();
  460. // Compute error as the difference between observed surprise and target surprise value
  461. size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
  462. return candidate.id == X;
  463. }));
  464. float observed_surprise = -log2f(candidates->data[X_idx].p);
  465. float e = observed_surprise - tau;
  466. // Update mu using the learning rate and error
  467. *mu = *mu - eta * e;
  468. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  469. return X;
  470. }
  471. llama_token llama_sample_token_mirostat_v2_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, float tau, float eta, float * mu) {
  472. int64_t t_start_sample_us;
  473. t_start_sample_us = ggml_time_us();
  474. llama_sample_softmax_impl(smpl, candidates);
  475. // Truncate the words with surprise values greater than mu
  476. candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
  477. return -log2f(candidate.p) > *mu;
  478. }));
  479. if (candidates->size == 0) {
  480. candidates->size = 1;
  481. }
  482. if (smpl) {
  483. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  484. }
  485. // Normalize the probabilities of the remaining words
  486. llama_sample_softmax_impl(smpl, candidates);
  487. // Sample the next word X from the remaining words
  488. llama_token X = llama_sample_token_impl(smpl, candidates);
  489. t_start_sample_us = ggml_time_us();
  490. // Compute error as the difference between observed surprise and target surprise value
  491. size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
  492. return candidate.id == X;
  493. }));
  494. float observed_surprise = -log2f(candidates->data[X_idx].p);
  495. float e = observed_surprise - tau;
  496. // Update mu using the learning rate and error
  497. *mu = *mu - eta * e;
  498. if (smpl) {
  499. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  500. }
  501. return X;
  502. }
  503. llama_token llama_sample_token_greedy_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
  504. const int64_t t_start_sample_us = ggml_time_us();
  505. // Find max element
  506. auto * max_iter = std::max_element(candidates->data, candidates->data + candidates->size, [](const llama_token_data & a, const llama_token_data & b) {
  507. return a.logit < b.logit;
  508. });
  509. llama_token result = max_iter->id;
  510. if (smpl) {
  511. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  512. smpl->n_sample++;
  513. }
  514. return result;
  515. }
  516. llama_token llama_sample_token_with_rng_impl(struct llama_sampling * smpl, llama_token_data_array * candidates, std::mt19937 & rng) {
  517. GGML_ASSERT(smpl);
  518. const int64_t t_start_sample_us = ggml_time_us();
  519. llama_sample_softmax_impl((struct llama_sampling *) nullptr, candidates);
  520. std::vector<float> probs;
  521. probs.reserve(candidates->size);
  522. for (size_t i = 0; i < candidates->size; ++i) {
  523. probs.push_back(candidates->data[i].p);
  524. }
  525. std::discrete_distribution<> dist(probs.begin(), probs.end());
  526. int idx = dist(rng);
  527. llama_token result = candidates->data[idx].id;
  528. smpl->t_sample_us += ggml_time_us() - t_start_sample_us;
  529. smpl->n_sample++;
  530. return result;
  531. }
  532. llama_token llama_sample_token_impl(struct llama_sampling * smpl, llama_token_data_array * candidates) {
  533. return llama_sample_token_with_rng_impl(smpl, candidates, smpl->rng);
  534. }