Skip to content

Commit

Permalink
Support KnnQuery's method_parameters and rescore
Browse files Browse the repository at this point in the history
Signed-off-by: Thomas Farr <tsfarr@amazon.com>
  • Loading branch information
Xtansia committed Feb 3, 2025
1 parent fcde8ec commit 25f0e41
Show file tree
Hide file tree
Showing 4 changed files with 258 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)

## [Unreleased 2.x]
### Added
- Added support for `KnnQuery`'s `method_parameters` and `rescore` properties ([#1407](https://github.com/opensearch-project/opensearch-java/pull/1407))

### Dependencies
- Bump `org.junit:junit-bom` from 5.11.3 to 5.11.4 ([#1367](https://github.com/opensearch-project/opensearch-java/pull/1367))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,13 @@
package org.opensearch.client.opensearch._types.query_dsl;

import jakarta.json.stream.JsonGenerator;

import java.util.Map;
import java.util.function.Function;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;

import org.opensearch.client.json.JsonData;
import org.opensearch.client.json.JsonpDeserializable;
import org.opensearch.client.json.JsonpDeserializer;
import org.opensearch.client.json.JsonpMapper;
Expand All @@ -31,6 +36,10 @@ public class KnnQuery extends QueryBase implements QueryVariant {
private final Float maxDistance;
@Nullable
private final Query filter;
@Nonnull
private final Map<String, JsonData> methodParameters;
@Nullable
private final KnnQueryRescore rescore;

private KnnQuery(Builder builder) {
super(builder);
Expand All @@ -41,6 +50,8 @@ private KnnQuery(Builder builder) {
this.minScore = builder.minScore;
this.maxDistance = builder.maxDistance;
this.filter = builder.filter;
this.methodParameters = ApiTypeHelper.unmodifiable(builder.methodParameters);
this.rescore = builder.rescore;
}

public static KnnQuery of(Function<Builder, ObjectBuilder<KnnQuery>> fn) {
Expand Down Expand Up @@ -108,6 +119,16 @@ public final Query filter() {
return this.filter;
}

@Nonnull
public final Map<String, JsonData> methodParameters() {
return this.methodParameters;
}

@Nullable
public final KnnQueryRescore rescore() {
return this.rescore;
}

@Override
protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
generator.writeStartObject(this.field);
Expand Down Expand Up @@ -138,11 +159,23 @@ protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
this.filter.serialize(generator, mapper);
}

if (ApiTypeHelper.isDefined(this.methodParameters)) {
for (Map.Entry<String, JsonData> entry : this.methodParameters.entrySet()) {
generator.writeKey(entry.getKey());
entry.getValue().serialize(generator, mapper);
}
}

if (this.rescore != null) {
generator.writeKey("rescore");
this.rescore.serialize(generator, mapper);
}

generator.writeEnd();
}

public Builder toBuilder() {
return toBuilder(new Builder()).field(field).vector(vector).k(k).minScore(minScore).maxDistance(maxDistance).filter(filter);
return toBuilder(new Builder()).field(field).vector(vector).k(k).minScore(minScore).maxDistance(maxDistance).filter(filter).methodParameters(methodParameters).rescore(rescore);
}

/**
Expand All @@ -161,6 +194,10 @@ public static class Builder extends QueryBase.AbstractBuilder<Builder> implement
private Float maxDistance;
@Nullable
private Query filter;
@Nullable
private Map<String, JsonData> methodParameters;
@Nullable
private KnnQueryRescore rescore;

/**
* Required - The target field.
Expand Down Expand Up @@ -227,6 +264,25 @@ public Builder filter(@Nullable Query filter) {
return this;
}

public Builder methodParameters(@Nonnull Map<String, JsonData> value) {
this.methodParameters = _mapPutAll(this.methodParameters, value);
return this;
}

public Builder methodParameters(String key, JsonData value) {
this.methodParameters = _mapPut(this.methodParameters, key, value);
return this;
}

public Builder rescore(@Nullable KnnQueryRescore value) {
this.rescore = value;
return this;
}

public Builder rescore(Function<KnnQueryRescore.Builder, ObjectBuilder<KnnQueryRescore>> fn) {
return this.rescore(fn.apply(new KnnQueryRescore.Builder()).build());
}

@Override
protected Builder self() {
return this;
Expand Down Expand Up @@ -264,6 +320,8 @@ protected static void setupKnnQueryDeserializer(ObjectDeserializer<Builder> op)
op.add(Builder::minScore, JsonpDeserializer.floatDeserializer(), "min_score");
op.add(Builder::maxDistance, JsonpDeserializer.floatDeserializer(), "max_distance");
op.add(Builder::filter, Query._DESERIALIZER, "filter");
op.add(Builder::methodParameters, JsonpDeserializer.stringMapDeserializer(JsonData._DESERIALIZER), "method_parameters");
op.add(Builder::rescore, KnnQueryRescore._DESERIALIZER, "rescore");

op.setKey(Builder::field, JsonpDeserializer.stringDeserializer());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.client.opensearch._types.query_dsl;

import jakarta.json.stream.JsonGenerator;
import org.opensearch.client.json.JsonpDeserializable;
import org.opensearch.client.json.JsonpDeserializer;
import org.opensearch.client.json.JsonpMapper;
import org.opensearch.client.json.JsonpSerializable;
import org.opensearch.client.json.PlainJsonSerializable;
import org.opensearch.client.json.UnionDeserializer;
import org.opensearch.client.util.ApiTypeHelper;
import org.opensearch.client.util.ObjectBuilder;
import org.opensearch.client.util.ObjectBuilderBase;
import org.opensearch.client.util.TaggedUnion;
import org.opensearch.client.util.TaggedUnionUtils;

@JsonpDeserializable
public class KnnQueryRescore implements TaggedUnion<KnnQueryRescore.Kind, Object>, PlainJsonSerializable {
public enum Kind {
Enable,
Context
}

private final Kind _kind;
private final Object _value;

@Override
public Kind _kind() {
return _kind;
}

@Override
public Object _get() {
return _value;
}

private KnnQueryRescore(Kind kind, Object value) {
this._kind = kind;
this._value = value;
}

private KnnQueryRescore(Builder builder) {
this._kind = ApiTypeHelper.requireNonNull(builder._kind, builder, "<variant kind>");
this._value = ApiTypeHelper.requireNonNull(builder._value, builder, "<variant value>");
}

public static KnnQueryRescore of(Function<Builder, ObjectBuilder<KnnQueryRescore>> fn) {
return fn.apply(new Builder()).build();
}

public boolean isEnable() {
return _kind == Kind.Enable;
}

public Boolean enable() {
return TaggedUnionUtils.get(this, Kind.Enable);
}

public boolean isContext() {
return _kind == Kind.Context;
}

public RescoreContext context() {
return TaggedUnionUtils.get(this, Kind.Context);
}

@Override
public void serialize(JsonGenerator generator, JsonpMapper mapper) {
if (_value instanceof JsonpSerializable) {
((JsonpSerializable) _value).serialize(generator, mapper);
} else {
switch (_kind) {
case Enable:
generator.write((Boolean) _value);
break;
}
}
}

public static class Builder extends ObjectBuilderBase implements ObjectBuilder<KnnQueryRescore> {
private Kind _kind;
private Object _value;

public ObjectBuilder<KnnQueryRescore> enable(Boolean v) {
this._kind = Kind.Enable;
this._value = v;
return this;
}

public ObjectBuilder<KnnQueryRescore> context(RescoreContext v) {
this._kind = Kind.Context;
this._value = v;
return this;
}

@Override
public KnnQueryRescore build() {
_checkSingleUse();
return new KnnQueryRescore(this);
}
}

private static JsonpDeserializer<KnnQueryRescore> buildKnnQueryRescoreDeserializer() {
return new UnionDeserializer.Builder<KnnQueryRescore, Kind, Object>(KnnQueryRescore::new, false)
.addMember(Kind.Enable, JsonpDeserializer.booleanDeserializer())
.addMember(Kind.Context, RescoreContext._DESERIALIZER)
.build();
}

public static final JsonpDeserializer<KnnQueryRescore> _DESERIALIZER = JsonpDeserializer.lazy(KnnQueryRescore::buildKnnQueryRescoreDeserializer);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.client.opensearch._types.query_dsl;

import jakarta.json.stream.JsonGenerator;
import org.opensearch.client.json.JsonpDeserializable;
import org.opensearch.client.json.JsonpDeserializer;
import org.opensearch.client.json.JsonpMapper;
import org.opensearch.client.json.ObjectBuilderDeserializer;
import org.opensearch.client.json.ObjectDeserializer;
import org.opensearch.client.json.PlainJsonSerializable;
import org.opensearch.client.util.ObjectBuilder;
import org.opensearch.client.util.ObjectBuilderBase;

import javax.annotation.Nullable;
import java.util.function.Function;

@JsonpDeserializable
public class RescoreContext implements PlainJsonSerializable {
@Nullable
private final Float oversampleFactor;

private RescoreContext(Builder builder) {
this.oversampleFactor = builder.oversampleFactor;
}

public static RescoreContext of(Function<Builder, ObjectBuilder<RescoreContext>> fn) {
return fn.apply(new Builder()).build();
}

@Nullable
public Float oversampleFactor() {
return this.oversampleFactor;
}

@Override
public void serialize(JsonGenerator generator, JsonpMapper mapper) {
generator.writeStartObject();
serializeInternal(generator, mapper);
generator.writeEnd();
}

protected void serializeInternal(JsonGenerator generator, JsonpMapper mapper) {
if (this.oversampleFactor != null) {
generator.writeKey("oversample_factor");
generator.write(this.oversampleFactor);
}
}

public static class Builder extends ObjectBuilderBase implements ObjectBuilder<RescoreContext> {
@Nullable
private Float oversampleFactor;

public Builder oversampleFactor(@Nullable Float value) {
this.oversampleFactor = value;
return this;
}

@Override
public RescoreContext build() {
_checkSingleUse();
return new RescoreContext(this);
}
}

public static final JsonpDeserializer<RescoreContext> _DESERIALIZER = ObjectBuilderDeserializer.lazy(
Builder::new,
RescoreContext::setupRescoreContextDeserializer
);

protected static void setupRescoreContextDeserializer(ObjectDeserializer<Builder> op) {
op.add(Builder::oversampleFactor, JsonpDeserializer.floatDeserializer(), "oversample_factor");
}
}

0 comments on commit 25f0e41

Please sign in to comment.