sampling.h 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. /**
  2. * llama.cpp - commit 3f1ae2e32cde00c39b96be6d01c2997c29bae555 - do not edit this file
  3. *
  4. * MIT License
  5. *
  6. * Copyright (c) 2023-2024 The ggml authors
  7. *
  8. * Permission is hereby granted, free of charge, to any person obtaining a copy
  9. * of this software and associated documentation files (the "Software"), to deal
  10. * in the Software without restriction, including without limitation the rights
  11. * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  12. * copies of the Software, and to permit persons to whom the Software is
  13. * furnished to do so, subject to the following conditions:
  14. *
  15. * The above copyright notice and this permission notice shall be included in all
  16. * copies or substantial portions of the Software.
  17. *
  18. * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  19. * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  20. * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  21. * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  22. * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  23. * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  24. * SOFTWARE.
  25. */
  26. #pragma once
  27. #include "llama.h"
  28. #include "common.h"
  29. #include <string>
  30. #include <vector>
  31. // gpt_sampler extends llama_sampler with additional functionality:
  32. //
  33. // - grammar support
  34. // - custom sampler logic based on the parameters
  35. // - history of the last accepted tokens
  36. // - performance metrics
  37. //
  38. // This goal is to have a common implementation of the sampling logic shared across the examples.
  39. // For example, depending on the temperature, the sampling chain can be very simple (greedy) or more
  40. // complex (top-k, top-p, etc).
  41. //
  42. // Another example is related to the grammar. In general, the grammar constraints applied on the full
  43. // vocabulary can be very taxing. To improve performance, the grammar can be applied only to the sampled
  44. // token in order to verify if it fits the grammar. And only if the token doesn't fit the grammar, the
  45. // grammar constraints are applied to the full vocabulary and the token is resampled.
  46. //
  47. // The gpt_sampler also maintains a container with the last accepted tokens. In the future, this can
  48. // be moved into the core llama library.
  49. //
  50. // For convenience, the gpt_sampler also maintains a container with the current candidate tokens.
  51. // This can be used to access the probabilities of the rest of the non-sampled tokens.
  52. //
  53. // TODO: measure grammar performance
  54. //
  55. struct gpt_sampler;
  56. // llama_sampler API overloads
  57. struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const struct gpt_sampler_params & params);
  58. void gpt_sampler_free(struct gpt_sampler * gsmpl);
  59. // if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
  60. void gpt_sampler_accept(struct gpt_sampler * gsmpl, llama_token token, bool accept_grammar);
  61. void gpt_sampler_reset (struct gpt_sampler * gsmpl);
  62. struct gpt_sampler * gpt_sampler_clone (struct gpt_sampler * gsmpl);
  63. // arguments can be nullptr to skip printing
  64. void gpt_perf_print(const struct llama_context * ctx, const struct gpt_sampler * gsmpl);
  65. // extended sampling implementation:
  66. //
  67. // - set logits
  68. // - apply the configured sampler chain
  69. // - check if the token fits the grammar (if any)
  70. // - if not: resample by first applying the grammar constraints and then sampling again (slower path)
  71. //
  72. // if grammar_first is true, the grammar is applied before the samplers (slower)
  73. // useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
  74. //
  75. llama_token gpt_sampler_sample(struct gpt_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
  76. uint32_t gpt_sampler_get_seed(const struct gpt_sampler * gsmpl);
  77. // helpers
  78. // access the internal list of current candidate tokens
  79. llama_token_data_array * gpt_sampler_get_candidates(struct gpt_sampler * gsmpl);
  80. // get the last accepted token
  81. llama_token gpt_sampler_last(const struct gpt_sampler * gsmpl);
  82. // print the sampler chain into a string
  83. std::string gpt_sampler_print(const struct gpt_sampler * gsmpl);
  84. // get a string representation of the last accepted tokens
  85. std::string gpt_sampler_prev_str(gpt_sampler * gsmpl, llama_context * ctx, int n);
  86. char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr);
  87. std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr);
  88. std::vector<enum gpt_sampler_type> gpt_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
  89. std::vector<enum gpt_sampler_type> gpt_sampler_types_from_chars(const std::string & chars);