|
#include "lm/filter/phrase.hh" |
|
|
|
#include "lm/filter/format.hh" |
|
|
|
#include <algorithm> |
|
#include <functional> |
|
#include <iostream> |
|
#include <queue> |
|
#include <string> |
|
#include <vector> |
|
|
|
#include <cctype> |
|
|
|
namespace lm { |
|
namespace phrase { |
|
|
|
unsigned int ReadMultiple(std::istream &in, Substrings &out) { |
|
bool sentence_content = false; |
|
unsigned int sentence_id = 0; |
|
std::vector<Hash> phrase; |
|
std::string word; |
|
while (in) { |
|
char c; |
|
|
|
while (!isspace(c = in.get()) && in) word += c; |
|
|
|
if (!in) c = '\n'; |
|
|
|
if (!word.empty()) { |
|
phrase.push_back(util::MurmurHashNative(word.data(), word.size())); |
|
word.clear(); |
|
} |
|
if (c == ' ') continue; |
|
|
|
if (!phrase.empty()) { |
|
sentence_content = true; |
|
out.AddPhrase(sentence_id, phrase.begin(), phrase.end()); |
|
phrase.clear(); |
|
} |
|
if (c == '\t' || c == '\v') continue; |
|
|
|
if (sentence_content) { |
|
++sentence_id; |
|
sentence_content = false; |
|
} |
|
} |
|
if (!in.eof()) in.exceptions(std::istream::failbit | std::istream::badbit); |
|
return sentence_id + sentence_content; |
|
} |
|
|
|
namespace { |
|
typedef unsigned int Sentence; |
|
typedef std::vector<Sentence> Sentences; |
|
} |
|
|
|
namespace detail { |
|
|
|
const StringPiece kEndSentence("</s>"); |
|
|
|
class Arc { |
|
public: |
|
Arc() {} |
|
|
|
|
|
void SetPhrase(detail::Vertex &from, detail::Vertex &to, const Sentences &intersect) { |
|
Set(to, intersect); |
|
from_ = &from; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
void SetRight(detail::Vertex &to, const Sentences &complete) { |
|
Set(to, complete); |
|
from_ = NULL; |
|
} |
|
|
|
Sentence Current() const { |
|
return *current_; |
|
} |
|
|
|
bool Empty() const { |
|
return current_ == last_; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void LowerBound(const Sentence to); |
|
|
|
private: |
|
void Set(detail::Vertex &to, const Sentences &sentences); |
|
|
|
const Sentence *current_; |
|
const Sentence *last_; |
|
detail::Vertex *from_; |
|
}; |
|
|
|
struct ArcGreater : public std::binary_function<const Arc *, const Arc *, bool> { |
|
bool operator()(const Arc *first, const Arc *second) const { |
|
return first->Current() > second->Current(); |
|
} |
|
}; |
|
|
|
class Vertex { |
|
public: |
|
Vertex() : current_(0) {} |
|
|
|
Sentence Current() const { |
|
return current_; |
|
} |
|
|
|
bool Empty() const { |
|
return incoming_.empty(); |
|
} |
|
|
|
void LowerBound(const Sentence to); |
|
|
|
private: |
|
friend class Arc; |
|
|
|
void AddIncoming(Arc *arc) { |
|
if (!arc->Empty()) incoming_.push(arc); |
|
} |
|
|
|
unsigned int current_; |
|
std::priority_queue<Arc*, std::vector<Arc*>, ArcGreater> incoming_; |
|
}; |
|
|
|
void Arc::LowerBound(const Sentence to) { |
|
current_ = std::lower_bound(current_, last_, to); |
|
|
|
|
|
if (!from_ || Empty() || (Current() > to)) return; |
|
assert(Current() == to); |
|
from_->LowerBound(to); |
|
if (from_->Empty()) { |
|
current_ = last_; |
|
return; |
|
} |
|
assert(from_->Current() >= to); |
|
if (from_->Current() > to) { |
|
current_ = std::lower_bound(current_ + 1, last_, from_->Current()); |
|
} |
|
} |
|
|
|
void Arc::Set(Vertex &to, const Sentences &sentences) { |
|
current_ = &*sentences.begin(); |
|
last_ = &*sentences.end(); |
|
to.AddIncoming(this); |
|
} |
|
|
|
void Vertex::LowerBound(const Sentence to) { |
|
if (Empty()) return; |
|
|
|
while (true) { |
|
Arc *top = incoming_.top(); |
|
if (top->Current() > to) { |
|
current_ = top->Current(); |
|
return; |
|
} |
|
|
|
|
|
incoming_.pop(); |
|
top->LowerBound(to); |
|
if (!top->Empty()) { |
|
incoming_.push(top); |
|
if (top->Current() == to) { |
|
current_ = to; |
|
return; |
|
} |
|
} else if (Empty()) { |
|
return; |
|
} |
|
} |
|
} |
|
|
|
} |
|
|
|
namespace { |
|
|
|
void BuildGraph(const Substrings &phrase, const std::vector<Hash> &hashes, detail::Vertex *const vertices, detail::Arc *free_arc) { |
|
using detail::Vertex; |
|
using detail::Arc; |
|
assert(!hashes.empty()); |
|
|
|
const Hash *const first_word = &*hashes.begin(); |
|
const Hash *const last_word = &*hashes.end() - 1; |
|
|
|
Hash hash = 0; |
|
const Sentences *found; |
|
|
|
{ |
|
Vertex *vertex = vertices; |
|
for (const Hash *word = first_word; ; ++word, ++vertex) { |
|
hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *word); |
|
|
|
if (word == last_word) { |
|
if (phrase.FindSubstring(hash, found)) |
|
(free_arc++)->SetRight(*vertex, *found); |
|
break; |
|
} |
|
if (!phrase.FindRight(hash, found)) break; |
|
(free_arc++)->SetRight(*vertex, *found); |
|
} |
|
} |
|
|
|
|
|
Vertex *vertex_from = vertices; |
|
for (const Hash *word_from = first_word + 1; word_from != &*hashes.end(); ++word_from, ++vertex_from) { |
|
hash = 0; |
|
Vertex *vertex_to = vertex_from + 1; |
|
for (const Hash *word_to = word_from; ; ++word_to, ++vertex_to) { |
|
|
|
hash = util::MurmurHashNative(&hash, sizeof(uint64_t), *word_to); |
|
|
|
if (word_to == last_word) { |
|
if (phrase.FindLeft(hash, found)) |
|
(free_arc++)->SetPhrase(*vertex_from, *vertex_to, *found); |
|
break; |
|
} |
|
if (!phrase.FindPhrase(hash, found)) break; |
|
(free_arc++)->SetPhrase(*vertex_from, *vertex_to, *found); |
|
} |
|
} |
|
} |
|
|
|
} |
|
|
|
namespace detail { |
|
|
|
|
|
ConditionCommon::ConditionCommon(const Substrings &substrings) : substrings_(substrings) {} |
|
|
|
|
|
ConditionCommon::ConditionCommon(const ConditionCommon &from) : substrings_(from.substrings_) {} |
|
|
|
ConditionCommon::~ConditionCommon() {} |
|
|
|
detail::Vertex &ConditionCommon::MakeGraph() { |
|
assert(!hashes_.empty()); |
|
vertices_.clear(); |
|
vertices_.resize(hashes_.size()); |
|
arcs_.clear(); |
|
|
|
arcs_.resize(((hashes_.size() + 1) * hashes_.size()) / 2); |
|
BuildGraph(substrings_, hashes_, &*vertices_.begin(), &*arcs_.begin()); |
|
return vertices_[hashes_.size() - 1]; |
|
} |
|
|
|
} |
|
|
|
bool Union::Evaluate() { |
|
detail::Vertex &last_vertex = MakeGraph(); |
|
unsigned int lower = 0; |
|
while (true) { |
|
last_vertex.LowerBound(lower); |
|
if (last_vertex.Empty()) return false; |
|
if (last_vertex.Current() == lower) return true; |
|
lower = last_vertex.Current(); |
|
} |
|
} |
|
|
|
template <class Output> void Multiple::Evaluate(const StringPiece &line, Output &output) { |
|
detail::Vertex &last_vertex = MakeGraph(); |
|
unsigned int lower = 0; |
|
while (true) { |
|
last_vertex.LowerBound(lower); |
|
if (last_vertex.Empty()) return; |
|
if (last_vertex.Current() == lower) { |
|
output.SingleAddNGram(lower, line); |
|
++lower; |
|
} else { |
|
lower = last_vertex.Current(); |
|
} |
|
} |
|
} |
|
|
|
template void Multiple::Evaluate<CountFormat::Multiple>(const StringPiece &line, CountFormat::Multiple &output); |
|
template void Multiple::Evaluate<ARPAFormat::Multiple>(const StringPiece &line, ARPAFormat::Multiple &output); |
|
template void Multiple::Evaluate<MultipleOutputBuffer>(const StringPiece &line, MultipleOutputBuffer &output); |
|
|
|
} |
|
} |
|
|