From 7ea8cd2608e1b7550ffca44aff5596dd43fd4aa6 Mon Sep 17 00:00:00 2001 From: Ahmed Abualsaud <65791736+ahmedabu98@users.noreply.github.com> Date: Tue, 4 Jun 2024 19:02:00 -0400 Subject: [PATCH] Simplify Managed API to avoid dealing with PCollectionRowTuple (#31470) * Managed accepts PInput type * add unit test * spotless * spotless * rename to getSinglePCollection --- .../beam/sdk/values/PCollectionRowTuple.java | 17 +++++++ .../beam/sdk/io/iceberg/IcebergIOIT.java | 10 ++-- ...cebergReadSchemaTransformProviderTest.java | 4 +- ...ebergWriteSchemaTransformProviderTest.java | 14 ++---- .../org/apache/beam/sdk/io/kafka/KafkaIO.java | 2 +- .../KafkaReadSchemaTransformProviderTest.java | 4 +- ...KafkaWriteSchemaTransformProviderTest.java | 7 +-- .../org/apache/beam/sdk/managed/Managed.java | 46 +++++++++++++++---- .../managed/ManagedTransformConstants.java | 3 ++ ...ManagedSchemaTransformTranslationTest.java | 3 +- .../apache/beam/sdk/managed/ManagedTest.java | 34 ++++++++++++-- 11 files changed, 104 insertions(+), 40 deletions(-) diff --git a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionRowTuple.java b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionRowTuple.java index 0e7c52c4ae7d..a2a3aa74e539 100644 --- a/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionRowTuple.java +++ b/sdks/java/core/src/main/java/org/apache/beam/sdk/values/PCollectionRowTuple.java @@ -23,6 +23,7 @@ import java.util.Objects; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.checkerframework.checker.nullness.qual.Nullable; @@ -180,6 +181,22 @@ public PCollection get(String tag) { return pcollection; } + /** + * Like {@link #get(String)}, but is a convenience method to get a single PCollection without + * providing a tag for that output. Use only when there is a single collection in this tuple. + * + *

Throws {@link IllegalStateException} if more than one output exists in the {@link + * PCollectionRowTuple}. + */ + public PCollection getSinglePCollection() { + Preconditions.checkState( + pcollectionMap.size() == 1, + "Expected exactly one output PCollection, but found %s. " + + "Please try retrieving a specified output using get() instead.", + pcollectionMap.size()); + return get(pcollectionMap.entrySet().iterator().next().getKey()); + } + /** * Returns an immutable Map from tag to corresponding {@link PCollection}, for all the members of * this {@link PCollectionRowTuple}. diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOIT.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOIT.java index 06a63909c121..467a2cbaf242 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOIT.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergIOIT.java @@ -38,7 +38,6 @@ import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.apache.hadoop.conf.Configuration; @@ -216,11 +215,10 @@ public void testRead() throws Exception { .build()) .build(); - PCollectionRowTuple output = - PCollectionRowTuple.empty(readPipeline) - .apply(Managed.read(Managed.ICEBERG).withConfig(config)); + PCollection rows = + readPipeline.apply(Managed.read(Managed.ICEBERG).withConfig(config)).getSinglePCollection(); - PAssert.that(output.get("output")).containsInAnyOrder(expectedRows); + PAssert.that(rows).containsInAnyOrder(expectedRows); readPipeline.run().waitUntilFinish(); } @@ -258,7 +256,7 @@ public void testWrite() { .build(); PCollection input = writePipeline.apply(Create.of(inputRows)).setRowSchema(BEAM_SCHEMA); - PCollectionRowTuple.of("input", input).apply(Managed.write(Managed.ICEBERG).withConfig(config)); + input.apply(Managed.write(Managed.ICEBERG).withConfig(config)); writePipeline.run().waitUntilFinish(); diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergReadSchemaTransformProviderTest.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergReadSchemaTransformProviderTest.java index 27a31f31830d..46168a487dda 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergReadSchemaTransformProviderTest.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergReadSchemaTransformProviderTest.java @@ -166,9 +166,9 @@ public void testReadUsingManagedTransform() throws Exception { Map configMap = new Yaml().load(yamlConfig); PCollection output = - PCollectionRowTuple.empty(testPipeline) + testPipeline .apply(Managed.read(Managed.ICEBERG).withConfig(configMap)) - .get(OUTPUT_TAG); + .getSinglePCollection(); PAssert.that(output) .satisfies( diff --git a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProviderTest.java b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProviderTest.java index 97aebd5c41f3..9ef3e9945ec9 100644 --- a/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProviderTest.java +++ b/sdks/java/io/iceberg/src/test/java/org/apache/beam/sdk/io/iceberg/IcebergWriteSchemaTransformProviderTest.java @@ -134,16 +134,12 @@ public void testWriteUsingManagedTransform() { identifier, CatalogUtil.ICEBERG_CATALOG_TYPE_HADOOP, warehouse.location); Map configMap = new Yaml().load(yamlConfig); - PCollectionRowTuple input = - PCollectionRowTuple.of( - INPUT_TAG, - testPipeline - .apply( - "Records To Add", Create.of(TestFixtures.asRows(TestFixtures.FILE1SNAPSHOT1))) - .setRowSchema( - SchemaAndRowConversions.icebergSchemaToBeamSchema(TestFixtures.SCHEMA))); + PCollection inputRows = + testPipeline + .apply("Records To Add", Create.of(TestFixtures.asRows(TestFixtures.FILE1SNAPSHOT1))) + .setRowSchema(SchemaAndRowConversions.icebergSchemaToBeamSchema(TestFixtures.SCHEMA)); PCollection result = - input.apply(Managed.write(Managed.ICEBERG).withConfig(configMap)).get(OUTPUT_TAG); + inputRows.apply(Managed.write(Managed.ICEBERG).withConfig(configMap)).get(OUTPUT_TAG); PAssert.that(result).satisfies(new VerifyOutputs(identifier, "append")); diff --git a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java index e897ed439cd1..8f995a63a10f 100644 --- a/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java +++ b/sdks/java/io/kafka/src/main/java/org/apache/beam/sdk/io/kafka/KafkaIO.java @@ -2663,7 +2663,7 @@ abstract static class Builder { abstract Builder setProducerConfig(Map producerConfig); abstract Builder setProducerFactoryFn( - SerializableFunction, Producer> fn); + @Nullable SerializableFunction, Producer> fn); abstract Builder setKeySerializer(Class> serializer); diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java index f5ac5bb54ad7..dfe062e1eef4 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaReadSchemaTransformProviderTest.java @@ -36,7 +36,7 @@ import org.apache.beam.sdk.managed.ManagedTransformConstants; import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; import org.apache.beam.sdk.schemas.utils.YamlUtils; -import org.apache.beam.sdk.values.PCollectionRowTuple; +import org.apache.beam.sdk.values.PBegin; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.io.ByteStreams; @@ -319,7 +319,7 @@ public void testBuildTransformWithManaged() { // Kafka Read SchemaTransform gets built in ManagedSchemaTransformProvider's expand Managed.read(Managed.KAFKA) .withConfig(YamlUtils.yamlStringToMap(config)) - .expand(PCollectionRowTuple.empty(Pipeline.create())); + .expand(PBegin.in(Pipeline.create())); } } diff --git a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProviderTest.java b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProviderTest.java index 60bff89b3555..f19e91d89261 100644 --- a/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProviderTest.java +++ b/sdks/java/io/kafka/src/test/java/org/apache/beam/sdk/io/kafka/KafkaWriteSchemaTransformProviderTest.java @@ -43,7 +43,6 @@ import org.apache.beam.sdk.transforms.SerializableFunction; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.PCollectionTuple; import org.apache.beam.sdk.values.Row; import org.apache.beam.sdk.values.TupleTag; @@ -225,10 +224,8 @@ public void testBuildTransformWithManaged() { Managed.write(Managed.KAFKA) .withConfig(YamlUtils.yamlStringToMap(config)) .expand( - PCollectionRowTuple.of( - "input", - Pipeline.create() - .apply(Create.empty(Schema.builder().addByteArrayField("bytes").build())))); + Pipeline.create() + .apply(Create.empty(Schema.builder().addByteArrayField("bytes").build()))); } } diff --git a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java index da4a0853fb39..6f95290e6ee6 100644 --- a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java +++ b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/Managed.java @@ -22,11 +22,16 @@ import java.util.List; import java.util.Map; import javax.annotation.Nullable; +import org.apache.beam.sdk.coders.RowCoder; import org.apache.beam.sdk.schemas.transforms.SchemaTransform; import org.apache.beam.sdk.schemas.transforms.SchemaTransformProvider; import org.apache.beam.sdk.schemas.utils.YamlUtils; import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.values.PBegin; +import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionRowTuple; +import org.apache.beam.sdk.values.PInput; +import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.annotations.VisibleForTesting; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; @@ -47,12 +52,13 @@ * specifies arguments using like so: * *

{@code
- * PCollectionRowTuple output = PCollectionRowTuple.empty(pipeline).apply(
+ * PCollection rows = pipeline.apply(
  *       Managed.read(ICEBERG)
  *           .withConfig(ImmutableMap..builder()
  *               .put("foo", "abc")
  *               .put("bar", 123)
- *               .build()));
+ *               .build()))
+ *       .getOutput();
  * }
* *

Instead of specifying configuration arguments directly in the code, one can provide the @@ -66,11 +72,9 @@ *

The file's path can be passed in to the Managed API like so: * *

{@code
- * PCollectionRowTuple input = PCollectionRowTuple.of("input", pipeline.apply(Create.of(...)))
+ * PCollection inputRows = pipeline.apply(Create.of(...));
  *
- * PCollectionRowTuple output = input.apply(
- *     Managed.write(ICEBERG)
- *         .withConfigUrl());
+ * input.apply(Managed.write(ICEBERG).withConfigUrl());
  * }
*/ public class Managed { @@ -132,8 +136,7 @@ public static ManagedTransform write(String sink) { } @AutoValue - public abstract static class ManagedTransform - extends PTransform { + public abstract static class ManagedTransform extends PTransform { abstract String getIdentifier(); abstract @Nullable Map getConfig(); @@ -183,7 +186,9 @@ ManagedTransform withSupportedIdentifiers(List supportedIdentifiers) { } @Override - public PCollectionRowTuple expand(PCollectionRowTuple input) { + public PCollectionRowTuple expand(PInput input) { + PCollectionRowTuple inputTuple = resolveInput(input); + ManagedSchemaTransformProvider.ManagedConfig managedConfig = ManagedSchemaTransformProvider.ManagedConfig.builder() .setTransformIdentifier(getIdentifier()) @@ -194,7 +199,28 @@ public PCollectionRowTuple expand(PCollectionRowTuple input) { SchemaTransform underlyingTransform = new ManagedSchemaTransformProvider(getSupportedIdentifiers()).from(managedConfig); - return input.apply(underlyingTransform); + return inputTuple.apply(underlyingTransform); + } + + @VisibleForTesting + static PCollectionRowTuple resolveInput(PInput input) { + if (input instanceof PBegin) { + return PCollectionRowTuple.empty(input.getPipeline()); + } else if (input instanceof PCollection) { + PCollection inputCollection = (PCollection) input; + Preconditions.checkArgument( + inputCollection.getCoder() instanceof RowCoder, + "Input PCollection must contain Row elements with a set Schema " + + "(using .setRowSchema()). Instead, found collection %s with coder: %s.", + inputCollection.getName(), + inputCollection.getCoder()); + return PCollectionRowTuple.of( + ManagedTransformConstants.INPUT, (PCollection) inputCollection); + } else if (input instanceof PCollectionRowTuple) { + return (PCollectionRowTuple) input; + } + + throw new IllegalArgumentException("Unsupported input type: " + input.getClass()); } } } diff --git a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedTransformConstants.java b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedTransformConstants.java index 141544305a38..51d0b67b4b89 100644 --- a/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedTransformConstants.java +++ b/sdks/java/managed/src/main/java/org/apache/beam/sdk/managed/ManagedTransformConstants.java @@ -38,6 +38,9 @@ * every single parameter through the Managed interface. */ public class ManagedTransformConstants { + // Standard input PCollection tag + public static final String INPUT = "input"; + public static final String ICEBERG_READ = "beam:schematransform:org.apache.beam:iceberg_read:v1"; public static final String ICEBERG_WRITE = "beam:schematransform:org.apache.beam:iceberg_write:v1"; diff --git a/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedSchemaTransformTranslationTest.java b/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedSchemaTransformTranslationTest.java index f7769a9e1d19..0d122646d899 100644 --- a/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedSchemaTransformTranslationTest.java +++ b/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedSchemaTransformTranslationTest.java @@ -50,7 +50,6 @@ import org.apache.beam.sdk.util.construction.BeamUrns; import org.apache.beam.sdk.util.construction.PipelineTranslation; import org.apache.beam.sdk.values.PCollection; -import org.apache.beam.sdk.values.PCollectionRowTuple; import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.ByteString; import org.apache.beam.vendor.grpc.v1p60p1.com.google.protobuf.InvalidProtocolBufferException; @@ -141,7 +140,7 @@ public void testProtoTranslation() throws Exception { .setIdentifier(TestSchemaTransformProvider.IDENTIFIER) .build() .withConfig(underlyingConfig); - PCollectionRowTuple.of("input", input).apply(transform).get("output"); + input.apply(transform); // Then translate the pipeline to a proto and extract the ManagedSchemaTransform's proto RunnerApi.Pipeline pipelineProto = PipelineTranslation.toProto(p); diff --git a/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedTest.java b/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedTest.java index 7ed364d0e174..249faffec567 100644 --- a/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedTest.java +++ b/sdks/java/managed/src/test/java/org/apache/beam/sdk/managed/ManagedTest.java @@ -17,17 +17,23 @@ */ package org.apache.beam.sdk.managed; +import static org.junit.Assert.assertThrows; + import java.nio.file.Paths; import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; +import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.managed.testing.TestSchemaTransformProvider; import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.testing.PAssert; import org.apache.beam.sdk.testing.TestPipeline; import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.values.PBegin; import org.apache.beam.sdk.values.PCollection; import org.apache.beam.sdk.values.PCollectionRowTuple; +import org.apache.beam.sdk.values.PCollectionTuple; +import org.apache.beam.sdk.values.PInput; import org.apache.beam.sdk.values.Row; import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableMap; import org.junit.Rule; @@ -61,11 +67,33 @@ public void testInvalidTransform() { Row.withSchema(SCHEMA).withFieldValue("str", "b").withFieldValue("int", 2).build(), Row.withSchema(SCHEMA).withFieldValue("str", "c").withFieldValue("int", 3).build()); + @Test + public void testResolveInputToPCollectionRowTuple() { + Pipeline p = Pipeline.create(); + List inputTypes = + Arrays.asList( + PBegin.in(p), + p.apply(Create.of(ROWS).withRowSchema(SCHEMA)), + PCollectionRowTuple.of("pcoll", p.apply(Create.of(ROWS).withRowSchema(SCHEMA)))); + + List badInputTypes = + Arrays.asList( + p.apply(Create.of(1, 2, 3)), + p.apply(Create.of(ROWS)), + PCollectionTuple.of("pcoll", p.apply(Create.of(ROWS)))); + + for (PInput input : inputTypes) { + Managed.ManagedTransform.resolveInput(input); + } + for (PInput badInput : badInputTypes) { + assertThrows( + IllegalArgumentException.class, () -> Managed.ManagedTransform.resolveInput(badInput)); + } + } + public void runTestProviderTest(Managed.ManagedTransform writeOp) { PCollection rows = - PCollectionRowTuple.of("input", pipeline.apply(Create.of(ROWS)).setRowSchema(SCHEMA)) - .apply(writeOp) - .get("output"); + pipeline.apply(Create.of(ROWS)).setRowSchema(SCHEMA).apply(writeOp).getSinglePCollection(); Schema outputSchema = rows.getSchema(); PAssert.that(rows)