From 3588d195335fa3dc06b002e5e468baa27e79f8fa Mon Sep 17 00:00:00 2001 From: twosom <72733442+twosom@users.noreply.github.com> Date: Fri, 28 Jun 2024 23:08:52 +0900 Subject: [PATCH] Add spark mapstate (#31669) * add isEmpty test in testMap * add map state in spark runner * update comment on SparkStateInternalsTest * modify isEmpty test to assertFalse / assertTrue --- .../beam/runners/core/StateInternalsTest.java | 6 + .../spark/stateful/SparkStateInternals.java | 148 +++++++++++++++++- .../stateful/SparkStateInternalsTest.java | 10 +- 3 files changed, 152 insertions(+), 12 deletions(-) diff --git a/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java b/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java index a4cd504eee71..e15249969f26 100644 --- a/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java +++ b/runners/core-java/src/test/java/org/apache/beam/runners/core/StateInternalsTest.java @@ -386,10 +386,16 @@ public void testMap() throws Exception { value.entries().readLater().read(), containsInAnyOrder(MapEntry.of("B", 2), MapEntry.of("D", 4), MapEntry.of("E", 5))); + // isEmpty + assertFalse(value.isEmpty().read()); + // clear value.clear(); assertThat(value.entries().read(), Matchers.emptyIterable()); assertThat(underTest.state(NAMESPACE_1, STRING_MAP_ADDR), equalTo(value)); + + // isEmpty + assertTrue(value.isEmpty().read()); } @Test diff --git a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java index 3ad955e78aa7..731cadb89f0c 100644 --- a/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java +++ b/runners/spark/src/main/java/org/apache/beam/runners/spark/stateful/SparkStateInternals.java @@ -19,7 +19,11 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.function.Function; import org.apache.beam.runners.core.StateInternals; import org.apache.beam.runners.core.StateNamespace; import org.apache.beam.runners.core.StateTag; @@ -28,12 +32,14 @@ import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.InstantCoder; import org.apache.beam.sdk.coders.ListCoder; +import org.apache.beam.sdk.coders.MapCoder; import org.apache.beam.sdk.state.BagState; import org.apache.beam.sdk.state.CombiningState; import org.apache.beam.sdk.state.MapState; import org.apache.beam.sdk.state.MultimapState; import org.apache.beam.sdk.state.OrderedListState; import org.apache.beam.sdk.state.ReadableState; +import org.apache.beam.sdk.state.ReadableStates; import org.apache.beam.sdk.state.SetState; import org.apache.beam.sdk.state.State; import org.apache.beam.sdk.state.StateContext; @@ -44,6 +50,7 @@ import org.apache.beam.sdk.transforms.windowing.TimestampCombiner; import org.apache.beam.sdk.util.CombineFnUtil; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.HashBasedTable; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Table; import org.checkerframework.checker.nullness.qual.Nullable; import org.joda.time.Instant; @@ -119,11 +126,10 @@ public SetState bindSet(StateTag> spec, Coder elemCoder) { @Override public MapState bindMap( - StateTag> spec, + StateTag> address, Coder mapKeyCoder, Coder mapValueCoder) { - throw new UnsupportedOperationException( - String.format("%s is not supported", MapState.class.getSimpleName())); + return new SparkMapState<>(namespace, address, MapCoder.of(mapKeyCoder, mapValueCoder)); } @Override @@ -359,6 +365,142 @@ public AccumT mergeAccumulators(Iterable accumulators) { } } + private final class SparkMapState + extends AbstractState> implements MapState { + + private SparkMapState( + StateNamespace namespace, + StateTag address, + Coder> coder) { + super(namespace, address, coder); + } + + @Override + public ReadableState get(MapKeyT key) { + return getOrDefault(key, null); + } + + @Override + public ReadableState getOrDefault(MapKeyT key, @Nullable MapValueT defaultValue) { + return new ReadableState() { + @Override + public MapValueT read() { + Map sparkMapState = readValue(); + if (sparkMapState == null) { + return defaultValue; + } + return sparkMapState.getOrDefault(key, defaultValue); + } + + @Override + public ReadableState readLater() { + return this; + } + }; + } + + @Override + public void put(MapKeyT key, MapValueT value) { + Map sparkMapState = readValue(); + if (sparkMapState == null) { + sparkMapState = new HashMap<>(); + } + sparkMapState.put(key, value); + writeValue(sparkMapState); + } + + @Override + public ReadableState computeIfAbsent( + MapKeyT key, Function mappingFunction) { + Map sparkMapState = readValue(); + MapValueT current = sparkMapState.get(key); + if (current == null) { + put(key, mappingFunction.apply(key)); + } + return ReadableStates.immediate(current); + } + + @Override + public void remove(MapKeyT key) { + Map sparkMapState = readValue(); + sparkMapState.remove(key); + writeValue(sparkMapState); + } + + @Override + public ReadableState> keys() { + return new ReadableState>() { + @Override + public Iterable read() { + Map sparkMapState = readValue(); + if (sparkMapState == null) { + return Collections.emptyList(); + } + return sparkMapState.keySet(); + } + + @Override + public ReadableState> readLater() { + return this; + } + }; + } + + @Override + public ReadableState> values() { + return new ReadableState>() { + @Override + public Iterable read() { + Map sparkMapState = readValue(); + if (sparkMapState == null) { + return Collections.emptyList(); + } + Iterable result = readValue().values(); + return result != null ? ImmutableList.copyOf(result) : Collections.emptyList(); + } + + @Override + public ReadableState> readLater() { + return this; + } + }; + } + + @Override + public ReadableState>> entries() { + return new ReadableState>>() { + @Override + public Iterable> read() { + Map sparkMapState = readValue(); + if (sparkMapState == null) { + return Collections.emptyList(); + } + return sparkMapState.entrySet(); + } + + @Override + public ReadableState>> readLater() { + return this; + } + }; + } + + @Override + public ReadableState isEmpty() { + return new ReadableState() { + @Override + public Boolean read() { + return stateTable.get(namespace.stringKey(), address.getId()) == null; + } + + @Override + public ReadableState readLater() { + return this; + } + }; + } + } + private final class SparkBagState extends AbstractState> implements BagState { private SparkBagState(StateNamespace namespace, StateTag> address, Coder coder) { super(namespace, address, ListCoder.of(coder)); diff --git a/runners/spark/src/test/java/org/apache/beam/runners/spark/stateful/SparkStateInternalsTest.java b/runners/spark/src/test/java/org/apache/beam/runners/spark/stateful/SparkStateInternalsTest.java index 6118f96fcc33..f6f2b8d6df6f 100644 --- a/runners/spark/src/test/java/org/apache/beam/runners/spark/stateful/SparkStateInternalsTest.java +++ b/runners/spark/src/test/java/org/apache/beam/runners/spark/stateful/SparkStateInternalsTest.java @@ -25,7 +25,7 @@ /** * Tests for {@link SparkStateInternals}. This is based on {@link StateInternalsTest}. Ignore set - * and map tests. + * tests. */ @RunWith(JUnit4.class) public class SparkStateInternalsTest extends StateInternalsTest { @@ -51,15 +51,7 @@ public void testMergeSetIntoSource() {} @Ignore public void testMergeSetIntoNewNamespace() {} - @Override - @Ignore - public void testMap() {} - @Override @Ignore public void testSetReadable() {} - - @Override - @Ignore - public void testMapReadable() {} }