|
#include "lm/interpolate/tune_derivatives.hh" |
|
|
|
#include "lm/interpolate/tune_instances.hh" |
|
#include "lm/interpolate/tune_matrix.hh" |
|
#include "util/stream/chain.hh" |
|
#include "util/stream/typed_stream.hh" |
|
|
|
#include <Eigen/Core> |
|
|
|
namespace lm { namespace interpolate { |
|
|
|
Accum Derivatives(Instances &in, const Vector &weights, Vector &gradient, Matrix &hessian) { |
|
gradient = in.CorrectGradientTerm(); |
|
hessian = Matrix::Zero(weights.rows(), weights.rows()); |
|
|
|
|
|
|
|
Vector weighted_uni((in.LNUnigrams() * weights).array().exp()); |
|
|
|
weighted_uni(in.BOS()) = 0.0; |
|
Accum Z_epsilon = weighted_uni.sum(); |
|
|
|
Vector unigram_cross(in.LNUnigrams().transpose() * weighted_uni / Z_epsilon); |
|
|
|
Accum sum_B_I = 0.0; |
|
Accum sum_ln_Z_context = 0.0; |
|
|
|
|
|
Matrix convolve; |
|
Vector full_cross; |
|
Matrix hessian_missing_Z_context; |
|
|
|
Vector ln_p_i_backed; |
|
|
|
Vector ln_p_i_full; |
|
|
|
|
|
util::stream::Chain chain(util::stream::ChainConfig(in.ReadExtensionsEntrySize(), 2, 64 << 20)); |
|
chain.ActivateProgress(); |
|
in.ReadExtensions(chain); |
|
util::stream::TypedStream<Extension> extensions(chain.Add()); |
|
chain >> util::stream::kRecycle; |
|
|
|
|
|
for (InstanceIndex n = 0; n < in.NumInstances(); ++n) { |
|
assert(extensions); |
|
Accum weighted_backoffs = exp(in.LNBackoffs(n).dot(weights)); |
|
|
|
|
|
Accum unnormalized_sum_x_p_I = 0.0; |
|
|
|
Accum unnormalized_sum_x_p_I_full = 0.0; |
|
|
|
|
|
hessian_missing_Z_context = Matrix::Zero(weights.rows(), weights.rows()); |
|
|
|
full_cross = Vector::Zero(weights.rows()); |
|
|
|
|
|
while (extensions && extensions->instance == n) { |
|
const WordIndex word = extensions->word; |
|
unnormalized_sum_x_p_I += weighted_uni(word); |
|
|
|
ln_p_i_backed = in.LNUnigrams().row(word) + in.LNBackoffs(n); |
|
|
|
|
|
ln_p_i_full = ln_p_i_backed; |
|
|
|
for (; extensions && extensions->word == word && extensions->instance == n; ++extensions) { |
|
ln_p_i_full(extensions->model) = extensions->ln_prob; |
|
} |
|
|
|
|
|
Accum weighted = exp(ln_p_i_full.dot(weights)); |
|
unnormalized_sum_x_p_I_full += weighted; |
|
|
|
|
|
full_cross.noalias() += |
|
weighted * ln_p_i_full |
|
- weighted_uni(word) * weighted_backoffs * in.LNUnigrams().row(word).transpose(); |
|
|
|
|
|
hessian_missing_Z_context.noalias() += |
|
|
|
weighted * ln_p_i_full * ln_p_i_full.transpose() |
|
|
|
- weighted_uni(word) * weighted_backoffs * ln_p_i_backed * ln_p_i_backed.transpose(); |
|
} |
|
|
|
Accum Z_context = |
|
weighted_backoffs * (Z_epsilon - unnormalized_sum_x_p_I) |
|
+ unnormalized_sum_x_p_I_full; |
|
sum_ln_Z_context += log(Z_context); |
|
Accum B_I = Z_epsilon / Z_context * weighted_backoffs; |
|
sum_B_I += B_I; |
|
|
|
|
|
|
|
|
|
full_cross /= Z_context; |
|
full_cross += |
|
|
|
B_I * (in.LNBackoffs(n).transpose() + unigram_cross) |
|
|
|
- unnormalized_sum_x_p_I / Z_epsilon * B_I * in.LNBackoffs(n).transpose(); |
|
gradient += full_cross; |
|
|
|
convolve = unigram_cross * in.LNBackoffs(n); |
|
|
|
hessian.noalias() += |
|
|
|
B_I * (convolve + convolve.transpose() + in.LNBackoffs(n).transpose() * in.LNBackoffs(n)) |
|
|
|
+ hessian_missing_Z_context / Z_context |
|
|
|
- full_cross * full_cross.transpose(); |
|
} |
|
|
|
for (Matrix::Index x = 0; x < weighted_uni.rows(); ++x) { |
|
|
|
|
|
hessian.noalias() += sum_B_I * weighted_uni(x) / Z_epsilon * in.LNUnigrams().row(x).transpose() * in.LNUnigrams().row(x); |
|
} |
|
return exp((in.CorrectGradientTerm().dot(weights) + sum_ln_Z_context) / static_cast<double>(in.NumInstances())); |
|
} |
|
|
|
}} |
|
|