llama-sampling.cpp 54 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733
  1. /**
  2. * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #include "llama-sampling.h"
  27. #include "llama-vocab.h"
  28. #include "llama-grammar.h"
  29. #include <algorithm>
  30. #include <cassert>
  31. #include <cfloat>
  32. #include <chrono>
  33. #include <cmath>
  34. #include <cstdlib>
  35. #include <cstring>
  36. #include <ctime>
  37. #include <numeric>
  38. #include <random>
  39. #include <unordered_map>
  40. static int llama_sample_dist(llama_token_data_array * cur_p, std::mt19937 & rng) {
  41. // iterator for the probabilities
  42. #ifdef __GNUC__
  43. #pragma GCC diagnostic push
  44. #pragma GCC diagnostic ignored "-Wunused-local-typedefs"
  45. #endif
  46. struct probs_iterator {
  47. typedef std::input_iterator_tag iterator_category;
  48. typedef float value_type;
  49. typedef float * pointer;
  50. typedef float & reference;
  51. typedef ptrdiff_t difference_type;
  52. const llama_token_data * data;
  53. bool operator==(const probs_iterator & other) const { return data == other.data; }
  54. bool operator!=(const probs_iterator & other) const { return data != other.data; }
  55. const float & operator*() const { return data->p; }
  56. probs_iterator & operator++() { ++data; return *this; }
  57. probs_iterator operator++(int) { probs_iterator tmp = *this; ++data; return tmp; }
  58. };
  59. #ifdef __GNUC__
  60. #pragma GCC diagnostic pop
  61. #endif
  62. std::discrete_distribution<int> dist(probs_iterator{cur_p->data}, probs_iterator{cur_p->data + cur_p->size});
  63. return dist(rng);
  64. }
  65. /*
  66. static void llama_log_softmax(float * array, size_t size) {
  67. float max_l = *std::max_element(array, array + size);
  68. float sum = 0.f;
  69. for (size_t i = 0; i < size; ++i) {
  70. float p = expf(array[i] - max_l);
  71. sum += p;
  72. array[i] = p;
  73. }
  74. for (size_t i = 0; i < size; ++i) {
  75. array[i] = logf(array[i] / sum);
  76. }
  77. }
  78. */
  79. static void llama_sampler_softmax_impl(llama_token_data_array * cur_p) {
  80. GGML_ASSERT(cur_p->size > 0);
  81. // Sort the logits in descending order
  82. if (!cur_p->sorted) {
  83. std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
  84. return a.logit > b.logit;
  85. });
  86. cur_p->sorted = true;
  87. }
  88. float max_l = cur_p->data[0].logit;
  89. float cum_sum = 0.0f;
  90. for (size_t i = 0; i < cur_p->size; ++i) {
  91. float p = expf(cur_p->data[i].logit - max_l);
  92. cur_p->data[i].p = p;
  93. cum_sum += p;
  94. }
  95. for (size_t i = 0; i < cur_p->size; ++i) {
  96. cur_p->data[i].p /= cum_sum;
  97. }
  98. }
  99. static void llama_sampler_top_k_impl(llama_token_data_array * cur_p, int32_t k) {
  100. // TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
  101. // if (k >= (int32_t)cur_p->size) {
  102. // return;
  103. // }
  104. if (k <= 0) {
  105. k = cur_p->size;
  106. }
  107. k = std::min(k, (int) cur_p->size);
  108. // Sort scores in descending order
  109. if (!cur_p->sorted) {
  110. auto comp = [](const llama_token_data & a, const llama_token_data & b) {
  111. return a.logit > b.logit;
  112. };
  113. if (k <= 128) {
  114. std::partial_sort(cur_p->data, cur_p->data + k, cur_p->data + cur_p->size, comp);
  115. } else {
  116. constexpr int nbuckets = 128;
  117. constexpr float bucket_low = -10.0f;
  118. constexpr float bucket_high = 10.0f;
  119. constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
  120. constexpr float bucket_inter = -bucket_low * bucket_scale;
  121. std::vector<int> bucket_idx(cur_p->size);
  122. std::vector<int> histo(nbuckets, 0);
  123. for (int i = 0; i < (int)cur_p->size; ++i) {
  124. const float val = cur_p->data[i].logit;
  125. int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
  126. ib = std::max(0, std::min(nbuckets-1, ib));
  127. bucket_idx[i] = ib;
  128. ++histo[ib];
  129. }
  130. int nhave = 0;
  131. int ib = nbuckets - 1;
  132. for ( ; ib >= 0; --ib) {
  133. nhave += histo[ib];
  134. if (nhave >= k) {
  135. break;
  136. }
  137. }
  138. std::vector<llama_token_data> tmp_tokens(nhave);
  139. auto * ptr = tmp_tokens.data();
  140. std::vector<llama_token_data*> bucket_ptrs;
  141. bucket_ptrs.reserve(nbuckets - ib);
  142. for (int j = nbuckets - 1; j >= ib; --j) {
  143. bucket_ptrs.push_back(ptr);
  144. ptr += histo[j];
  145. }
  146. for (int i = 0; i < (int)cur_p->size; ++i) {
  147. int j = bucket_idx[i];
  148. if (j >= ib) {
  149. *bucket_ptrs[nbuckets-1-j]++ = cur_p->data[i];
  150. }
  151. }
  152. ptr = tmp_tokens.data();
  153. int ndone = 0;
  154. for (int j = nbuckets-1; j > ib; --j) {
  155. std::sort(ptr, ptr + histo[j], comp);
  156. ptr += histo[j];
  157. ndone += histo[j];
  158. }
  159. std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
  160. std::memcpy(cur_p->data, tmp_tokens.data(), k*sizeof(llama_token_data));
  161. }
  162. cur_p->sorted = true;
  163. }
  164. cur_p->size = k;
  165. }
  166. static uint32_t get_rng_seed(uint32_t seed) {
  167. if (seed == LLAMA_DEFAULT_SEED) {
  168. // use system clock if std::random_device is not a true RNG
  169. static bool is_rd_prng = std::random_device().entropy() == 0;
  170. if (is_rd_prng) {
  171. return (uint32_t) std::chrono::system_clock::now().time_since_epoch().count();
  172. }
  173. std::random_device rd;
  174. return rd();
  175. }
  176. return seed;
  177. }
  178. // llama_sampler API
  179. const char * llama_sampler_name(const struct llama_sampler * smpl) {
  180. if (!smpl->iface) {
  181. return "(null)";
  182. }
  183. return smpl->iface->name(smpl);
  184. }
  185. void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
  186. if (smpl->iface->accept) {
  187. smpl->iface->accept(smpl, token);
  188. }
  189. }
  190. void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
  191. GGML_ASSERT(smpl->iface->apply);
  192. smpl->iface->apply(smpl, cur_p);
  193. }
  194. void llama_sampler_reset(struct llama_sampler * smpl) {
  195. if (smpl->iface->reset) {
  196. smpl->iface->reset(smpl);
  197. }
  198. }
  199. struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
  200. if (smpl->iface->clone) {
  201. return smpl->iface->clone(smpl);
  202. }
  203. if (smpl->ctx == nullptr) {
  204. return new llama_sampler {
  205. /* .iface = */ smpl->iface,
  206. /* .ctx = */ nullptr,
  207. };
  208. }
  209. GGML_ABORT("the sampler does not support cloning");
  210. }
  211. void llama_sampler_free(struct llama_sampler * smpl) {
  212. if (smpl == nullptr) {
  213. return;
  214. }
  215. if (smpl->iface->free) {
  216. smpl->iface->free(smpl);
  217. }
  218. delete smpl;
  219. }
  220. llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) {
  221. const auto * logits = llama_get_logits_ith(ctx, idx);
  222. const int n_vocab = llama_n_vocab(llama_get_model(ctx));
  223. // TODO: do not allocate each time
  224. std::vector<llama_token_data> cur;
  225. cur.reserve(n_vocab);
  226. for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
  227. cur.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
  228. }
  229. llama_token_data_array cur_p = {
  230. /* .data = */ cur.data(),
  231. /* .size = */ cur.size(),
  232. /* .selected = */ -1,
  233. /* .sorted = */ false,
  234. };
  235. llama_sampler_apply(smpl, &cur_p);
  236. GGML_ASSERT(cur_p.selected >= 0 && cur_p.selected < (int32_t) cur_p.size);
  237. auto token = cur_p.data[cur_p.selected].id;
  238. llama_sampler_accept(smpl, token);
  239. return token;
  240. }
  241. // sampler chain
  242. static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
  243. return "chain";
  244. }
  245. static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token token) {
  246. auto * chain = (llama_sampler_chain *) smpl->ctx;
  247. time_meas tm(chain->t_sample_us, chain->params.no_perf);
  248. for (auto * smpl : chain->samplers) {
  249. llama_sampler_accept(smpl, token);
  250. }
  251. chain->n_sample++;
  252. }
  253. static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
  254. auto * chain = (llama_sampler_chain *) smpl->ctx;
  255. time_meas tm(chain->t_sample_us, chain->params.no_perf);
  256. for (auto * smpl : chain->samplers) {
  257. llama_sampler_apply(smpl, cur_p);
  258. }
  259. }
  260. static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
  261. auto * chain = (llama_sampler_chain *) smpl->ctx;
  262. for (auto * smpl : chain->samplers) {
  263. llama_sampler_reset(smpl);
  264. }
  265. chain->t_sample_us = 0;
  266. chain->n_sample = 0;
  267. }
  268. static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampler * smpl) {
  269. const auto * chain_src = (const llama_sampler_chain *) smpl->ctx;
  270. auto * result = llama_sampler_chain_init(chain_src->params);
  271. for (auto * smpl : chain_src->samplers) {
  272. llama_sampler_chain_add(result, llama_sampler_clone(smpl));
  273. }
  274. return result;
  275. }
  276. static void llama_sampler_chain_free(struct llama_sampler * smpl) {
  277. auto * chain = (llama_sampler_chain *) smpl->ctx;
  278. for (auto * smpl : chain->samplers) {
  279. llama_sampler_free(smpl);
  280. }
  281. delete chain;
  282. }
  283. static struct llama_sampler_i llama_sampler_chain_i = {
  284. /* .name = */ llama_sampler_chain_name,
  285. /* .accept = */ llama_sampler_chain_accept,
  286. /* .apply = */ llama_sampler_chain_apply,
  287. /* .reset = */ llama_sampler_chain_reset,
  288. /* .clone = */ llama_sampler_chain_clone,
  289. /* .free = */ llama_sampler_chain_free,
  290. };
  291. struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) {
  292. return new llama_sampler {
  293. /* .iface = */ &llama_sampler_chain_i,
  294. /* .ctx = */ new llama_sampler_chain {
  295. /* .params = */ params,
  296. /* .samplers = */ {},
  297. /* .t_sample_us = */ 0,
  298. /* .n_sample = */ 0,
  299. },
  300. };
  301. }
  302. void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
  303. auto * p = (llama_sampler_chain *) chain->ctx;
  304. p->samplers.push_back(smpl);
  305. }
  306. struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
  307. const auto * p = (const llama_sampler_chain *) chain->ctx;
  308. if (i < 0 || (size_t) i >= p->samplers.size()) {
  309. return nullptr;
  310. }
  311. return p->samplers[i];
  312. }
  313. struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
  314. auto * p = (llama_sampler_chain *) chain->ctx;
  315. if (i < 0 || (size_t) i >= p->samplers.size()) {
  316. return nullptr;
  317. }
  318. auto * result = p->samplers[i];
  319. p->samplers.erase(p->samplers.begin() + i);
  320. return result;
  321. }
  322. int llama_sampler_chain_n(const struct llama_sampler * chain) {
  323. const auto * p = (const llama_sampler_chain *) chain->ctx;
  324. return p->samplers.size();
  325. }
  326. //
  327. // samplers
  328. //
  329. // greedy
  330. static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) {
  331. return "greedy";
  332. }
  333. static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
  334. cur_p->selected = 0;
  335. for (size_t i = 1; i < cur_p->size; ++i) {
  336. if (cur_p->data[i].logit > cur_p->data[cur_p->selected].logit) {
  337. cur_p->selected = i;
  338. }
  339. }
  340. }
  341. static struct llama_sampler_i llama_sampler_greedy_i = {
  342. /* .name = */ llama_sampler_greedy_name,
  343. /* .accept = */ nullptr,
  344. /* .apply = */ llama_sampler_greedy_apply,
  345. /* .reset = */ nullptr,
  346. /* .clone = */ nullptr,
  347. /* .free = */ nullptr,
  348. };
  349. struct llama_sampler * llama_sampler_init_greedy() {
  350. return new llama_sampler {
  351. /* .iface = */ &llama_sampler_greedy_i,
  352. /* .ctx = */ nullptr,
  353. };
  354. }
  355. // dist
  356. struct llama_sampler_dist {
  357. const uint32_t seed;
  358. uint32_t seed_cur;
  359. std::mt19937 rng;
  360. };
  361. static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) {
  362. return "dist";
  363. }
  364. static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
  365. auto * ctx = (llama_sampler_dist *) smpl->ctx;
  366. cur_p->selected = llama_sample_dist(cur_p, ctx->rng);
  367. }
  368. static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) {
  369. const auto * ctx = (const llama_sampler_dist *) smpl->ctx;
  370. auto * result = llama_sampler_init_dist(ctx->seed);
  371. // copy the state
  372. {
  373. auto * result_ctx = (llama_sampler_dist *) result->ctx;
  374. result_ctx->rng = ctx->rng;
  375. }
  376. return result;
  377. }
  378. static void llama_sampler_dist_reset(struct llama_sampler * smpl) {
  379. auto * ctx = (llama_sampler_dist *) smpl->ctx;
  380. ctx->seed_cur = get_rng_seed(ctx->seed);
  381. ctx->rng.seed(ctx->seed_cur);
  382. }
  383. static void llama_sampler_dist_free(struct llama_sampler * smpl) {
  384. delete (llama_sampler_dist *) smpl->ctx;
  385. }
  386. static struct llama_sampler_i llama_sampler_dist_i = {
  387. /* .name = */ llama_sampler_dist_name,
  388. /* .accept = */ nullptr,
  389. /* .apply = */ llama_sampler_dist_apply,
  390. /* .reset = */ llama_sampler_dist_reset,
  391. /* .clone = */ llama_sampler_dist_clone,
  392. /* .free = */ llama_sampler_dist_free,
  393. };
  394. struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
  395. auto seed_cur = get_rng_seed(seed);
  396. return new llama_sampler {
  397. /* .iface = */ &llama_sampler_dist_i,
  398. /* .ctx = */ new llama_sampler_dist {
  399. /* .seed = */ seed,
  400. /* .seed_cur = */ seed_cur,
  401. /* .rng = */ std::mt19937(seed_cur),
  402. },
  403. };
  404. }
  405. // softmax
  406. static const char * llama_sampler_softmax_name(const struct llama_sampler * /*smpl*/) {
  407. return "softmax";
  408. }
  409. static void llama_sampler_softmax_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) {
  410. llama_sampler_softmax_impl(cur_p);
  411. }
  412. static struct llama_sampler_i llama_sampler_softmax_i = {
  413. /* .name = */ llama_sampler_softmax_name,
  414. /* .accept = */ nullptr,
  415. /* .apply = */ llama_sampler_softmax_apply,
  416. /* .reset = */ nullptr,
  417. /* .clone = */ nullptr,
  418. /* .free = */ nullptr,
  419. };
  420. struct llama_sampler * llama_sampler_init_softmax() {
  421. return new llama_sampler {
  422. /* .iface = */ &llama_sampler_softmax_i,
  423. /* .ctx = */ nullptr,
  424. };
  425. }
  426. // top-k
  427. struct llama_sampler_top_k {
  428. const int32_t k;
  429. };
  430. static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) {
  431. return "top-k";
  432. }
  433. static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
  434. const auto * ctx = (llama_sampler_top_k *) smpl->ctx;
  435. llama_sampler_top_k_impl(cur_p, ctx->k);
  436. }
  437. static struct llama_sampler * llama_sampler_top_k_clone(const struct llama_sampler * smpl) {
  438. const auto * ctx = (const llama_sampler_top_k *) smpl->ctx;
  439. return llama_sampler_init_top_k(ctx->k);
  440. }
  441. static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
  442. delete (llama_sampler_top_k *) smpl->ctx;
  443. }
  444. static struct llama_sampler_i llama_sampler_top_k_i = {
  445. /* .name = */ llama_sampler_top_k_name,
  446. /* .accept = */ nullptr,
  447. /* .apply = */ llama_sampler_top_k_apply,
  448. /* .reset = */ nullptr,
  449. /* .clone = */ llama_sampler_top_k_clone,
  450. /* .free = */ llama_sampler_top_k_free,
  451. };
  452. struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
  453. return new llama_sampler {
  454. /* .iface = */ &llama_sampler_top_k_i,
  455. /* .ctx = */ new llama_sampler_top_k {
  456. /* .k = */ k,
  457. },
  458. };
  459. }
  460. // top-p
  461. struct llama_sampler_top_p {
  462. const float p;
  463. const size_t min_keep;
  464. };
  465. static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) {
  466. return "top-p";
  467. }
  468. static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
  469. const auto * ctx = (llama_sampler_top_p *) smpl->ctx;
  470. if (ctx->p >= 1.0f) {
  471. return;
  472. }
  473. llama_sampler_softmax_impl(cur_p);
  474. // Compute the cumulative probabilities
  475. float cum_sum = 0.0f;
  476. size_t last_idx = cur_p->size;
  477. for (size_t i = 0; i < cur_p->size; ++i) {
  478. cum_sum += cur_p->data[i].p;
  479. // Check if the running sum is at least p or if we have kept at least min_keep tokens
  480. // we set the last index to i+1 to indicate that the current iterate should be included in the set
  481. if (cum_sum >= ctx->p && i + 1 >= ctx->min_keep) {
  482. last_idx = i + 1;
  483. break;
  484. }
  485. }
  486. // Resize the output vector to keep only the top-p tokens
  487. cur_p->size = last_idx;
  488. }
  489. static struct llama_sampler * llama_sampler_top_p_clone(const struct llama_sampler * smpl) {
  490. const auto * ctx = (const llama_sampler_top_p *) smpl->ctx;
  491. return llama_sampler_init_top_p(ctx->p, ctx->min_keep);
  492. }
  493. static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
  494. delete (llama_sampler_top_p *) smpl->ctx;
  495. }
  496. static struct llama_sampler_i llama_sampler_top_p_i = {
  497. /* .name = */ llama_sampler_top_p_name,
  498. /* .accept = */ nullptr,
  499. /* .apply = */ llama_sampler_top_p_apply,
  500. /* .reset = */ nullptr,
  501. /* .clone = */ llama_sampler_top_p_clone,
  502. /* .free = */ llama_sampler_top_p_free,
  503. };
  504. struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
  505. return new llama_sampler {
  506. /* .iface = */ &llama_sampler_top_p_i,
  507. /* .ctx = */ new llama_sampler_top_p {
  508. /* .p = */ p,
  509. /* .min_keep = */ min_keep,
  510. },
  511. };
  512. }
  513. // min-p
  514. struct llama_sampler_min_p {
  515. const float p;
  516. const size_t min_keep;
  517. };
  518. static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) {
  519. return "min-p";
  520. }
  521. static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
  522. const auto * ctx = (llama_sampler_min_p *) smpl->ctx;
  523. if (ctx->p <= 0.0f || !cur_p->size) {
  524. return;
  525. }
  526. bool min_p_applied = false;
  527. // if the cur_p aren't sorted, try the unsorted implementation first
  528. if (!cur_p->sorted) {
  529. std::vector<llama_token_data> filtered_tokens;
  530. float max_logit = -FLT_MAX;
  531. for (size_t i = 0; i < cur_p->size; ++i) {
  532. max_logit = std::max(max_logit, cur_p->data[i].logit);
  533. }
  534. const float min_logit = max_logit + logf(ctx->p); // min logit for p_i >= p * p_max
  535. for (size_t i = 0; i < cur_p->size; ++i) {
  536. if (cur_p->data[i].logit >= min_logit) {
  537. filtered_tokens.push_back(cur_p->data[i]);
  538. }
  539. }
  540. // if we have enough values the operation was a success
  541. if (filtered_tokens.size() >= ctx->min_keep) {
  542. memcpy(cur_p->data, filtered_tokens.data(), filtered_tokens.size()*sizeof(llama_token_data));
  543. cur_p->size = filtered_tokens.size();
  544. min_p_applied = true;
  545. }
  546. }
  547. // if the cur_p are sorted or the unsorted implementation failed, use this implementation
  548. if (!min_p_applied) {
  549. // Sort the logits in descending order
  550. if (!cur_p->sorted) {
  551. std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
  552. return a.logit > b.logit;
  553. });
  554. cur_p->sorted = true;
  555. }
  556. const float min_logit = cur_p->data[0].logit + logf(ctx->p); // min logit for p_i >= p * p_max
  557. size_t i = 1; // first token always matches
  558. for (; i < cur_p->size; ++i) {
  559. if (cur_p->data[i].logit < min_logit && i >= ctx->min_keep) {
  560. break; // prob too small
  561. }
  562. }
  563. // Resize the output vector to keep only the matching tokens
  564. cur_p->size = i;
  565. }
  566. }
  567. static struct llama_sampler * llama_sampler_min_p_clone(const struct llama_sampler * smpl) {
  568. const auto * ctx = (const llama_sampler_min_p *) smpl->ctx;
  569. return llama_sampler_init_min_p(ctx->p, ctx->min_keep);
  570. }
  571. static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
  572. delete (llama_sampler_min_p *) smpl->ctx;
  573. }
  574. static struct llama_sampler_i llama_sampler_min_p_i = {
  575. /* .name = */ llama_sampler_min_p_name,
  576. /* .accept = */ nullptr,
  577. /* .apply = */ llama_sampler_min_p_apply,
  578. /* .reset = */ nullptr,
  579. /* .clone = */ llama_sampler_min_p_clone,
  580. /* .free = */ llama_sampler_min_p_free,
  581. };
  582. struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
  583. return new llama_sampler {
  584. /* .iface = */ &llama_sampler_min_p_i,
  585. /* .ctx = */ new llama_sampler_min_p {
  586. /* .p = */ p,
  587. /* .min_keep = */ min_keep,
  588. },
  589. };
  590. }
  591. // tail-free
  592. struct llama_sampler_tail_free {
  593. const float z;
  594. const size_t min_keep;
  595. };
  596. static const char * llama_sampler_tail_free_name(const struct llama_sampler * /*smpl*/) {
  597. return "tail-free";
  598. }
  599. static void llama_sampler_tail_free_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
  600. const auto * ctx = (llama_sampler_tail_free *) smpl->ctx;
  601. if (ctx->z >= 1.0f || cur_p->size <= 2) {
  602. return;
  603. }
  604. llama_sampler_softmax_impl(cur_p);
  605. // Compute the first and second derivatives
  606. std::vector<float> first_derivatives(cur_p->size - 1);
  607. std::vector<float> second_derivatives(cur_p->size - 2);
  608. for (size_t i = 0; i < first_derivatives.size(); ++i) {
  609. first_derivatives[i] = cur_p->data[i].p - cur_p->data[i + 1].p;
  610. }
  611. for (size_t i = 0; i < second_derivatives.size(); ++i) {
  612. second_derivatives[i] = first_derivatives[i] - first_derivatives[i + 1];
  613. }
  614. // Calculate absolute value of second derivatives
  615. for (size_t i = 0; i < second_derivatives.size(); ++i) {
  616. second_derivatives[i] = std::abs(second_derivatives[i]);
  617. }
  618. // Normalize the second derivatives
  619. {
  620. const float second_derivatives_sum = std::accumulate(second_derivatives.begin(), second_derivatives.end(), 0.0f);
  621. if (second_derivatives_sum > 1e-6f) {
  622. for (float & value : second_derivatives) {
  623. value /= second_derivatives_sum;
  624. }
  625. } else {
  626. for (float & value : second_derivatives) {
  627. value = 1.0f / second_derivatives.size();
  628. }
  629. }
  630. }
  631. float cum_sum = 0.0f;
  632. size_t last_idx = cur_p->size;
  633. for (size_t i = 0; i < second_derivatives.size(); ++i) {
  634. cum_sum += second_derivatives[i];
  635. // Check if the running sum is greater than z or if we have kept at least min_keep tokens
  636. if (cum_sum > ctx->z && i >= ctx->min_keep) {
  637. last_idx = i;
  638. break;
  639. }
  640. }
  641. // Resize the output vector to keep only the tokens above the tail location
  642. cur_p->size = last_idx;
  643. }
  644. static struct llama_sampler * llama_sampler_tail_free_clone(const struct llama_sampler * smpl) {
  645. const auto * ctx = (const llama_sampler_tail_free *) smpl->ctx;
  646. return llama_sampler_init_tail_free(ctx->z, ctx->min_keep);
  647. }
  648. static void llama_sampler_tail_free_free(struct llama_sampler * smpl) {
  649. delete (llama_sampler_tail_free *) smpl->ctx;
  650. }
  651. static struct llama_sampler_i llama_sampler_tail_free_i = {
  652. /* .name = */ llama_sampler_tail_free_name,
  653. /* .accept = */ nullptr,
  654. /* .apply = */ llama_sampler_tail_free_apply,
  655. /* .reset = */ nullptr,
  656. /* .clone = */ llama_sampler_tail_free_clone,
  657. /* .free = */ llama_sampler_tail_free_free,
  658. };
  659. struct llama_sampler * llama_sampler_init_tail_free(float z, size_t min_keep) {
  660. return new llama_sampler {
  661. /* .iface = */ &llama_sampler_tail_free_i,
  662. /* .ctx = */ new llama_sampler_tail_free {
  663. /* .z = */ z,
  664. /*. min_keep = */ min_keep,
  665. },
  666. };
  667. }
  668. // typical
  669. struct llama_sampler_typical {
  670. const float p;
  671. const size_t min_keep;
  672. };
  673. static const char * llama_sampler_typical_name(const struct llama_sampler * /*smpl*/) {
  674. return "typical";
  675. }
  676. static void llama_sampler_typical_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
  677. const auto * ctx = (llama_sampler_typical *) smpl->ctx;
  678. // Reference implementation:
  679. // https://github.com/huggingface/transformers/compare/main...cimeister:typical-sampling:typical-pr
  680. if (ctx->p >= 1.0f) {
  681. return;
  682. }
  683. // Compute the softmax of logits and calculate entropy
  684. llama_sampler_softmax_impl(cur_p);
  685. float entropy = 0.0f;
  686. for (size_t i = 0; i < cur_p->size; ++i) {
  687. entropy += -cur_p->data[i].p * logf(cur_p->data[i].p);
  688. }
  689. // Compute the absolute difference between negative log probability and entropy for each candidate
  690. std::vector<float> shifted_scores;
  691. for (size_t i = 0; i < cur_p->size; ++i) {
  692. float shifted_score = fabsf(-logf(cur_p->data[i].p) - entropy);
  693. shifted_scores.push_back(shifted_score);
  694. }
  695. // Sort tokens based on the shifted_scores and their corresponding indices
  696. std::vector<size_t> indices(cur_p->size);
  697. std::iota(indices.begin(), indices.end(), 0);
  698. std::sort(indices.begin(), indices.end(), [&](size_t a, size_t b) {
  699. return shifted_scores[a] < shifted_scores[b];
  700. });
  701. // Compute the cumulative probabilities
  702. float cum_sum = 0.0f;
  703. size_t last_idx = indices.size();
  704. for (size_t i = 0; i < indices.size(); ++i) {
  705. size_t idx = indices[i];
  706. cum_sum += cur_p->data[idx].p;
  707. // Check if the running sum is greater than typical or if we have kept at least min_keep tokens
  708. if (cum_sum > ctx->p && i >= ctx->min_keep - 1) {
  709. last_idx = i + 1;
  710. break;
  711. }
  712. }
  713. // Resize the output vector to keep only the locally typical tokens
  714. std::vector<llama_token_data> cur_p_new;
  715. for (size_t i = 0; i < last_idx; ++i) {
  716. size_t idx = indices[i];
  717. cur_p_new.push_back(cur_p->data[idx]);
  718. }
  719. // Replace the data in cur_p with the cur_p_new data
  720. std::copy(cur_p_new.begin(), cur_p_new.end(), cur_p->data);
  721. cur_p->size = cur_p_new.size();
  722. cur_p->sorted = false;
  723. }
  724. static struct llama_sampler * llama_sampler_typical_clone(const struct llama_sampler * smpl) {
  725. const auto * ctx = (const llama_sampler_typical *) smpl->ctx;
  726. return llama_sampler_init_typical(ctx->p, ctx->min_keep);
  727. }
  728. static void llama_sampler_typical_free(struct llama_sampler * smpl) {
  729. delete (llama_sampler_typical *) smpl->ctx;
  730. }
  731. static struct llama_sampler_i llama_sampler_typical_i = {
  732. /* .name = */ llama_sampler_typical_name,
  733. /* .accept = */ nullptr,
  734. /* .apply = */ llama_sampler_typical_apply,
  735. /* .reset = */ nullptr,
  736. /* .clone = */ llama_sampler_typical_clone,
  737. /* .free = */ llama_sampler_typical_free,
  738. };
  739. struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
  740. return new llama_sampler {
  741. /* .iface = */ &llama_sampler_typical_i,
  742. /* .ctx = */ new llama_sampler_typical {
  743. /* .p = */ p,
  744. /* .min_keep = */ min_keep,
  745. },
  746. };
  747. }
  748. // temp
  749. struct llama_sampler_temp {
  750. const float temp;
  751. };
  752. static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) {
  753. return "temp";
  754. }
  755. static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
  756. const auto * ctx = (llama_sampler_temp *) smpl->ctx;
  757. for (size_t i = 0; i < cur_p->size; ++i) {
  758. cur_p->data[i].logit /= ctx->temp;
  759. }
  760. }
  761. static struct llama_sampler * llama_sampler_temp_clone(const struct llama_sampler * smpl) {
  762. const auto * ctx = (const llama_sampler_temp *) smpl->ctx;
  763. return llama_sampler_init_temp(ctx->temp);
  764. }
  765. static void llama_sampler_temp_free(struct llama_sampler * smpl) {
  766. delete (llama_sampler_temp *) smpl->ctx;
  767. }
  768. static struct llama_sampler_i llama_sampler_temp_i = {
  769. /* .name = */ llama_sampler_temp_name,
  770. /* .accept = */ nullptr,
  771. /* .apply = */ llama_sampler_temp_apply,
  772. /* .reset = */ nullptr,
  773. /* .clone = */ llama_sampler_temp_clone,
  774. /* .free = */ llama_sampler_temp_free,
  775. };
  776. struct llama_sampler * llama_sampler_init_temp(float temp) {
  777. return new llama_sampler {
  778. /* .iface = */ &llama_sampler_temp_i,
  779. /* .ctx = */ new llama_sampler_temp {
  780. /*.temp = */ temp,
  781. },
  782. };
  783. }
  784. // temp-ext
  785. struct llama_sampler_temp_ext {
  786. const float temp;
  787. const float delta;
  788. const float exponent;
  789. };
  790. static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) {
  791. return "temp-ext";
  792. }
  793. static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
  794. const auto * ctx = (llama_sampler_temp_ext *) smpl->ctx;
  795. if (ctx->delta > 0) {
  796. const float min_temp = std::max(0.0f, ctx->temp - ctx->delta);
  797. const float max_temp = ctx->temp + ctx->delta;
  798. float exponent_val = ctx->exponent;
  799. // no need to do anything if there is only one (or zero) candidates
  800. if (cur_p->size <= 1) {
  801. return;
  802. }
  803. // Calculate maximum possible entropy
  804. float max_entropy = -logf(1.0f / cur_p->size);
  805. llama_sampler_softmax_impl(cur_p);
  806. // Calculate entropy of the softmax probabilities
  807. float entropy = 0.0f;
  808. for (size_t i = 0; i < cur_p->size; ++i) {
  809. float prob = cur_p->data[i].p;
  810. if (prob > 0.0f) { // Ensure no log(0)
  811. entropy -= prob * logf(prob);
  812. }
  813. }
  814. // Normalize the entropy (max_entropy cannot be 0 here because we checked cur_p->size != 1 above)
  815. float normalized_entropy = entropy / max_entropy;
  816. // Map the normalized entropy to the desired temperature range using the power function
  817. float dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent_val);
  818. #ifdef DEBUG
  819. LLAMA_LOG_INFO("Your text maxtemp value is: %f\n", max_temp);
  820. LLAMA_LOG_INFO("Entropy: %f\n", entropy);
  821. LLAMA_LOG_INFO("Max Possible Entropy: %f\n", max_entropy);
  822. LLAMA_LOG_INFO("Normalized Entropy: %f\n", normalized_entropy);
  823. LLAMA_LOG_INFO("Exponent: %f\n", exponent_val);
  824. LLAMA_LOG_INFO("Dynamic Temperature (dyn_temp): %f\n", dyn_temp);
  825. #endif
  826. // Apply the dynamically calculated temperature scaling
  827. for (size_t i = 0; i < cur_p->size; ++i) {
  828. cur_p->data[i].logit /= dyn_temp;
  829. }
  830. // Re-compute softmax probabilities after scaling logits with dynamic temperature
  831. const double max_l_double = cur_p->data[0].logit;
  832. double cum_sum_double = 0.0;
  833. for (size_t i = 0; i < cur_p->size; ++i) {
  834. double p = exp(cur_p->data[i].logit - max_l_double);
  835. cur_p->data[i].p = p; // Store the scaled probability
  836. cum_sum_double += p;
  837. }
  838. for (size_t i = 0; i < cur_p->size; ++i) {
  839. cur_p->data[i].p /= cum_sum_double; // Re-normalize the probabilities
  840. }
  841. #ifdef DEBUG
  842. // Print the updated top 25 probabilities after temperature scaling
  843. LLAMA_LOG_INFO("\nUpdated Top 25 Probabilities After Dynamic Temperature Scaling (in percentages):\n");
  844. for (size_t i = 0; i < 25 && i < cur_p->size; ++i) {
  845. LLAMA_LOG_INFO("Token %zu: %f%%\n", i + 1, cur_p->data[i].p * 100.0f);
  846. }
  847. #endif
  848. } else {
  849. for (size_t i = 0; i < cur_p->size; ++i) {
  850. cur_p->data[i].logit /= ctx->temp;
  851. }
  852. }
  853. }
  854. static struct llama_sampler * llama_sampler_temp_ext_clone(const struct llama_sampler * smpl) {
  855. const auto * ctx = (const llama_sampler_temp_ext *) smpl->ctx;
  856. return llama_sampler_init_temp_ext(ctx->temp, ctx->delta, ctx->exponent);
  857. }
  858. static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
  859. delete (llama_sampler_temp_ext *) smpl->ctx;
  860. }
  861. static struct llama_sampler_i llama_sampler_temp_ext_i = {
  862. /* .name = */ llama_sampler_temp_ext_name,
  863. /* .accept = */ nullptr,
  864. /* .apply = */ llama_sampler_temp_ext_apply,
  865. /* .reset = */ nullptr,
  866. /* .clone = */ llama_sampler_temp_ext_clone,
  867. /* .free = */ llama_sampler_temp_ext_free,
  868. };
  869. struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) {
  870. return new llama_sampler {
  871. /* .iface = */ &llama_sampler_temp_ext_i,
  872. /* .ctx = */ new llama_sampler_temp_ext {
  873. /* .temp = */ temp,
  874. /* .delta = */ delta,
  875. /* .exponent = */ exponent,
  876. },
  877. };
  878. }
  879. // mirostat
  880. struct llama_sampler_mirostat {
  881. const int32_t n_vocab;
  882. const uint32_t seed;
  883. uint32_t seed_cur;
  884. const float tau;
  885. const float eta;
  886. const int32_t m;
  887. float mu;
  888. std::mt19937 rng;
  889. };
  890. static const char * llama_sampler_mirostat_name(const struct llama_sampler * /*smpl*/) {
  891. return "mirostat";
  892. }
  893. static void llama_sampler_mirostat_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
  894. auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
  895. llama_sampler_softmax_impl(cur_p);
  896. // Estimate s_hat using the most probable m tokens
  897. float s_hat = 0.0;
  898. float sum_ti_bi = 0.0;
  899. float sum_ti_sq = 0.0;
  900. for (size_t i = 0; i < size_t(ctx->m - 1) && i < cur_p->size - 1; ++i) {
  901. float t_i = logf(float(i + 2) / float(i + 1));
  902. float b_i = logf(cur_p->data[i].p / cur_p->data[i + 1].p);
  903. sum_ti_bi += t_i * b_i;
  904. sum_ti_sq += t_i * t_i;
  905. }
  906. s_hat = sum_ti_bi / sum_ti_sq;
  907. // Compute k from the estimated s_hat and target surprise value
  908. float epsilon_hat = s_hat - 1;
  909. float k = powf((epsilon_hat * powf(2, ctx->mu)) / (1 - powf(ctx->n_vocab, -epsilon_hat)), 1 / s_hat);
  910. llama_sampler_top_k_impl(cur_p, std::max(int(k), 1));
  911. llama_sampler_softmax_impl(cur_p);
  912. const int idx = llama_sample_dist(cur_p, ctx->rng);
  913. cur_p->selected = idx;
  914. float observed_surprise = -log2f(cur_p->data[idx].p);
  915. float e = observed_surprise - ctx->tau;
  916. // Update mu using the learning rate and error
  917. ctx->mu = ctx->mu - ctx->eta * e;
  918. }
  919. static struct llama_sampler * llama_sampler_mirostat_clone(const struct llama_sampler * smpl) {
  920. const auto * ctx = (const llama_sampler_mirostat *) smpl->ctx;
  921. auto * result = llama_sampler_init_mirostat(ctx->n_vocab, ctx->seed, ctx->tau, ctx->eta, ctx->m);
  922. // copy the state
  923. {
  924. auto * result_ctx = (llama_sampler_mirostat *) smpl->ctx;
  925. result_ctx->mu = ctx->mu;
  926. result_ctx->rng = ctx->rng;
  927. }
  928. return result;
  929. }
  930. static void llama_sampler_mirostat_reset(struct llama_sampler * smpl) {
  931. auto * ctx = (llama_sampler_mirostat *) smpl->ctx;
  932. ctx->mu = 2.0f*ctx->tau;
  933. ctx->seed_cur = get_rng_seed(ctx->seed);
  934. ctx->rng.seed(ctx->seed_cur);
  935. }
  936. static void llama_sampler_mirostat_free(struct llama_sampler * smpl) {
  937. delete (llama_sampler_mirostat *) smpl->ctx;
  938. }
  939. static struct llama_sampler_i llama_sampler_mirostat_i = {
  940. /* .name = */ llama_sampler_mirostat_name,
  941. /* .accept = */ nullptr,
  942. /* .apply = */ llama_sampler_mirostat_apply,
  943. /* .reset = */ llama_sampler_mirostat_reset,
  944. /* .clone = */ llama_sampler_mirostat_clone,
  945. /* .free = */ llama_sampler_mirostat_free,
  946. };
  947. struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) {
  948. auto seed_cur = get_rng_seed(seed);
  949. return new llama_sampler {
  950. /* .iface = */ &llama_sampler_mirostat_i,
  951. /* .ctx = */ new llama_sampler_mirostat {
  952. /* .n_vocab = */ n_vocab,
  953. /* .seed = */ seed,
  954. /* .seed_cur = */ seed_cur,
  955. /* .tau = */ tau,
  956. /* .eta = */ eta,
  957. /* .m = */ m,
  958. /* .mu = */ 2.0f*tau,
  959. /* .rng = */ std::mt19937(seed_cur),
  960. },
  961. };
  962. }
  963. // mirostat v2
  964. struct llama_sampler_mirostat_v2 {
  965. const uint32_t seed;
  966. uint32_t seed_cur;
  967. const float tau;
  968. const float eta;
  969. float mu;
  970. std::mt19937 rng;
  971. };
  972. static const char * llama_sampler_mirostat_v2_name(const struct llama_sampler * /*smpl*/) {
  973. return "mirostat-v2";
  974. }
  975. static void llama_sampler_mirostat_v2_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
  976. auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
  977. llama_sampler_softmax_impl(cur_p);
  978. // Truncate the words with surprise values greater than mu
  979. cur_p->size = std::distance(cur_p->data, std::find_if(cur_p->data, cur_p->data + cur_p->size, [&](const llama_token_data & candidate) {
  980. return -log2f(candidate.p) > ctx->mu;
  981. }));
  982. if (cur_p->size == 0) {
  983. cur_p->size = 1;
  984. }
  985. // Normalize the probabilities of the remaining words
  986. llama_sampler_softmax_impl(cur_p);
  987. const int idx = llama_sample_dist(cur_p, ctx->rng);
  988. cur_p->selected = idx;
  989. float observed_surprise = -log2f(cur_p->data[idx].p);
  990. float e = observed_surprise - ctx->tau;
  991. // Update mu using the learning rate and error
  992. ctx->mu = ctx->mu - ctx->eta * e;
  993. }
  994. static void llama_sampler_mirostat_v2_reset(struct llama_sampler * smpl) {
  995. auto * ctx = (llama_sampler_mirostat_v2 *) smpl->ctx;
  996. ctx->mu = 2.0f*ctx->tau;
  997. ctx->seed_cur = get_rng_seed(ctx->seed);
  998. ctx->rng.seed(ctx->seed_cur);
  999. }
  1000. static struct llama_sampler * llama_sampler_mirostat_v2_clone(const struct llama_sampler * smpl) {
  1001. const auto * ctx = (const llama_sampler_mirostat_v2 *) smpl->ctx;
  1002. auto * result = llama_sampler_init_mirostat_v2(ctx->seed, ctx->tau, ctx->eta);
  1003. // copy the state
  1004. {
  1005. auto * result_ctx = (llama_sampler_mirostat_v2 *) result->ctx;
  1006. result_ctx->mu = ctx->mu;
  1007. result_ctx->rng = ctx->rng;
  1008. }
  1009. return result;
  1010. }
  1011. static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) {
  1012. delete (llama_sampler_mirostat_v2 *) smpl->ctx;
  1013. }
  1014. static struct llama_sampler_i llama_sampler_mirostat_v2_i = {
  1015. /* .name = */ llama_sampler_mirostat_v2_name,
  1016. /* .accept = */ nullptr,
  1017. /* .apply = */ llama_sampler_mirostat_v2_apply,
  1018. /* .reset = */ llama_sampler_mirostat_v2_reset,
  1019. /* .clone = */ llama_sampler_mirostat_v2_clone,
  1020. /* .free = */ llama_sampler_mirostat_v2_free,
  1021. };
  1022. struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) {
  1023. auto seed_cur = get_rng_seed(seed);
  1024. return new llama_sampler {
  1025. /* .iface = */ &llama_sampler_mirostat_v2_i,
  1026. /* .ctx = */ new llama_sampler_mirostat_v2 {
  1027. /* .seed = */ seed,
  1028. /* .seed_cur = */ seed_cur,
  1029. /* .tau = */ tau,
  1030. /* .eta = */ eta,
  1031. /* .mu = */ 2.0f*tau,
  1032. /* .rng = */ std::mt19937(seed_cur),
  1033. },
  1034. };
  1035. }
  1036. // grammar
  1037. struct llama_sampler_grammar {
  1038. const struct llama_vocab * vocab;
  1039. std::string grammar_str;
  1040. std::string grammar_root;
  1041. struct llama_grammar * grammar;
  1042. };
  1043. static const char * llama_sampler_grammar_name(const struct llama_sampler * /*smpl*/) {
  1044. return "grammar";
  1045. }
  1046. static void llama_sampler_grammar_accept_impl(struct llama_sampler * smpl, llama_token token) {
  1047. auto * ctx = (llama_sampler_grammar *) smpl->ctx;
  1048. if (ctx->grammar) {
  1049. llama_grammar_accept_impl(*ctx->grammar, token);
  1050. }
  1051. }
  1052. static void llama_sampler_grammar_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
  1053. auto * ctx = (llama_sampler_grammar *) smpl->ctx;
  1054. if (ctx->grammar) {
  1055. llama_grammar_apply_impl(*ctx->grammar, cur_p);
  1056. }
  1057. }
  1058. static void llama_sampler_grammar_reset(struct llama_sampler * smpl) {
  1059. auto * ctx = (llama_sampler_grammar *) smpl->ctx;
  1060. if (!ctx->grammar) {
  1061. return;
  1062. }
  1063. auto * grammar_new = llama_grammar_init_impl(ctx->grammar->vocab, ctx->grammar_str.c_str(), ctx->grammar_root.c_str());
  1064. llama_grammar_free_impl(ctx->grammar);
  1065. ctx->grammar = grammar_new;
  1066. }
  1067. static struct llama_sampler * llama_sampler_grammar_clone(const struct llama_sampler * smpl) {
  1068. const auto * ctx = (const llama_sampler_grammar *) smpl->ctx;
  1069. auto * result = llama_sampler_init_grammar_impl(*ctx->vocab, nullptr, nullptr);
  1070. // copy the state
  1071. {
  1072. auto * result_ctx = (llama_sampler_grammar *) result->ctx;
  1073. if (ctx->grammar) {
  1074. result_ctx->grammar_str = ctx->grammar_str;
  1075. result_ctx->grammar_root = ctx->grammar_root;
  1076. result_ctx->grammar = llama_grammar_clone_impl(*ctx->grammar);
  1077. }
  1078. }
  1079. return result;
  1080. }
  1081. static void llama_sampler_grammar_free(struct llama_sampler * smpl) {
  1082. const auto * ctx = (llama_sampler_grammar *) smpl->ctx;
  1083. if (ctx->grammar) {
  1084. llama_grammar_free_impl(ctx->grammar);
  1085. }
  1086. delete ctx;
  1087. }
  1088. static struct llama_sampler_i llama_sampler_grammar_i = {
  1089. /* .name = */ llama_sampler_grammar_name,
  1090. /* .accept = */ llama_sampler_grammar_accept_impl,
  1091. /* .apply = */ llama_sampler_grammar_apply,
  1092. /* .reset = */ llama_sampler_grammar_reset,
  1093. /* .clone = */ llama_sampler_grammar_clone,
  1094. /* .free = */ llama_sampler_grammar_free,
  1095. };
  1096. struct llama_sampler * llama_sampler_init_grammar_impl(const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root) {
  1097. auto * ctx = new llama_sampler_grammar;
  1098. if (grammar_str != nullptr && grammar_str[0] != '\0') {
  1099. *ctx = {
  1100. /* .vocab = */ &vocab,
  1101. /* .grammar_str = */ grammar_str,
  1102. /* .grammar_root = */ grammar_root,
  1103. /* .grammar = */ llama_grammar_init_impl(&vocab, grammar_str, grammar_root),
  1104. };
  1105. } else {
  1106. *ctx = {
  1107. /* .vocab = */ &vocab,
  1108. /* .grammar_str = */ {},
  1109. /* .grammar_root = */ {},
  1110. /* .grammar = */ nullptr,
  1111. };
  1112. }
  1113. return new llama_sampler {
  1114. /* .iface = */ &llama_sampler_grammar_i,
  1115. /* .ctx = */ ctx,
  1116. };
  1117. }
  1118. // penalties
  1119. struct llama_sampler_penalties {
  1120. const int32_t n_vocab;
  1121. const llama_token special_eos_id;
  1122. const llama_token linefeed_id;
  1123. const int32_t penalty_last_n;
  1124. const float penalty_repeat;
  1125. const float penalty_freq;
  1126. const float penalty_present;
  1127. const bool penalize_nl;
  1128. const bool ignore_eos;
  1129. ring_buffer<llama_token> prev;
  1130. };
  1131. static const char * llama_sampler_penalties_name(const struct llama_sampler * /*smpl*/) {
  1132. return "penalties";
  1133. }
  1134. static void llama_sampler_penalties_accept(struct llama_sampler * smpl, llama_token token) {
  1135. auto * ctx = (llama_sampler_penalties *) smpl->ctx;
  1136. if (ctx->penalty_last_n == 0) {
  1137. return;
  1138. }
  1139. ctx->prev.push_back(token);
  1140. }
  1141. static void llama_sampler_penalties_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
  1142. auto * ctx = (llama_sampler_penalties *) smpl->ctx;
  1143. if (ctx->ignore_eos) {
  1144. assert(ctx->special_eos_id >= 0);
  1145. // optimistically check if the candidates are not yet sorted/shuffled/truncated
  1146. if (cur_p->size > (size_t) ctx->special_eos_id && cur_p->data[ctx->special_eos_id].id == ctx->special_eos_id) {
  1147. cur_p->data[ctx->special_eos_id].logit = -INFINITY;
  1148. } else {
  1149. // else, search for the special EOS token
  1150. for (size_t i = 0; i < cur_p->size; ++i) {
  1151. if (cur_p->data[i].id == ctx->special_eos_id) {
  1152. cur_p->data[i].logit = -INFINITY;
  1153. break;
  1154. }
  1155. }
  1156. }
  1157. }
  1158. if ((ctx->penalty_last_n == 0) ||
  1159. (ctx->penalty_repeat == 1.0f && ctx->penalty_freq == 0.0f && ctx->penalty_present == 0.0f)) {
  1160. return;
  1161. }
  1162. bool nl_found = false;
  1163. size_t nl_idx = 0;
  1164. float nl_logit = -INFINITY;
  1165. if (!ctx->penalize_nl) {
  1166. assert(ctx->linefeed_id >= 0);
  1167. // optimistically check if the candidates are not yet sorted/shuffled/truncated
  1168. if (cur_p->size > (size_t) ctx->linefeed_id && cur_p->data[ctx->linefeed_id].id == ctx->linefeed_id) {
  1169. nl_found = true;
  1170. nl_idx = ctx->linefeed_id;
  1171. nl_logit = cur_p->data[ctx->linefeed_id].logit;
  1172. } else {
  1173. // else, search for the linefeed token
  1174. for (size_t i = 0; i < cur_p->size; ++i) {
  1175. if (cur_p->data[i].id == ctx->linefeed_id) {
  1176. nl_found = true;
  1177. nl_idx = i;
  1178. nl_logit = cur_p->data[i].logit;
  1179. break;
  1180. }
  1181. }
  1182. }
  1183. }
  1184. // Create a frequency map to count occurrences of each token in last_tokens
  1185. // TODO: optimize this by maintaining the token count in the sampler context
  1186. using llama_token_cnt = std::unordered_map<llama_token, int>;
  1187. llama_token_cnt token_count;
  1188. for (int i = 0; i < std::min<int>(ctx->penalty_last_n, ctx->prev.size()); ++i) {
  1189. token_count[ctx->prev.rat(i)]++;
  1190. }
  1191. // Apply frequency and presence penalties to the cur_p
  1192. for (size_t i = 0; i < cur_p->size; ++i) {
  1193. const auto token_iter = token_count.find(cur_p->data[i].id);
  1194. if (token_iter == token_count.end()) {
  1195. continue;
  1196. }
  1197. const int count = token_iter->second;
  1198. // The academic publication that described this technique actually just only divided, but that would cause tokens with negative logits to become more likely, which is obviously wrong.
  1199. // This is common fix for this problem, which is to multiply by the penalty instead of dividing.
  1200. if (cur_p->data[i].logit <= 0) {
  1201. cur_p->data[i].logit *= ctx->penalty_repeat;
  1202. } else {
  1203. cur_p->data[i].logit /= ctx->penalty_repeat;
  1204. }
  1205. cur_p->data[i].logit -= float(count) * ctx->penalty_freq + float(count > 0) * ctx->penalty_present;
  1206. }
  1207. cur_p->sorted = false;
  1208. if (!ctx->penalize_nl && nl_found) {
  1209. // restore the logit of the newline token if it was penalized
  1210. cur_p->data[nl_idx].logit = nl_logit;
  1211. }
  1212. }
  1213. static void llama_sampler_penalties_reset(struct llama_sampler * smpl) {
  1214. auto * ctx = (llama_sampler_penalties *) smpl->ctx;
  1215. ctx->prev.clear();
  1216. }
  1217. static struct llama_sampler * llama_sampler_penalties_clone(const struct llama_sampler * smpl) {
  1218. const auto * ctx = (const llama_sampler_penalties *) smpl->ctx;
  1219. auto * result = llama_sampler_init_penalties(
  1220. ctx->n_vocab,
  1221. ctx->special_eos_id,
  1222. ctx->linefeed_id,
  1223. ctx->penalty_last_n,
  1224. ctx->penalty_repeat,
  1225. ctx->penalty_freq,
  1226. ctx->penalty_present,
  1227. ctx->penalize_nl,
  1228. ctx->ignore_eos);
  1229. // copy the state
  1230. {
  1231. auto * result_ctx = (llama_sampler_penalties *) result->ctx;
  1232. result_ctx->prev = ctx->prev;
  1233. }
  1234. return result;
  1235. }
  1236. static void llama_sampler_penalties_free(struct llama_sampler * smpl) {
  1237. delete (llama_sampler_penalties *) smpl->ctx;
  1238. }
  1239. static struct llama_sampler_i llama_sampler_penalties_i = {
  1240. /* .name = */ llama_sampler_penalties_name,
  1241. /* .accept = */ llama_sampler_penalties_accept,
  1242. /* .apply = */ llama_sampler_penalties_apply,
  1243. /* .reset = */ llama_sampler_penalties_reset,
  1244. /* .clone = */ llama_sampler_penalties_clone,
  1245. /* .free = */ llama_sampler_penalties_free,
  1246. };
  1247. struct llama_sampler * llama_sampler_init_penalties(
  1248. int32_t n_vocab,
  1249. llama_token special_eos_id,
  1250. llama_token linefeed_id,
  1251. int32_t penalty_last_n,
  1252. float penalty_repeat,
  1253. float penalty_freq,
  1254. float penalty_present,
  1255. bool penalize_nl,
  1256. bool ignore_eos) {
  1257. if (linefeed_id == LLAMA_TOKEN_NULL) {
  1258. penalize_nl = true;
  1259. }
  1260. if (special_eos_id == LLAMA_TOKEN_NULL) {
  1261. ignore_eos = false;
  1262. }
  1263. penalty_last_n = std::max(penalty_last_n, 0);
  1264. return new llama_sampler {
  1265. /* .iface = */ &llama_sampler_penalties_i,
  1266. /* .ctx = */ new llama_sampler_penalties {
  1267. /* .n_vocab = */ n_vocab,
  1268. /* .special_eos_id = */ special_eos_id,
  1269. /* .linefeed_id = */ linefeed_id,
  1270. /* .penalty_last_n = */ penalty_last_n,
  1271. /* .penalty_repeat = */ penalty_repeat,
  1272. /* .penalty_freq = */ penalty_freq,
  1273. /* .penalty_present = */ penalty_present,
  1274. /* .penalize_nl = */ penalize_nl,
  1275. /* .ignore_eos = */ ignore_eos,
  1276. /* .prev = */ ring_buffer<llama_token>(penalty_last_n),
  1277. },
  1278. };
  1279. }
  1280. // logit-bias
  1281. struct llama_sampler_logit_bias {
  1282. const int32_t n_vocab;
  1283. const std::vector<llama_logit_bias> logit_bias;
  1284. std::vector<llama_logit_bias> to_search;
  1285. };
  1286. static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) {
  1287. return "logit-bias";
  1288. }
  1289. static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
  1290. auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
  1291. if (ctx->logit_bias.empty()) {
  1292. return;
  1293. }
  1294. ctx->to_search.clear();
  1295. // update the candidates that have not been shuffled in the vocabulary (i.e. idx == id)
  1296. for (const auto & lb : ctx->logit_bias) {
  1297. if (lb.token >= 0 && cur_p->size > (size_t) lb.token && cur_p->data[lb.token].id == lb.token) {
  1298. cur_p->data[lb.token].logit += lb.bias;
  1299. } else {
  1300. ctx->to_search.push_back(lb);
  1301. }
  1302. }
  1303. if (ctx->to_search.empty()) {
  1304. return;
  1305. }
  1306. // search for the remaining candidates that were not found in the previous step
  1307. for (size_t i = 0; i < cur_p->size; ++i) {
  1308. for (const auto & lb : ctx->to_search) {
  1309. if (cur_p->data[i].id == lb.token) {
  1310. cur_p->data[i].logit += lb.bias;
  1311. break;
  1312. }
  1313. }
  1314. }
  1315. }
  1316. static struct llama_sampler * llama_sampler_logit_bias_clone(const struct llama_sampler * smpl) {
  1317. const auto * ctx = (const llama_sampler_logit_bias *) smpl->ctx;
  1318. return llama_sampler_init_logit_bias(ctx->n_vocab, ctx->logit_bias.size(), ctx->logit_bias.data());
  1319. }
  1320. static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) {
  1321. delete (llama_sampler_logit_bias *) smpl->ctx;
  1322. }
  1323. static struct llama_sampler_i llama_sampler_logit_bias_i = {
  1324. /* .name = */ llama_sampler_logit_bias_name,
  1325. /* .accept = */ nullptr,
  1326. /* .apply = */ llama_sampler_logit_bias_apply,
  1327. /* .reset = */ nullptr,
  1328. /* .clone = */ llama_sampler_logit_bias_clone,
  1329. /* .free = */ llama_sampler_logit_bias_free,
  1330. };
  1331. struct llama_sampler * llama_sampler_init_logit_bias(
  1332. int32_t n_vocab,
  1333. int32_t n_logit_bias,
  1334. const llama_logit_bias * logit_bias) {
  1335. return new llama_sampler {
  1336. /* .iface = */ &llama_sampler_logit_bias_i,
  1337. /* .ctx = */ new llama_sampler_logit_bias {
  1338. /* .n_vocab = */ n_vocab,
  1339. /* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
  1340. /* .to_search = */ {},
  1341. },
  1342. };
  1343. }
  1344. // utils
  1345. uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
  1346. if (smpl->iface == &llama_sampler_dist_i) {
  1347. return ((const llama_sampler_dist *) smpl->ctx)->seed_cur;
  1348. }
  1349. if (smpl->iface == &llama_sampler_mirostat_i) {
  1350. return ((const llama_sampler_mirostat *) smpl->ctx)->seed_cur;
  1351. }
  1352. if (smpl->iface == &llama_sampler_mirostat_v2_i) {
  1353. return ((const llama_sampler_mirostat_v2 *) smpl->ctx)->seed_cur;
  1354. }
  1355. if (smpl->iface == &llama_sampler_chain_i) {
  1356. const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
  1357. for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
  1358. const uint32_t seed = llama_sampler_get_seed(*it);
  1359. if (seed != LLAMA_DEFAULT_SEED) {
  1360. return seed;
  1361. }
  1362. }
  1363. }
  1364. return LLAMA_DEFAULT_SEED;
  1365. }
  1366. // perf
  1367. struct llama_perf_sampler_data llama_perf_sampler(const struct llama_sampler * chain) {
  1368. struct llama_perf_sampler_data data = {};
  1369. if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
  1370. GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
  1371. }
  1372. const auto * ctx = (const struct llama_sampler_chain *) chain->ctx;
  1373. data.t_sample_ms = 1e-3 * ctx->t_sample_us;
  1374. data.n_sample = std::max(0, ctx->n_sample);
  1375. return data;
  1376. }
  1377. void llama_perf_sampler_print(const struct llama_sampler * chain) {
  1378. const auto data = llama_perf_sampler(chain);
  1379. LLAMA_LOG_INFO("%s: sampling time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
  1380. __func__, data.t_sample_ms, data.n_sample, data.t_sample_ms / data.n_sample, 1e3 / data.t_sample_ms * data.n_sample);
  1381. }
  1382. void llama_perf_sampler_reset(struct llama_sampler * chain) {
  1383. if (chain == nullptr || chain->iface != &llama_sampler_chain_i) {
  1384. GGML_ABORT("%s: invalid sampler passed - requires a sampler created with llama_sampler_chain_init()\n", __func__);
  1385. }
  1386. auto * ctx = (struct llama_sampler_chain *) chain->ctx;
  1387. ctx->t_sample_us = ctx->n_sample = 0;
  1388. }