Skip to content

Commit

Permalink
Quantize in range [1, 2^b)
Browse files Browse the repository at this point in the history
Due to quantization, some scores can be 0, but our frequency encoding
(which is used for scores) assumes positive values. To fix it, we
quantize into a range starting at 1 instead.

Changelog-changed: Scores are quantized starting at 1 instead of 0
Fixes: #572
Signed-off-by: Michal Siedlaczek <michal@siedlaczek.me>
  • Loading branch information
elshize committed Feb 12, 2024
1 parent 62d4dcc commit 25ff231
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 21 deletions.
18 changes: 4 additions & 14 deletions include/pisa/linear_quantizer.hpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#pragma once

#include <cmath>
#include <cstdint>

#include <fmt/format.h>
Expand All @@ -9,21 +8,12 @@
namespace pisa {

struct LinearQuantizer {
explicit LinearQuantizer(float max, std::uint8_t bits)
: m_max(max), m_scale(static_cast<float>((1U << bits) - 1U) / max) {
if (bits > 32 or bits == 0) {
throw std::runtime_error(fmt::format(
"Linear quantizer must take a number of bits between 1 and 32 but {} passed", bits
));
}
}

[[nodiscard]] auto operator()(float value) const -> std::uint32_t {
Expects(value <= m_max);
return std::ceil(value * m_scale);
}
LinearQuantizer(float max, std::uint8_t bits);
[[nodiscard]] auto operator()(float value) const -> std::uint32_t;
[[nodiscard]] auto range() const noexcept -> std::uint32_t;

private:
std::uint32_t m_range;
float m_max;
float m_scale;
};
Expand Down
5 changes: 4 additions & 1 deletion include/pisa/scorer/quantized.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ class QuantizingScorer {
[this, scorer = m_scorer->term_scorer(term_id)](std::uint32_t doc, std::uint32_t freq) {
auto score = scorer(doc, freq);
assert(score >= 0.0);
return this->m_quantizer(score);
auto quantized = this->m_quantizer(score);
assert(quantized >= 0);
assert(quantized <= this->m_quantizer.range());
return quantized;
};
}
};
Expand Down
30 changes: 30 additions & 0 deletions src/linear_quantizer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include <cmath>

#include "linear_quantizer.hpp"

namespace pisa {

LinearQuantizer::LinearQuantizer(float max, std::uint8_t bits)
: m_range((1U << bits) - 1U), m_max(max), m_scale(static_cast<float>(m_range - 1) / max) {
if (max <= 0.0) {
throw std::runtime_error(
fmt::format("Max score for linear quantizer must be positive but {} passed", max)
);
}
if (bits > 32 or bits < 2) {
throw std::runtime_error(fmt::format(
"Linear quantizer must take a number of bits between 2 and 32 but {} passed", bits
));
}
}

auto LinearQuantizer::operator()(float value) const -> std::uint32_t {
Expects(0 <= value && value <= m_max);
return std::round(value * m_scale) + 1;
}

auto LinearQuantizer::range() const noexcept -> std::uint32_t {
return m_range;
}

} // namespace pisa
6 changes: 3 additions & 3 deletions test/test_compress.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include "type_safe.hpp"
#include "wand_utils.hpp"

TEST_CASE("Compress block index", "[index][compress]") {
TEST_CASE("Compress index", "[index][compress]") {
std::string encoding = GENERATE(
"ef",
"single",
Expand Down Expand Up @@ -38,10 +38,10 @@ TEST_CASE("Compress block index", "[index][compress]") {
);
}

TEST_CASE("Compress quantized block index", "[index][compress]") {
TEST_CASE("Compress quantized index", "[index][compress]") {
auto input = PISA_SOURCE_DIR "/test/test_data/test_collection";

std::string scorer = GENERATE("bm25");
std::string scorer = GENERATE("bm25", "qld");
CAPTURE(scorer);
auto scorer_params = ScorerParams(scorer);

Expand Down
13 changes: 10 additions & 3 deletions test/test_linear_quantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,25 @@

TEST_CASE("LinearQuantizer", "[scoring][unit]") {
SECTION("construct") {
WHEN("number of bits is 0 or 33") {
std::uint8_t bits = GENERATE(0, 33);
WHEN("number of bits is 0, 1 or 33") {
std::uint8_t bits = GENERATE(0, 1, 33);
THEN("constructor fails") {
REQUIRE_THROWS(pisa::LinearQuantizer(10.0, bits));
}
}
}
SECTION("max is 0") {
std::uint8_t bits = 8;
float max = 0.0;
REQUIRE_THROWS(pisa::LinearQuantizer(max, bits));
}
SECTION("scores") {
std::uint8_t bits = GENERATE(3, 8, 12, 16, 19, 32);
float max = GENERATE(1.0, 100.0, std::numeric_limits<float>::max());
CAPTURE((int)bits);
CAPTURE(max);
pisa::LinearQuantizer quantizer(max, bits);
REQUIRE(quantizer(0) == 0);
REQUIRE(quantizer(0) == 1);
REQUIRE(quantizer(max) == (1 << bits) - 1);
}
}

0 comments on commit 25ff231

Please sign in to comment.