Skip to content

Commit

Permalink
Quantize in range [1, 2^b)
Browse files Browse the repository at this point in the history
Due to quantization, some scores can be 0, but our frequency encoding
(which is used for scores) assumes positive values. To fix it, we
quantize into a range starting at 1 instead.

Fixes: #572
  • Loading branch information
elshize committed Feb 8, 2024
1 parent aa8900b commit 4c790af
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 20 deletions.
18 changes: 4 additions & 14 deletions include/pisa/linear_quantizer.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#pragma once

#include <cmath>
#include <cstdint>

#include <fmt/format.h>
Expand All @@ -9,21 +8,12 @@
namespace pisa {

struct LinearQuantizer {
explicit LinearQuantizer(float max, std::uint8_t bits)
: m_max(max), m_scale(static_cast<float>((1U << bits) - 1U) / max) {
if (bits > 32 or bits == 0) {
throw std::runtime_error(fmt::format(
"Linear quantizer must take a number of bits between 1 and 32 but {} passed", bits
));
}
}

[[nodiscard]] auto operator()(float value) const -> std::uint32_t {
Expects(value <= m_max);
return std::ceil(value * m_scale);
}
LinearQuantizer(float max, std::uint8_t bits);
[[nodiscard]] auto operator()(float value) const -> std::uint32_t;
[[nodiscard]] auto range() const noexcept -> std::uint32_t;

private:
std::uint32_t m_range;
float m_max;
float m_scale;
};
Expand Down
5 changes: 4 additions & 1 deletion include/pisa/scorer/quantized.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ class QuantizingScorer {
[this, scorer = m_scorer->term_scorer(term_id)](std::uint32_t doc, std::uint32_t freq) {
auto score = scorer(doc, freq);
assert(score >= 0.0);
return this->m_quantizer(score);
auto quantized = this->m_quantizer(score);
assert(quantized >= 0);
assert(quantized <= this->m_quantizer.range());
return quantized;
};
}
};
Expand Down
25 changes: 25 additions & 0 deletions src/linear_quantizer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include <cmath>

#include "linear_quantizer.hpp"

namespace pisa {

LinearQuantizer::LinearQuantizer(float max, std::uint8_t bits)
: m_range((1U << bits) - 1U), m_max(max), m_scale(static_cast<float>(m_range - 1) / max) {
if (bits > 32 or bits < 2) {
throw std::runtime_error(fmt::format(
"Linear quantizer must take a number of bits between 2 and 32 but {} passed", bits
));
}
}

auto LinearQuantizer::operator()(float value) const -> std::uint32_t {
Expects(0 <= value && value <= m_max);
return std::ceil(value * m_scale) + 1;
}

auto LinearQuantizer::range() const noexcept -> std::uint32_t {
return m_range;
}

} // namespace pisa
6 changes: 3 additions & 3 deletions test/test_compress.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "type_safe.hpp"
#include "wand_utils.hpp"

TEST_CASE("Compress block index", "[index][compress]") {
TEST_CASE("Compress index", "[index][compress]") {
std::string encoding = GENERATE(
"ef",
"single",
Expand Down Expand Up @@ -38,10 +38,10 @@ TEST_CASE("Compress block index", "[index][compress]") {
);
}

TEST_CASE("Compress quantized block index", "[index][compress]") {
TEST_CASE("Compress quantized index", "[index][compress]") {
auto input = PISA_SOURCE_DIR "/test/test_data/test_collection";

std::string scorer = GENERATE("bm25");
std::string scorer = GENERATE("bm25", "qld");
CAPTURE(scorer);
auto scorer_params = ScorerParams(scorer);

Expand Down
10 changes: 8 additions & 2 deletions test/test_linear_quantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
TEST_CASE("LinearQuantizer", "[scoring][unit]") {
SECTION("construct") {
WHEN("number of bits is 0 or 33") {
std::uint8_t bits = GENERATE(0, 33);
std::uint8_t bits = GENERATE(0, 1, 33);
THEN("constructor fails") {
REQUIRE_THROWS(pisa::LinearQuantizer(10.0, bits));
}
Expand All @@ -18,7 +18,13 @@ TEST_CASE("LinearQuantizer", "[scoring][unit]") {
std::uint8_t bits = GENERATE(3, 8, 12, 16, 19, 32);
float max = GENERATE(1.0, 100.0, std::numeric_limits<float>::max());
pisa::LinearQuantizer quantizer(max, bits);
REQUIRE(quantizer(0) == 0);
REQUIRE(quantizer(0) == 1);
REQUIRE(quantizer(max) == (1 << bits) - 1);
}
SECTION("max is 0") {
std::uint8_t bits = GENERATE(3, 8, 12, 16, 19, 32);
float max = 0.0;
pisa::LinearQuantizer quantizer(max, bits);
REQUIRE(quantizer(0) == 1);
}
}

0 comments on commit 4c790af

Please sign in to comment.