diff --git a/include/pisa/linear_quantizer.hpp b/include/pisa/linear_quantizer.hpp index 9d2faf50..c371aa31 100644 --- a/include/pisa/linear_quantizer.hpp +++ b/include/pisa/linear_quantizer.hpp @@ -1,6 +1,5 @@ #pragma once -#include #include #include @@ -9,21 +8,12 @@ namespace pisa { struct LinearQuantizer { - explicit LinearQuantizer(float max, std::uint8_t bits) - : m_max(max), m_scale(static_cast((1U << bits) - 1U) / max) { - if (bits > 32 or bits == 0) { - throw std::runtime_error(fmt::format( - "Linear quantizer must take a number of bits between 1 and 32 but {} passed", bits - )); - } - } - - [[nodiscard]] auto operator()(float value) const -> std::uint32_t { - Expects(value <= m_max); - return std::ceil(value * m_scale); - } + LinearQuantizer(float max, std::uint8_t bits); + [[nodiscard]] auto operator()(float value) const -> std::uint32_t; + [[nodiscard]] auto range() const noexcept -> std::uint32_t; private: + std::uint32_t m_range; float m_max; float m_scale; }; diff --git a/include/pisa/scorer/quantized.hpp b/include/pisa/scorer/quantized.hpp index c0afefc8..659e047e 100644 --- a/include/pisa/scorer/quantized.hpp +++ b/include/pisa/scorer/quantized.hpp @@ -39,7 +39,10 @@ class QuantizingScorer { [this, scorer = m_scorer->term_scorer(term_id)](std::uint32_t doc, std::uint32_t freq) { auto score = scorer(doc, freq); assert(score >= 0.0); - return this->m_quantizer(score); + auto quantized = this->m_quantizer(score); + assert(quantized >= 0); + assert(quantized <= this->m_quantizer.range()); + return quantized; }; } }; diff --git a/src/linear_quantizer.cpp b/src/linear_quantizer.cpp new file mode 100644 index 00000000..a0c0c81c --- /dev/null +++ b/src/linear_quantizer.cpp @@ -0,0 +1,30 @@ +#include + +#include "linear_quantizer.hpp" + +namespace pisa { + +LinearQuantizer::LinearQuantizer(float max, std::uint8_t bits) + : m_range((1U << bits) - 1U), m_max(max), m_scale(static_cast(m_range - 1) / max) { + if (max <= 0.0) { + throw std::runtime_error( + fmt::format("Max score for linear quantizer must be positive but {} passed", max) + ); + } + if (bits > 32 or bits < 2) { + throw std::runtime_error(fmt::format( + "Linear quantizer must take a number of bits between 2 and 32 but {} passed", bits + )); + } +} + +auto LinearQuantizer::operator()(float value) const -> std::uint32_t { + Expects(0 <= value && value <= m_max); + return std::round(value * m_scale) + 1; +} + +auto LinearQuantizer::range() const noexcept -> std::uint32_t { + return m_range; +} + +} // namespace pisa diff --git a/test/test_compress.cpp b/test/test_compress.cpp index 211d656d..81fdf664 100644 --- a/test/test_compress.cpp +++ b/test/test_compress.cpp @@ -9,7 +9,7 @@ #include "type_safe.hpp" #include "wand_utils.hpp" -TEST_CASE("Compress block index", "[index][compress]") { +TEST_CASE("Compress index", "[index][compress]") { std::string encoding = GENERATE( "ef", "single", @@ -38,10 +38,10 @@ TEST_CASE("Compress block index", "[index][compress]") { ); } -TEST_CASE("Compress quantized block index", "[index][compress]") { +TEST_CASE("Compress quantized index", "[index][compress]") { auto input = PISA_SOURCE_DIR "/test/test_data/test_collection"; - std::string scorer = GENERATE("bm25"); + std::string scorer = GENERATE("bm25", "qld"); CAPTURE(scorer); auto scorer_params = ScorerParams(scorer); diff --git a/test/test_linear_quantizer.cpp b/test/test_linear_quantizer.cpp index 2787e92f..1bb1d6f8 100644 --- a/test/test_linear_quantizer.cpp +++ b/test/test_linear_quantizer.cpp @@ -7,18 +7,25 @@ TEST_CASE("LinearQuantizer", "[scoring][unit]") { SECTION("construct") { - WHEN("number of bits is 0 or 33") { - std::uint8_t bits = GENERATE(0, 33); + WHEN("number of bits is 0, 1 or 33") { + std::uint8_t bits = GENERATE(0, 1, 33); THEN("constructor fails") { REQUIRE_THROWS(pisa::LinearQuantizer(10.0, bits)); } } } + SECTION("max is 0") { + std::uint8_t bits = 8; + float max = 0.0; + REQUIRE_THROWS(pisa::LinearQuantizer(max, bits)); + } SECTION("scores") { std::uint8_t bits = GENERATE(3, 8, 12, 16, 19, 32); float max = GENERATE(1.0, 100.0, std::numeric_limits::max()); + CAPTURE((int)bits); + CAPTURE(max); pisa::LinearQuantizer quantizer(max, bits); - REQUIRE(quantizer(0) == 0); + REQUIRE(quantizer(0) == 1); REQUIRE(quantizer(max) == (1 << bits) - 1); } }