diff --git a/lttoolbox/CMakeLists.txt b/lttoolbox/CMakeLists.txt index 9d5b72f..2f210c8 100644 --- a/lttoolbox/CMakeLists.txt +++ b/lttoolbox/CMakeLists.txt @@ -21,6 +21,7 @@ set(LIBLTTOOLBOX_HEADERS node.h pattern_list.h regexp_compiler.h + reusable_state.h serialiser.h sorted_vector.h sorted_vector.hpp @@ -53,6 +54,7 @@ set(LIBLTTOOLBOX_SOURCES node.cc pattern_list.cc regexp_compiler.cc + reusable_state.cc sorted_vector.cc state.cc string_utils.cc diff --git a/lttoolbox/fst_processor.cc b/lttoolbox/fst_processor.cc index b219639..ff9c03c 100644 --- a/lttoolbox/fst_processor.cc +++ b/lttoolbox/fst_processor.cc @@ -597,6 +597,20 @@ FSTProcessor::filterFinals(const State& state, UStringView casefrom) uppercase, firstupper, 0); } +UString +FSTProcessor::filterFinals(const ReusableState& state, UStringView casefrom) +{ + bool firstupper = false, uppercase = false; + if (!dictionaryCase) { + firstupper = u_isupper(casefrom[0]); + uppercase = (casefrom.size() > 1 && + firstupper && u_isupper(casefrom[casefrom.size()-1])); + } + return state.filterFinals(all_finals, alphabet, escaped_chars, + displayWeightsMode, maxAnalyses, maxWeightClasses, + uppercase, firstupper, 0); +} + void FSTProcessor::writeEscaped(UStringView str, UFILE *output) { @@ -886,7 +900,9 @@ FSTProcessor::analysis(InputFile& input, UFILE *output) bool last_incond = false; bool last_postblank = false; bool last_preblank = false; - State current_state = initial_state; + //State current_state = initial_state; + ReusableState current_state; + current_state.init(&root); UString lf; // analysis (lexical form and tags) UString sf; // surface form UString lf_spcmp; // space compound analysis @@ -1141,7 +1157,7 @@ FSTProcessor::analysis(InputFile& input, UFILE *output) } } - current_state = initial_state; + current_state.init(&root); lf.clear(); sf.clear(); last_start = input_buffer.getPos(); @@ -1343,7 +1359,8 @@ FSTProcessor::generation(InputFile& input, UFILE *output, GenerationMode mode) generation_wrapper_null_flush(input, output, mode); } - State current_state = initial_state; + ReusableState current_state; + current_state.init(&root); UString sf; outOfWord = false; @@ -1468,7 +1485,7 @@ FSTProcessor::generation(InputFile& input, UFILE *output, GenerationMode mode) } } - current_state = initial_state; + current_state.init(&root); sf.clear(); } else if(u_isspace(val) && sf.size() == 0) @@ -1525,7 +1542,8 @@ FSTProcessor::transliteration(InputFile& input, UFILE *output) size_t cur_word = 0; size_t cur_pos = 0; size_t match_pos = 0; - State current_state = initial_state; + ReusableState current_state; + current_state.init(&root); UString last_match; int space_diff = 0; @@ -1705,7 +1723,7 @@ FSTProcessor::transliteration(InputFile& input, UFILE *output) firstupper = false; have_first = false; have_second = false; - current_state = initial_state; + current_state.init(&root); } } } diff --git a/lttoolbox/fst_processor.h b/lttoolbox/fst_processor.h index 940789a..ca24a66 100644 --- a/lttoolbox/fst_processor.h +++ b/lttoolbox/fst_processor.h @@ -24,6 +24,7 @@ #include #include #include +#include #include #include #include @@ -328,6 +329,7 @@ class FSTProcessor * Assumes that casefrom is non-empty */ UString filterFinals(const State& state, UStringView casefrom); + UString filterFinals(const ReusableState& state, UStringView casefrom); /** * Write a string to an output stream, @@ -450,11 +452,11 @@ class FSTProcessor * * @return running with --case-sensitive or state size exceeds max */ - bool beCaseSensitive(const State& state) { + bool beCaseSensitive(size_t size) { if(caseSensitive) { return true; } - else if(state.size() < max_case_insensitive_state_size) { + else if(size < max_case_insensitive_state_size) { return false; // ie. do case-folding } else { @@ -467,6 +469,11 @@ class FSTProcessor } } + bool beCaseSensitive(const State& s) { return beCaseSensitive(s.size()); } + bool beCaseSensitive(const ReusableState& s) { + return beCaseSensitive(s.size()); + } + public: /* diff --git a/lttoolbox/node.h b/lttoolbox/node.h index 34ae538..b922154 100644 --- a/lttoolbox/node.h +++ b/lttoolbox/node.h @@ -23,6 +23,7 @@ #include class State; +class ReusableState; class Node; @@ -35,6 +36,7 @@ class Dest double *out_weight; friend class State; + friend class ReusableState; friend class Node; void copy(Dest const &d) @@ -112,6 +114,7 @@ class Node { private: friend class State; + friend class ReusableState; /** * The outgoing transitions of this node. diff --git a/lttoolbox/reusable_state.cc b/lttoolbox/reusable_state.cc new file mode 100644 index 0000000..aa55c95 --- /dev/null +++ b/lttoolbox/reusable_state.cc @@ -0,0 +1,397 @@ +#include +#include + +#define WalkBack(var, pos, block) { \ + size_t index = pos; \ + while (index != 0) { \ + auto& var = get(index); \ + block \ + index = var.prev; \ + } \ + } + +#define StepLoop(block) { \ + size_t new_start = end; \ + for (size_t i = start; i < new_start; i++) { \ + block \ + } \ + start = new_start; \ + epsilonClosure(); \ + } + +ReusableState::ReusableState() {} + +ReusableState::~ReusableState() +{ + for (size_t i = 0; i < steps.size(); i++) { + delete steps[i]; + } + steps.clear(); +} + +ReusableState::Step& ReusableState::get_or_create(size_t index) +{ + size_t a = index >> STATE_STEP_BLOCK_SIZE_EXP; + size_t b = index & (STATE_STEP_BLOCK_SIZE-1); + while (a >= steps.size()) { + auto block = new std::array; + steps.push_back(block); + } + return (*(steps[a]))[b]; +} + +const ReusableState::Step& ReusableState::get(size_t index) const +{ + size_t a = index >> STATE_STEP_BLOCK_SIZE_EXP; + size_t b = index & (STATE_STEP_BLOCK_SIZE-1); + return (*(steps[a]))[b]; +} + +bool ReusableState::apply(int32_t input, size_t pos, + int32_t old_sym, int32_t new_sym, bool dirty) +{ + auto& prev = get(pos); + bool set_dirty = prev.dirty||dirty; + std::map::const_iterator it; + it = prev.where->transitions.find(input); + if (it != prev.where->transitions.end()) { + for (int j = 0; j < it->second.size; j++) { + Step& next = get_or_create(end); + next.where = it->second.dest[j]; + next.symbol = it->second.out_tag[j]; + if (old_sym && next.symbol == old_sym) next.symbol = new_sym; + next.weight = it->second.out_weight[j]; + next.dirty = set_dirty; + next.prev = pos; + end++; + } + return true; + } + return false; +} + +void ReusableState::epsilonClosure() +{ + for (size_t i = start; i < end; i++) { + apply(0, i, 0, 0, false); + } +} + +size_t ReusableState::size() const +{ + return end - start; +} + +void ReusableState::init(Node* initial) +{ + while (steps.size() > STATE_RESET_SIZE) { + delete steps[steps.size()-1]; + steps.pop_back(); + } + start = 0; + end = 1; + get_or_create(0).where = initial; + epsilonClosure(); +} + +void ReusableState::reinit(Node* initial) +{ + size_t start_was = start; + get_or_create(end).where = initial; + start = end; + end++; + epsilonClosure(); + start = start_was; +} + +void ReusableState::step(int32_t input) +{ + StepLoop({ + apply(input, i, 0, 0, false); + }) +} + +void ReusableState::step(int32_t input, int32_t alt) +{ + if (alt == 0 || alt == input) { + step(input); + return; + } + StepLoop({ + apply(input, i, 0, 0, false); + apply(alt, i, 0, 0, true); + }) +} + +void ReusableState::step_override(int32_t input, + int32_t old_sym, int32_t new_sym) +{ + StepLoop({ + apply(input, i, old_sym, new_sym, false); + }) +} + +void ReusableState::step_override(int32_t input, int32_t alt, + int32_t old_sym, int32_t new_sym) +{ + if (alt == 0 || alt == input) { + step_override(input, old_sym, new_sym); + return; + } + StepLoop({ + apply(input, i, old_sym, new_sym, false); + apply(alt, i, old_sym, new_sym, true); + }) +} + +void ReusableState::step_careful(int32_t input, int32_t alt) +{ + if (alt == 0 || alt == input) { + step(input); + return; + } + StepLoop({ + if (!apply(input, i, 0, 0, false)) { + apply(alt, i, 0, 0, true); + } + }) +} + +void ReusableState::step(int32_t input, int32_t alt1, int32_t alt2) +{ + if (alt1 == 0 || alt1 == input || alt1 == alt2) { + step(input, alt2); + return; + } else if (alt2 == 0 || alt2 == input) { + step(input, alt1); + return; + } + StepLoop({ + apply(input, i, 0, 0, false); + apply(alt1, i, 0, 0, true); + apply(alt2, i, 0, 0, true); + }) +} + +void ReusableState::step(int32_t input, std::set alts) +{ + StepLoop({ + apply(input, i, 0, 0, false); + for (auto& a : alts) { + if (a == 0 || a == input) continue; + apply(a, i, 0, 0, true); + } + }) +} + +void ReusableState::step_case(UChar32 val, UChar32 val2, bool caseSensitive) +{ + if (u_isupper(val) && !caseSensitive) { + step(val, u_tolower(val), val2); + } else { + step(val, val2); + } +} + +void ReusableState::step_case(UChar32 val, bool caseSensitive) +{ + if (!u_isupper(val) || caseSensitive) { + step(val); + } else { + step(val, u_tolower(val)); + } +} + +void ReusableState::step_case_override(UChar32 val, bool caseSensitive) +{ + if (!u_isupper(val) || caseSensitive) { + step(val); + } else { + step_override(val, u_tolower(val), u_tolower(val), val); + } +} + +void ReusableState::step_optional(int32_t val) +{ + size_t old_start = start; + step(val); + start = old_start; +} + +bool ReusableState::isFinal(const std::map& finals) const +{ + for (size_t i = start; i < end; i++) { + if (finals.find(get(i).where) != finals.end()) return true; + } + return false; +} + +void ReusableState::extract(size_t pos, UString& result, double& weight, + const Alphabet& alphabet, + const std::set& escaped_chars, + bool uppercase) const { + std::vector symbols; + WalkBack(it, pos, { + weight += it.weight; + if (it.symbol) symbols.push_back(it.symbol); + }) + for (auto it = symbols.rbegin(); it != symbols.rend(); it++) { + if (escaped_chars.find(*it) != escaped_chars.end()) result += '\\'; + alphabet.getSymbol(result, *it, uppercase); + } +} + +void NFinals(std::vector>& results, + size_t maxAnalyses, size_t maxWeightClasses) +{ + if (results.empty()) return; + sort(results.begin(), results.end()); + if (maxAnalyses < results.size()) { + results.erase(results.begin()+maxAnalyses, results.end()); + } + if (maxWeightClasses < results.size()) { + double last_weight = results[0].first + 1; + for (size_t i = 0; i < results.size(); i++) { + if (results[i].first != last_weight) { + last_weight = results[i].first; + if (maxWeightClasses == 0) { + results.erase(results.begin()+i, results.end()); + return; + } + maxWeightClasses--; + } + } + } +} + +UString ReusableState::filterFinals(const std::map& finals, + const Alphabet& alphabet, + const std::set& escaped_chars, + bool display_weights, + int max_analyses, int max_weight_classes, + bool uppercase, bool firstupper, + int firstchar) const +{ + std::vector> results; + + UString temp; + double weight; + for (size_t i = start; i < end; i++) { + auto fin = finals.find(get(i).where); + if (fin != finals.end()) { + weight = fin->second; + temp.clear(); + extract(i, temp, weight, alphabet, escaped_chars, uppercase); + if (firstupper && get(i).dirty) { + int idx = (temp[firstchar] == '~' ? firstchar + 1 : firstchar); + temp[idx] = u_toupper(temp[idx]); + } + results.push_back({weight, temp}); + } + } + + NFinals(results, max_analyses, max_weight_classes); + + temp.clear(); + std::set seen; + for (auto& it : results) { + if (seen.find(it.second) != seen.end()) continue; + seen.insert(it.second); + temp += '/'; + temp += it.second; + if (display_weights) { + UChar wbuf[16]{}; + // if anyone wants a weight of 10000, this will not be enough + u_sprintf(wbuf, "", it.first); + temp += wbuf; + } + } + return temp; +} + +bool ReusableState::lastPartHasRequiredSymbol(size_t pos, int32_t symbol, + int32_t separator) +{ + WalkBack(it, pos, { + if (it.symbol == symbol) return true; + else if (separator && it.symbol == separator) return false; + }); + return false; +} + +bool ReusableState::hasSymbol(int32_t symbol) +{ + for (size_t i = start; i < end; i++) { + if (lastPartHasRequiredSymbol(i, symbol, 0)) return true; + } + return false; +} + +void ReusableState::pruneCompounds(int32_t requiredSymbol, int32_t separator, + int maxElements) +{ + int min = maxElements; + size_t len = size(); + std::vector count(len, 0); + for (size_t i = 0; i < len; i++) { + bool found = false; + WalkBack(it, i+start, { + if (it.symbol == requiredSymbol && count[i] == 0) found = true; + else if (it.symbol == separator) { + if (found) count[i]++; + else { + count[i] = INT_MAX; + break; + } + } + }); + if (count[i] < min) min = count[i]; + } + size_t keep = 0; + for (size_t i = 0; i < len; i++) { + if (count[i] == min) { + size_t src = start + i; + size_t dest = start + keep; + // move the step that we're keeping, overwriting one that's being + // discarded, and shrink the state size + if (src != dest) get_or_create(dest) = get(src); + keep++; + } + } + end = start + keep; +} + +void ReusableState::restartFinals(const std::map& finals, + int32_t requiredSymbol, Node* restart, + int32_t separator) +{ + if (restart == nullptr) return; + for (size_t i = start, limit = end; i < limit; i++) { + auto& step = get(i); + if (finals.count(step.where) > 0 && + lastPartHasRequiredSymbol(i, requiredSymbol, separator)) { + size_t start_was = start; + start = end; + end++; + auto& newstep = get_or_create(start); + newstep.where = restart; + newstep.symbol = separator; + newstep.prev = i; + epsilonClosure(); + start = start_was; + } + } +} + +void ReusableState::pruneStatesWithForbiddenSymbol(int32_t symbol) +{ + size_t keep = 0; + for (size_t i = start; i < end; i++) { + if (!lastPartHasRequiredSymbol(i, symbol, 0)) { + size_t dest = start + keep; + if (i != dest) get_or_create(dest) = get(i); + keep++; + } + } + end = start + keep; +} diff --git a/lttoolbox/reusable_state.h b/lttoolbox/reusable_state.h new file mode 100644 index 0000000..94a39fe --- /dev/null +++ b/lttoolbox/reusable_state.h @@ -0,0 +1,83 @@ +#ifndef __LT_REUSABLE_STATE__ +#define __LT_REUSABLE_STATE__ + +#include +#include +#include + +#include +#include +#include + +#define STATE_STEP_BLOCK_SIZE_EXP 8 +#define STATE_STEP_BLOCK_SIZE (1<*> steps; + size_t start = 0; + size_t end = 1; + + Step& get_or_create(size_t index); + const Step& get(size_t index) const; + + bool apply(int32_t input, size_t pos, int32_t old_sym, int32_t new_sym, + bool dirty); + + void epsilonClosure(); + + void extract(size_t pos, UString& result, double& weight, + const Alphabet& alphabet, + const std::set& escaped_chars, bool uppercase) const; + +public: + ReusableState(); + ~ReusableState(); + + size_t size() const; + void init(Node* initial); + void reinit(Node* initial); + + void step(int32_t input); + void step(int32_t input, int32_t alt); + void step_override(int32_t input, int32_t old_sym, int32_t new_sym); + void step_override(int32_t input, int32_t alt, + int32_t old_sym, int32_t new_sym); + void step_careful(int32_t input, int32_t alt); + void step(int32_t input, int32_t alt1, int32_t alt2); + void step(int32_t input, std::set alts); + void step_case(UChar32 val, UChar32 val2, bool caseSensitive); + void step_case(UChar32 val, bool caseSensitive); + void step_case_override(UChar32 val, bool caseSensitive); + void step_optional(int32_t val); + + bool isFinal(const std::map& finals) const; + + UString filterFinals(const std::map& finals, + const Alphabet& alphabet, + const std::set& escaped_chars, + bool display_weights, + int max_analyses, int max_weight_classes, + bool uppercase, bool firstupper, + int firstchar = 0) const; + + bool lastPartHasRequiredSymbol(size_t pos, int32_t symbol, int32_t separator); + bool hasSymbol(int32_t symbol); + void pruneCompounds(int32_t requiredSymbol, int32_t separator, + int maxElements); + void restartFinals(const std::map& finals, + int32_t requiredSymbol, Node* restart_state, + int32_t separator); + void pruneStatesWithForbiddenSymbol(int32_t symbol); +}; + +#endif