123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130 |
- /**
- * llama.cpp - commit ba1cb19cdd0d92e012e0f6e009e0620f854b6afd - do not edit this file
- *
- * MIT License
- *
- * Copyright (c) 2023-2024 The ggml authors
- *
- * Permission is hereby granted, free of charge, to any person obtaining a copy
- * of this software and associated documentation files (the "Software"), to deal
- * in the Software without restriction, including without limitation the rights
- * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- * copies of the Software, and to permit persons to whom the Software is
- * furnished to do so, subject to the following conditions:
- *
- * The above copyright notice and this permission notice shall be included in all
- * copies or substantial portions of the Software.
- *
- * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- * SOFTWARE.
- */
- #pragma once
- #include "llama.h"
- #include "common.h"
- #include <string>
- #include <vector>
- // common_sampler extends llama_sampler with additional functionality:
- //
- // - grammar support
- // - custom sampler logic based on the parameters
- // - history of the last accepted tokens
- // - performance metrics
- //
- // This goal is to have a common implementation of the sampling logic shared across the examples.
- // For example, depending on the temperature, the sampling chain can be very simple (greedy) or more
- // complex (top-k, top-p, etc).
- //
- // Another example is related to the grammar. In general, the grammar constraints applied on the full
- // vocabulary can be very taxing. To improve performance, the grammar can be applied only to the sampled
- // token in order to verify if it fits the grammar. And only if the token doesn't fit the grammar, the
- // grammar constraints are applied to the full vocabulary and the token is resampled.
- //
- // The common_sampler also maintains a container with the last accepted tokens. In the future, this can
- // be moved into the core llama library.
- //
- // For convenience, the common_sampler also maintains a container with the current candidate tokens.
- // This can be used to access the probabilities of the rest of the non-sampled tokens.
- //
- // TODO: measure grammar performance
- //
- struct common_sampler;
- // llama_sampler API overloads
- struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);
- void common_sampler_free(struct common_sampler * gsmpl);
- // if accept_grammar is true, the token is accepted both by the sampling chain and the grammar
- void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar);
- void common_sampler_reset (struct common_sampler * gsmpl);
- struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
- // arguments can be nullptr to skip printing
- void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
- // extended sampling implementation:
- //
- // - set logits
- // - apply the configured sampler chain
- // - check if the token fits the grammar (if any)
- // - if not: resample by first applying the grammar constraints and then sampling again (slower path)
- //
- // if grammar_first is true, the grammar is applied before the samplers (slower)
- // useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
- //
- llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
- // generalized version of common_sampler_sample
- //
- // will cross-reference the sampled tokens with a batch of draft tokens and accept those that match
- // if the sampler disagrees at some point, we stop and return the accepted tokens up to now
- //
- // common_sampler_sample_n(gsmpl, ctx, { idx }, {});
- //
- // is equivalent to
- //
- // common_sampler_sample(gsmpl, ctx, idx);
- // common_sampler_accept(gsmpl, token, true);
- //
- // requires: idxs.size() == draft.size() + 1
- //
- // returns at least 1 token, up to idxs.size()
- //
- std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
- // assume idxs == [ 0, 1, 2, ..., draft.size() ]
- std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
- uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
- // helpers
- // access the internal list of current candidate tokens
- llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl);
- // get the last accepted token
- llama_token common_sampler_last(const struct common_sampler * gsmpl);
- // print the sampler chain into a string
- std::string common_sampler_print(const struct common_sampler * gsmpl);
- // get a string representation of the last accepted tokens
- std::string common_sampler_prev_str(common_sampler * gsmpl, llama_context * ctx, int n);
- char common_sampler_type_to_chr(enum common_sampler_type cnstr);
- std::string common_sampler_type_to_str(enum common_sampler_type cnstr);
- std::vector<enum common_sampler_type> common_sampler_types_from_names(const std::vector<std::string> & names, bool allow_alt_names);
- std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std::string & chars);
|