diff --git a/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java b/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java index c619b563b0..39402cf236 100644 --- a/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java +++ b/common/src/main/java/org/apache/nemo/common/ir/IRDAG.java @@ -38,6 +38,7 @@ import org.apache.nemo.common.ir.vertex.SourceVertex; import org.apache.nemo.common.ir.vertex.executionproperty.MessageIdVertexProperty; import org.apache.nemo.common.ir.vertex.executionproperty.ParallelismProperty; +import org.apache.nemo.common.ir.vertex.utility.IntermediateAccumulatorVertex; import org.apache.nemo.common.ir.vertex.utility.TaskSizeSplitterVertex; import org.apache.nemo.common.ir.vertex.utility.runtimepass.MessageAggregatorVertex; import org.apache.nemo.common.ir.vertex.utility.runtimepass.MessageGeneratorVertex; @@ -799,6 +800,39 @@ public void insert(final TaskSizeSplitterVertex toInsert) { modifiedDAG = builder.build(); } + public void insert(final IntermediateAccumulatorVertex accumulatorVertex, final IREdge targetEdge) { + // Create a completely new DAG with the vertex inserted. + final DAGBuilder builder = new DAGBuilder<>(); + + builder.addVertex(accumulatorVertex); + modifiedDAG.topologicalDo(v -> { + builder.addVertex(v); + + modifiedDAG.getIncomingEdgesOf(v).forEach(e -> { + if (e == targetEdge) { + // Edge to the accumulatorVertex + final IREdge toAV = new IREdge(CommunicationPatternProperty.Value.PARTIAL_SHUFFLE, + e.getSrc(), accumulatorVertex); + e.copyExecutionPropertiesTo(toAV); + toAV.setProperty(CommunicationPatternProperty.of(CommunicationPatternProperty.Value.PARTIAL_SHUFFLE)); + + // Edge from the accumulatorVertex + final IREdge fromAV = new IREdge(CommunicationPatternProperty.Value.SHUFFLE, accumulatorVertex, e.getDst()); + e.copyExecutionPropertiesTo(fromAV); + + // Connect the new edges + builder.connectVertices(toAV); + builder.connectVertices(fromAV); + } else { + // Simply connect vertices as before + builder.connectVertices(e); + } + }); + }); + + modifiedDAG = builder.build(); + } + /** * Reshape unsafely, without guarantees on preserving application semantics. * TODO #330: Refactor Unsafe Reshaping Passes diff --git a/common/src/main/java/org/apache/nemo/common/ir/IRDAGChecker.java b/common/src/main/java/org/apache/nemo/common/ir/IRDAGChecker.java index ae8a8b3889..639bd5c4d6 100644 --- a/common/src/main/java/org/apache/nemo/common/ir/IRDAGChecker.java +++ b/common/src/main/java/org/apache/nemo/common/ir/IRDAGChecker.java @@ -35,6 +35,7 @@ import org.apache.nemo.common.ir.vertex.SourceVertex; import org.apache.nemo.common.ir.vertex.executionproperty.*; import org.apache.nemo.common.ir.vertex.transform.SignalTransform; +import org.apache.nemo.common.ir.vertex.utility.IntermediateAccumulatorVertex; import org.apache.nemo.common.ir.vertex.utility.runtimepass.MessageAggregatorVertex; import org.apache.nemo.common.ir.vertex.utility.RelayVertex; import org.slf4j.Logger; @@ -79,6 +80,7 @@ private IRDAGChecker() { addLoopVertexCheckers(); addScheduleGroupCheckers(); addCacheCheckers(); + addIntermediateAccumulatorVertexCheckers(); } /** @@ -284,23 +286,25 @@ void addShuffleEdgeCheckers() { final NeighborChecker shuffleChecker = ((v, inEdges, outEdges) -> { for (final IREdge inEdge : inEdges) { if (CommunicationPatternProperty.Value.SHUFFLE + .equals(inEdge.getPropertyValue(CommunicationPatternProperty.class).get()) + || CommunicationPatternProperty.Value.PARTIAL_SHUFFLE .equals(inEdge.getPropertyValue(CommunicationPatternProperty.class).get())) { // Shuffle edges must have the following properties if (!inEdge.getPropertyValue(KeyExtractorProperty.class).isPresent() || !inEdge.getPropertyValue(KeyEncoderProperty.class).isPresent() || !inEdge.getPropertyValue(KeyDecoderProperty.class).isPresent()) { - return failure("Shuffle edge does not have a Key-related property: " + inEdge.getId()); + return failure("(Partial)Shuffle edge does not have a Key-related property: " + inEdge.getId()); } } else { // Non-shuffle edges must not have the following properties final Optional> partitioner = inEdge.getPropertyValue(PartitionerProperty.class); if (partitioner.isPresent() && partitioner.get().left().equals(PartitionerProperty.Type.HASH)) { - return failure("Only shuffle can have the hash partitioner", + return failure("Only (partial)shuffle can have the hash partitioner", inEdge, CommunicationPatternProperty.class, PartitionerProperty.class); } if (inEdge.getPropertyValue(PartitionSetProperty.class).isPresent()) { - return failure("Only shuffle can select partition sets", + return failure("Only (partial)shuffle can select partition sets", inEdge, CommunicationPatternProperty.class, PartitionSetProperty.class); } } @@ -486,6 +490,34 @@ void addEncodingCompressionCheckers() { singleEdgeCheckerList.add(compressAndDecompress); } + void addIntermediateAccumulatorVertexCheckers() { + final NeighborChecker shuffleExecutorSet = ((v, inEdges, outEdges) -> { + if (v instanceof IntermediateAccumulatorVertex) { + if (inEdges.size() != 1 || outEdges.size() != 1) { + return failure("Intermediate accumulator vertex must have only one in/out edge.", v); + } else if (inEdges.stream().anyMatch(e -> + !e.getPropertyValue(CommunicationPatternProperty.class).get() + .equals(CommunicationPatternProperty.Value.PARTIAL_SHUFFLE))) { + return failure("Intermediate accumulator vertex must have partial shuffle inEdge.", v); + } else if (!v.getPropertyValue(ShuffleExecutorSetProperty.class).isPresent()) { + return failure("Intermediate accumulator vertex must have shuffle executor set property.", v); + } else if (v.getPropertyValue(ParallelismProperty.class).get() + < v.getPropertyValue(ShuffleExecutorSetProperty.class).get().size()) { + return failure("Parallelism must be greater or equal to the number of shuffle executor set", v); + } + } else { + if (inEdges.stream().anyMatch(e -> e.getPropertyValue(CommunicationPatternProperty.class).get() + .equals(CommunicationPatternProperty.Value.PARTIAL_SHUFFLE))) { + return failure("Only intermediate accumulator vertex can have partial shuffle inEdge.", v); + } else if (v.getPropertyValue(ShuffleExecutorSetProperty.class).isPresent()) { + return failure("Only intermediate accumulator vertex can have shuffle executor set property.", v); + } + } + return success(); + }); + neighborCheckerList.add(shuffleExecutorSet); + } + /** * Group outgoing edges by the additional output tag property. * @param outEdges the outedges to group. diff --git a/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/CommunicationPatternProperty.java b/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/CommunicationPatternProperty.java index 23909b43bd..accb8ac054 100644 --- a/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/CommunicationPatternProperty.java +++ b/common/src/main/java/org/apache/nemo/common/ir/edge/executionproperty/CommunicationPatternProperty.java @@ -52,6 +52,7 @@ public static CommunicationPatternProperty of(final Value value) { public enum Value { ONE_TO_ONE, BROADCAST, - SHUFFLE + SHUFFLE, + PARTIAL_SHUFFLE } } diff --git a/common/src/main/java/org/apache/nemo/common/ir/executionproperty/ExecutionPropertyMap.java b/common/src/main/java/org/apache/nemo/common/ir/executionproperty/ExecutionPropertyMap.java index d0ef23549d..19f52432a1 100644 --- a/common/src/main/java/org/apache/nemo/common/ir/executionproperty/ExecutionPropertyMap.java +++ b/common/src/main/java/org/apache/nemo/common/ir/executionproperty/ExecutionPropertyMap.java @@ -73,6 +73,7 @@ public static ExecutionPropertyMap of( map.put(EncoderProperty.of(EncoderFactory.DUMMY_ENCODER_FACTORY)); map.put(DecoderProperty.of(DecoderFactory.DUMMY_DECODER_FACTORY)); switch (commPattern) { + case PARTIAL_SHUFFLE: case SHUFFLE: map.put(DataFlowProperty.of(DataFlowProperty.Value.PULL)); map.put(PartitionerProperty.of(PartitionerProperty.Type.HASH)); diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/ShuffleExecutorSetProperty.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/ShuffleExecutorSetProperty.java new file mode 100644 index 0000000000..aed688ad07 --- /dev/null +++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/executionproperty/ShuffleExecutorSetProperty.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.nemo.common.ir.vertex.executionproperty; + +import org.apache.nemo.common.ir.executionproperty.VertexExecutionProperty; + +import java.util.ArrayList; +import java.util.Set; + +/** + * List of set of node names to limit the scheduling of the tasks of the vertex to while shuffling. + * For example, [[1, 2, 3], [4, 5, 6]] limits shuffle to occur just within nodes 1, 2, 3 and nodes 4, 5, 6 separately. + * This occurs by limiting the source executors where the tasks read their input data from, depending on + * where the task is located at. + * ShuffleExecutorSetProperty is set only for the IntermediateAccumulatorVertex, and + * other vertices should not have this property. + */ +public final class ShuffleExecutorSetProperty extends VertexExecutionProperty>> { + + /** + * Default constructor. + * @param value value of the execution property. + */ + private ShuffleExecutorSetProperty(final ArrayList> value) { + super(value); + } + + /** + * Static method for constructing {@link ShuffleExecutorSetProperty}. + * + * @param setsOfExecutors the list of executors to schedule the tasks of the vertex on. + * Leave empty to make it effectless. + * @return the new execution property + */ + public static ShuffleExecutorSetProperty of(final Set> setsOfExecutors) { + return new ShuffleExecutorSetProperty(new ArrayList<>(setsOfExecutors)); + } +} diff --git a/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/IntermediateAccumulatorVertex.java b/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/IntermediateAccumulatorVertex.java new file mode 100644 index 0000000000..bf1ec2142c --- /dev/null +++ b/common/src/main/java/org/apache/nemo/common/ir/vertex/utility/IntermediateAccumulatorVertex.java @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.nemo.common.ir.vertex.utility; + +import org.apache.nemo.common.ir.vertex.OperatorVertex; +import org.apache.nemo.common.ir.vertex.transform.Transform; + +/** + * During combine transform, accumulates data among physically nearby containers prior to shuffling across WAN. + */ +public final class IntermediateAccumulatorVertex extends OperatorVertex { + /** + * Constructor. + */ + public IntermediateAccumulatorVertex(final Transform t) { + super(t); + } +} diff --git a/common/src/test/java/org/apache/nemo/common/ir/IRDAGTest.java b/common/src/test/java/org/apache/nemo/common/ir/IRDAGTest.java index 6113c85266..2e28ab93e7 100644 --- a/common/src/test/java/org/apache/nemo/common/ir/IRDAGTest.java +++ b/common/src/test/java/org/apache/nemo/common/ir/IRDAGTest.java @@ -31,6 +31,7 @@ import org.apache.nemo.common.ir.vertex.OperatorVertex; import org.apache.nemo.common.ir.vertex.SourceVertex; import org.apache.nemo.common.ir.vertex.executionproperty.*; +import org.apache.nemo.common.ir.vertex.utility.IntermediateAccumulatorVertex; import org.apache.nemo.common.ir.vertex.utility.TaskSizeSplitterVertex; import org.apache.nemo.common.ir.vertex.utility.runtimepass.MessageAggregatorVertex; import org.apache.nemo.common.ir.vertex.utility.runtimepass.MessageGeneratorVertex; @@ -327,6 +328,16 @@ public void testSplitterVertex() { mustPass(); } + @Test + public void testAccumulatorVertex() { + final IntermediateAccumulatorVertex cv = + new IntermediateAccumulatorVertex(new EmptyComponents.EmptyTransform("iav")); + cv.setProperty(ShuffleExecutorSetProperty.of(new HashSet<>())); + cv.setProperty(ParallelismProperty.of(5)); + irdag.insert(cv, shuffleEdge); + mustPass(); + } + private MessageAggregatorVertex insertNewTriggerVertex(final IRDAG dag, final IREdge edgeToGetStatisticsOf) { final MessageGeneratorVertex mb = new MessageGeneratorVertex<>((l, r) -> null); final MessageAggregatorVertex ma = new MessageAggregatorVertex<>(() -> new Object(), (l, r) -> null); diff --git a/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslationContext.java b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslationContext.java index 28aafb2c90..3ba6f0e7fe 100644 --- a/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslationContext.java +++ b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslationContext.java @@ -277,7 +277,7 @@ private CommunicationPatternProperty.Value getCommPattern(final IRVertex src, fi } // If GBKTransform represents a partial CombinePerKey transformation, we do NOT need to shuffle its input, // since its output will be shuffled before going through a final CombinePerKey transformation. - if ((dstTransform instanceof GBKTransform && !((GBKTransform) dstTransform).getIsPartialCombining()) + if ((dstTransform instanceof CombineTransform && !((CombineTransform) dstTransform).getIsPartialCombining()) || dstTransform instanceof GroupByKeyTransform) { return CommunicationPatternProperty.Value.SHUFFLE; } diff --git a/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslator.java b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslator.java index cd9d7ad223..3794c3687f 100644 --- a/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslator.java +++ b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/PipelineTranslator.java @@ -358,10 +358,12 @@ private static Pipeline.PipelineVisitor.CompositeBehavior combinePerKeyTranslato } final CombineFnBase.GlobalCombineFn combineFn = perKey.getFn(); + final AppliedPTransform pTransform = beamNode.toAppliedPTransform(ctx.getPipeline()); + final PCollection mainInput = (PCollection) Iterables.getOnlyElement( - TransformInputs.nonAdditionalInputs(beamNode.toAppliedPTransform(ctx.getPipeline()))); + TransformInputs.nonAdditionalInputs(pTransform)); final PCollection inputs = (PCollection) Iterables.getOnlyElement( - TransformInputs.nonAdditionalInputs(beamNode.toAppliedPTransform(ctx.getPipeline()))); + TransformInputs.nonAdditionalInputs(pTransform)); final KvCoder inputCoder = (KvCoder) inputs.getCoder(); final Coder accumulatorCoder; @@ -386,48 +388,52 @@ private static Pipeline.PipelineVisitor.CompositeBehavior combinePerKeyTranslato finalCombine = new OperatorVertex(new CombineFnFinalTransform<>(combineFn)); } else { // Stream data processing, using GBKTransform - final AppliedPTransform pTransform = beamNode.toAppliedPTransform(ctx.getPipeline()); final CombineFnBase.GlobalCombineFn partialCombineFn = new PartialCombineFn( (Combine.CombineFn) combineFn, accumulatorCoder); + final CombineFnBase.GlobalCombineFn intermediateCombineFn = new IntermediateCombineFn( + (Combine.CombineFn) combineFn, accumulatorCoder); final CombineFnBase.GlobalCombineFn finalCombineFn = new FinalCombineFn( (Combine.CombineFn) combineFn, accumulatorCoder); + final SystemReduceFn partialSystemReduceFn = SystemReduceFn.combining( inputCoder.getKeyCoder(), AppliedCombineFn.withInputCoder(partialCombineFn, - ctx.getPipeline().getCoderRegistry(), inputCoder, - null, - mainInput.getWindowingStrategy())); + ctx.getPipeline().getCoderRegistry(), + inputCoder, + null, mainInput.getWindowingStrategy())); + final SystemReduceFn intermediateSystemReduceFn = + SystemReduceFn.combining( + inputCoder.getKeyCoder(), + AppliedCombineFn.withInputCoder(intermediateCombineFn, + ctx.getPipeline().getCoderRegistry(), + KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder), + null, mainInput.getWindowingStrategy())); final SystemReduceFn finalSystemReduceFn = SystemReduceFn.combining( inputCoder.getKeyCoder(), AppliedCombineFn.withInputCoder(finalCombineFn, ctx.getPipeline().getCoderRegistry(), - KvCoder.of(inputCoder.getKeyCoder(), - accumulatorCoder), + KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder), null, mainInput.getWindowingStrategy())); final TupleTag partialMainOutputTag = new TupleTag<>(); - final GBKTransform partialCombineStreamTransform = - new GBKTransform(inputCoder, - Collections.singletonMap(partialMainOutputTag, KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder)), - partialMainOutputTag, - mainInput.getWindowingStrategy(), - ctx.getPipelineOptions(), - partialSystemReduceFn, - DoFnSchemaInformation.create(), - DisplayData.from(beamNode.getTransform()), - true); - final GBKTransform finalCombineStreamTransform = - new GBKTransform(KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder), + final CombineTransformFactory combineTransformFactory = + new CombineTransformFactory(inputCoder, + partialMainOutputTag, + KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder), getOutputCoders(pTransform), Iterables.getOnlyElement(beamNode.getOutputs().keySet()), mainInput.getWindowingStrategy(), ctx.getPipelineOptions(), + partialSystemReduceFn, + intermediateSystemReduceFn, finalSystemReduceFn, DoFnSchemaInformation.create(), - DisplayData.from(beamNode.getTransform()), - false); + DisplayData.from(beamNode.getTransform())); + + final CombineTransform partialCombineStreamTransform = combineTransformFactory.getPartialCombineTransform(); + final CombineTransform finalCombineStreamTransform = combineTransformFactory.getFinalCombineTransform(); partialCombine = new OperatorVertex(partialCombineStreamTransform); finalCombine = new OperatorVertex(finalCombineStreamTransform); @@ -564,7 +570,7 @@ private static Transform createGBKTransform( return new GroupByKeyTransform(); } else { // GroupByKey Transform when using a non-global windowing strategy. - return new GBKTransform<>( + return new CombineTransform<>( (KvCoder) mainInput.getCoder(), getOutputCoders(pTransform), mainOutputTag, diff --git a/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/GBKTransform.java b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/CombineTransform.java similarity index 86% rename from compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/GBKTransform.java rename to compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/CombineTransform.java index 9194edbbcd..61a4a27eeb 100644 --- a/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/GBKTransform.java +++ b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/CombineTransform.java @@ -46,9 +46,9 @@ * @param input type * @param output type */ -public final class GBKTransform +public final class CombineTransform extends AbstractDoFnTransform, KeyedWorkItem, KV> { - private static final Logger LOG = LoggerFactory.getLogger(GBKTransform.class.getName()); + private static final Logger LOG = LoggerFactory.getLogger(CombineTransform.class.getName()); private final SystemReduceFn reduceFn; private transient InMemoryTimerInternalsFactory inMemoryTimerInternalsFactory; private transient InMemoryStateInternalsFactory inMemoryStateInternalsFactory; @@ -58,16 +58,31 @@ public final class GBKTransform private boolean dataReceived = false; private transient OutputCollector originOc; private final boolean isPartialCombining; + private final CombineTransform intermediateCombine; - public GBKTransform(final Coder> inputCoder, - final Map, Coder> outputCoders, - final TupleTag> mainOutputTag, - final WindowingStrategy windowingStrategy, - final PipelineOptions options, - final SystemReduceFn reduceFn, - final DoFnSchemaInformation doFnSchemaInformation, - final DisplayData displayData, - final boolean isPartialCombining) { + public CombineTransform(final Coder> inputCoder, + final Map, Coder> outputCoders, + final TupleTag> mainOutputTag, + final WindowingStrategy windowingStrategy, + final PipelineOptions options, + final SystemReduceFn reduceFn, + final DoFnSchemaInformation doFnSchemaInformation, + final DisplayData displayData, + final boolean isPartialCombining) { + this(inputCoder, outputCoders, mainOutputTag, windowingStrategy, options, reduceFn, + doFnSchemaInformation, displayData, isPartialCombining, null); + } + + public CombineTransform(final Coder> inputCoder, + final Map, Coder> outputCoders, + final TupleTag> mainOutputTag, + final WindowingStrategy windowingStrategy, + final PipelineOptions options, + final SystemReduceFn reduceFn, + final DoFnSchemaInformation doFnSchemaInformation, + final DisplayData displayData, + final boolean isPartialCombining, + final CombineTransform intermediateCombine) { super(null, inputCoder, outputCoders, @@ -81,6 +96,7 @@ public GBKTransform(final Coder> inputCoder, Collections.emptyMap()); /* does not have side inputs */ this.reduceFn = reduceFn; this.isPartialCombining = isPartialCombining; + this.intermediateCombine = intermediateCombine; } /** @@ -273,6 +289,13 @@ public boolean getIsPartialCombining() { return isPartialCombining; } + /** + * Get the intermediate combine transform if possible. + * @return the intermediate transform if possible. + */ + public Optional getIntermediateCombine() { + return Optional.ofNullable(intermediateCombine); + } /** Wrapper class for {@link OutputCollector}. */ public class GBKOutputCollector implements OutputCollector>> { diff --git a/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/CombineTransformFactory.java b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/CombineTransformFactory.java new file mode 100644 index 0000000000..86916cbbcf --- /dev/null +++ b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/CombineTransformFactory.java @@ -0,0 +1,136 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.nemo.compiler.frontend.beam.transform; + +import org.apache.beam.runners.core.SystemReduceFn; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.options.PipelineOptions; +import org.apache.beam.sdk.transforms.DoFnSchemaInformation; +import org.apache.beam.sdk.transforms.display.DisplayData; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowingStrategy; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Collections; +import java.util.Map; + +/** + * Factory for the combine transform that combines the results during the group by key. + * @param the type of the key. + * @param the type of the input values. + * @param the type of the accumulators. + * @param the type of the output. + */ +public class CombineTransformFactory { + private static final Logger LOG = LoggerFactory.getLogger(CombineTransformFactory.class.getName()); + + private final SystemReduceFn combineFn; + private final SystemReduceFn intermediateCombineFn; + private final SystemReduceFn finalReduceFn; + + private final Coder> inputCoder; + private final TupleTag> partialMainOutputTag; + private final Coder> accumulatorCoder; + private final Map, Coder> outputCoders; + + private final TupleTag> mainOutputTag; + private final WindowingStrategy windowingStrategy; + private final PipelineOptions options; + + private final DoFnSchemaInformation doFnSchemaInformation; + private final DisplayData displayData; + + public CombineTransformFactory(final Coder> inputCoder, + final TupleTag> partialMainOutputTag, + final Coder> accumulatorCoder, + final Map, Coder> outputCoders, + final TupleTag> mainOutputTag, + final WindowingStrategy windowingStrategy, + final PipelineOptions options, + final SystemReduceFn combineFn, + final SystemReduceFn intermediateCombineFn, + final SystemReduceFn finalReduceFn, + final DoFnSchemaInformation doFnSchemaInformation, + final DisplayData displayData) { + this.combineFn = combineFn; + this.intermediateCombineFn = intermediateCombineFn; + this.finalReduceFn = finalReduceFn; + + this.inputCoder = inputCoder; + this.partialMainOutputTag = partialMainOutputTag; + this.accumulatorCoder = accumulatorCoder; + this.outputCoders = outputCoders; + + this.mainOutputTag = mainOutputTag; + this.windowingStrategy = windowingStrategy; + this.options = options; + + this.doFnSchemaInformation = doFnSchemaInformation; + this.displayData = displayData; + } + + + /** + * Get the partial combine transform of the combine transform. + * @return the partial combine transform for the combine transform. + */ + public CombineTransform getPartialCombineTransform() { + return new CombineTransform<>(inputCoder, + Collections.singletonMap(partialMainOutputTag, accumulatorCoder), + partialMainOutputTag, + windowingStrategy, + options, + combineFn, + doFnSchemaInformation, + displayData, true); + } + + /** + * Get the intermediate combine transform of the combine transform. + * @return the intermediate combine transform for the combine transform. + */ + public CombineTransform getIntermediateCombineTransform() { + return new CombineTransform<>(accumulatorCoder, + Collections.singletonMap(partialMainOutputTag, accumulatorCoder), + partialMainOutputTag, + windowingStrategy, + options, + intermediateCombineFn, + doFnSchemaInformation, + displayData, false); + } + + /** + * Get the final combine transform of the combine transform. + * @return the final combine transform for the combine transform. + */ + public CombineTransform getFinalCombineTransform() { + return new CombineTransform<>(accumulatorCoder, + outputCoders, + mainOutputTag, + windowingStrategy, + options, + finalReduceFn, + doFnSchemaInformation, + displayData, false, this.getIntermediateCombineTransform()); + } +} diff --git a/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/IntermediateCombineFn.java b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/IntermediateCombineFn.java new file mode 100644 index 0000000000..73a272b0a0 --- /dev/null +++ b/compiler/frontend/beam/src/main/java/org/apache/nemo/compiler/frontend/beam/transform/IntermediateCombineFn.java @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.nemo.compiler.frontend.beam.transform; + +import org.apache.beam.sdk.coders.CannotProvideCoderException; +import org.apache.beam.sdk.coders.Coder; +import org.apache.beam.sdk.coders.CoderRegistry; +import org.apache.beam.sdk.transforms.Combine; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Arrays; + +/** + * Wrapper class for {@link Combine.CombineFn}. + * When adding input, it merges its accumulator and input accumulator into a single accumulator. + * After then, it returns the accumulator for it to be merged later on by the {@link FinalCombineFn}. + * @param accumulator type. + */ +public final class IntermediateCombineFn extends Combine.CombineFn { + private static final Logger LOG = LoggerFactory.getLogger(IntermediateCombineFn.class.getName()); + private final Combine.CombineFn originFn; + private final Coder accumCoder; + + public IntermediateCombineFn(final Combine.CombineFn originFn, + final Coder accumCoder) { + this.originFn = originFn; + this.accumCoder = accumCoder; + } + + @Override + public Coder getAccumulatorCoder(final CoderRegistry registry, final Coder inputCoder) + throws CannotProvideCoderException { + return accumCoder; + } + + @Override + public AccumT createAccumulator() { + return originFn.createAccumulator(); + } + + @Override + public AccumT addInput(final AccumT mutableAccumulator, final AccumT input) { + final AccumT result = originFn.mergeAccumulators(Arrays.asList(mutableAccumulator, input)); + return result; + } + + @Override + public AccumT mergeAccumulators(final Iterable accumulators) { + return originFn.mergeAccumulators(accumulators); + } + + @Override + public AccumT extractOutput(final AccumT accumulator) { + return accumulator; + } +} diff --git a/compiler/optimizer/pom.xml b/compiler/optimizer/pom.xml index 49a546d86e..e75c3ca093 100644 --- a/compiler/optimizer/pom.xml +++ b/compiler/optimizer/pom.xml @@ -73,5 +73,17 @@ under the License. jackson-databind ${jackson.version} + + + org.apache.nemo + nemo-compiler-frontend-beam + ${project.version} + + + org.apache.nemo + nemo-compiler-frontend-beam + 0.4-SNAPSHOT + compile + diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/PipeTransferForAllEdgesPass.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/PipeTransferForAllEdgesPass.java index cddbb53f21..85b0157d60 100644 --- a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/PipeTransferForAllEdgesPass.java +++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/annotating/PipeTransferForAllEdgesPass.java @@ -19,6 +19,8 @@ package org.apache.nemo.compiler.optimizer.pass.compiletime.annotating; import org.apache.nemo.common.ir.IRDAG; +import org.apache.nemo.common.ir.edge.executionproperty.DataFlowProperty; +import org.apache.nemo.common.ir.edge.executionproperty.DataPersistenceProperty; import org.apache.nemo.common.ir.edge.executionproperty.DataStoreProperty; /** @@ -37,9 +39,11 @@ public PipeTransferForAllEdgesPass() { public IRDAG apply(final IRDAG dag) { dag.getVertices().forEach(vertex -> dag.getIncomingEdgesOf(vertex).stream() - .forEach(edge -> edge.setPropertyPermanently( - DataStoreProperty.of(DataStoreProperty.Value.PIPE))) - ); + .forEach(edge -> { + edge.setPropertyPermanently(DataStoreProperty.of(DataStoreProperty.Value.PIPE)); + edge.setPropertyPermanently(DataFlowProperty.of(DataFlowProperty.Value.PUSH)); + edge.setPropertyPermanently(DataPersistenceProperty.of(DataPersistenceProperty.Value.DISCARD)); + })); return dag; } } diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/IntermediateAccumulatorInsertionPass.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/IntermediateAccumulatorInsertionPass.java new file mode 100644 index 0000000000..37fd295608 --- /dev/null +++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/IntermediateAccumulatorInsertionPass.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping; + +import com.fasterxml.jackson.databind.ObjectMapper; +import org.apache.nemo.common.Util; +import org.apache.nemo.common.exception.SchedulingException; +import org.apache.nemo.common.ir.IRDAG; +import org.apache.nemo.common.ir.edge.IREdge; +import org.apache.nemo.common.ir.edge.executionproperty.CommunicationPatternProperty; +import org.apache.nemo.common.ir.vertex.OperatorVertex; +import org.apache.nemo.common.ir.vertex.executionproperty.ParallelismProperty; +import org.apache.nemo.common.ir.vertex.executionproperty.ShuffleExecutorSetProperty; +import org.apache.nemo.common.ir.vertex.utility.IntermediateAccumulatorVertex; +import org.apache.nemo.common.test.ExampleTestArgs; +import org.apache.nemo.compiler.frontend.beam.transform.CombineTransform; +import org.apache.nemo.compiler.optimizer.pass.compiletime.Requires; + +import java.io.File; +import java.util.*; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +/** + * Pass for inserting intermediate aggregator for partial shuffle. + */ +@Requires(ParallelismProperty.class) +public final class IntermediateAccumulatorInsertionPass extends ReshapingPass { + private final String networkFilePath; + + /** + * Default constructor. + */ + public IntermediateAccumulatorInsertionPass() { + this(false); + } + + /** + * Constructor for unit test. + * @param isUnitTest indicates unit test. + */ + public IntermediateAccumulatorInsertionPass(final boolean isUnitTest) { + super(IntermediateAccumulatorInsertionPass.class); + if (isUnitTest) { + this.networkFilePath = ExampleTestArgs.getFileBasePath() + "inputs/example_labeldict.json"; + } else { + this.networkFilePath = Util.fetchProjectRootPath() + "/tools/network_profiling/labeldict.json"; + } + } + + /** + * Insert accumulator vertex based on network hierarchy. + * + * @param irdag irdag to apply pass. + * @return modified irdag. + */ + @Override + public IRDAG apply(final IRDAG irdag) { + try { + ObjectMapper mapper = new ObjectMapper(); + Map> map = mapper.readValue(new File(networkFilePath), Map.class); + + irdag.topologicalDo(v -> { + if (v instanceof OperatorVertex && ((OperatorVertex) v).getTransform() instanceof CombineTransform) { + final CombineTransform finalCombineStreamTransform = (CombineTransform) ((OperatorVertex) v).getTransform(); + if (finalCombineStreamTransform.getIntermediateCombine().isPresent()) { + irdag.getIncomingEdgesOf(v).forEach(e -> { + if (CommunicationPatternProperty.Value.SHUFFLE + .equals(e.getPropertyValue(CommunicationPatternProperty.class) + .orElse(CommunicationPatternProperty.Value.ONE_TO_ONE))) { + handleDataTransferFor(irdag, map, finalCombineStreamTransform, e, 10F); + } + }); + } + } + }); + + return irdag; + } catch (final Exception e) { + throw new SchedulingException(e); + } + } + + private static void handleDataTransferFor(final IRDAG irdag, + final Map> map, + final CombineTransform finalCombineStreamTransform, + final IREdge targetEdge, + final Float threshold) { + final int srcParallelism = targetEdge.getSrc().getPropertyValue(ParallelismProperty.class).get(); + + final int mapSize = map.size(); + final int numOfNodes = (mapSize + 1) / 2; + Float previousDistance = 0F; + + for (int i = numOfNodes; i < mapSize; i++) { + final float currentDistance = Float.parseFloat(map.get(String.valueOf(i)).get(1)); + if (previousDistance != 0 && currentDistance > threshold * previousDistance + && srcParallelism * 2 / 3 >= mapSize - i + 1) { + final Integer targetNumberOfSets = mapSize - i; + final Set> setsOfExecutors = getTargetNumberOfExecutorSetsFrom(map, targetNumberOfSets); + + final CombineTransform intermediateCombineStreamTransform = + (CombineTransform) finalCombineStreamTransform.getIntermediateCombine().get(); + final IntermediateAccumulatorVertex accumulatorVertex = + new IntermediateAccumulatorVertex(intermediateCombineStreamTransform); + + targetEdge.getDst().copyExecutionPropertiesTo(accumulatorVertex); + accumulatorVertex.setProperty(ParallelismProperty.of(srcParallelism * 2 / 3)); + accumulatorVertex.setProperty(ShuffleExecutorSetProperty.of(setsOfExecutors)); + + irdag.insert(accumulatorVertex, targetEdge); + break; + } + previousDistance = currentDistance; + } + } + + private static Set> getTargetNumberOfExecutorSetsFrom(final Map> map, + final Integer targetNumber) { + final Set> result = new HashSet<>(); + final Integer index = map.size() - targetNumber; + final List indicesToCheck = IntStream.range(0, index) + .map(i -> -i).sorted().map(i -> -i) + .mapToObj(String::valueOf) + .collect(Collectors.toList()); + + Arrays.asList(map.get(String.valueOf(index)).get(0).split("\\+")) + .forEach(key -> result.add(recursivelyExtractExecutorsFrom(map, key, indicesToCheck))); + + while (!indicesToCheck.isEmpty()) { + result.add(recursivelyExtractExecutorsFrom(map, indicesToCheck.get(0), indicesToCheck)); + } + + return result; + } + + private static HashSet recursivelyExtractExecutorsFrom(final Map> map, + final String key, + final List indicesToCheck) { + indicesToCheck.remove(key); + final HashSet result = new HashSet<>(); + final List indices = Arrays.asList(map.get(key).get(0).split("\\+")); + if (indices.size() == 1) { + result.add(indices.get(0)); + } else { + indices.forEach(index -> result.addAll(recursivelyExtractExecutorsFrom(map, index, indicesToCheck))); + } + return result; + } +} diff --git a/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/policy/IntermediateAccumulatorPolicy.java b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/policy/IntermediateAccumulatorPolicy.java new file mode 100644 index 0000000000..4f9c3f09f6 --- /dev/null +++ b/compiler/optimizer/src/main/java/org/apache/nemo/compiler/optimizer/policy/IntermediateAccumulatorPolicy.java @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.nemo.compiler.optimizer.policy; + +import org.apache.nemo.common.ir.IRDAG; +import org.apache.nemo.compiler.optimizer.pass.compiletime.annotating.PipeTransferForAllEdgesPass; +import org.apache.nemo.compiler.optimizer.pass.compiletime.composite.DefaultCompositePass; +import org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping.IntermediateAccumulatorInsertionPass; +import org.apache.nemo.compiler.optimizer.pass.runtime.Message; + +/** + * A policy to perform intermediate data accumulation in shuffle edges (e.g. WAN networks). + */ +public final class IntermediateAccumulatorPolicy implements Policy { + public static final PolicyBuilder BUILDER = + new PolicyBuilder() + .registerCompileTimePass(new DefaultCompositePass()) + .registerCompileTimePass(new PipeTransferForAllEdgesPass()) + .registerCompileTimePass(new IntermediateAccumulatorInsertionPass()); + + private final Policy policy; + + /** + * Default constructor. + */ + public IntermediateAccumulatorPolicy() { + this.policy = BUILDER.build(); + } + + @Override + public IRDAG runCompileTimeOptimization(final IRDAG dag, final String dagDirectory) { + return this.policy.runCompileTimeOptimization(dag, dagDirectory); + } + + @Override + public IRDAG runRunTimeOptimizations(final IRDAG dag, final Message message) { + return this.policy.runRunTimeOptimizations(dag, message); + } +} diff --git a/compiler/optimizer/src/test/java/org/apache/nemo/compiler/optimizer/policy/PolicyBuilderTest.java b/compiler/optimizer/src/test/java/org/apache/nemo/compiler/optimizer/policy/PolicyBuilderTest.java index 606b9851a5..d2bfc7f26b 100644 --- a/compiler/optimizer/src/test/java/org/apache/nemo/compiler/optimizer/policy/PolicyBuilderTest.java +++ b/compiler/optimizer/src/test/java/org/apache/nemo/compiler/optimizer/policy/PolicyBuilderTest.java @@ -45,6 +45,12 @@ public void testDataSkewPolicy() { assertEquals(1, DataSkewPolicy.BUILDER.getRunTimePasses().size()); } + @Test + public void testIntermediateAccumulatorPolicy() { + assertEquals(11, IntermediateAccumulatorPolicy.BUILDER.getCompileTimePasses().size()); + assertEquals(0, IntermediateAccumulatorPolicy.BUILDER.getRunTimePasses().size()); + } + @Test public void testShouldFailPolicy() { try { diff --git a/compiler/test/src/main/java/org/apache/nemo/compiler/CompilerTestUtil.java b/compiler/test/src/main/java/org/apache/nemo/compiler/CompilerTestUtil.java index 3f5001a0b7..a347bc94b1 100644 --- a/compiler/test/src/main/java/org/apache/nemo/compiler/CompilerTestUtil.java +++ b/compiler/test/src/main/java/org/apache/nemo/compiler/CompilerTestUtil.java @@ -143,4 +143,25 @@ public static IRDAG compileMLRDAG() throws Exception { .addUserArgs(input, numFeatures, numClasses, numIteration); return compileDAG(mlrArgBuilder.build()); } + + public static IRDAG compileWindowedWordcountIntermediateAccumulationDAG() throws Exception { + final String input = ROOT_DIR + "/examples/resources/inputs/test_input_windowed_wordcount"; + final String windowType = "fixed"; + final String inputType = "bounded"; + final String main = "org.apache.nemo.examples.beam.WindowedWordCount"; + final String output = ROOT_DIR + "/examples/resources/inputs/test_output"; + final String scheduler = "org.apache.nemo.runtime.master.scheduler.StreamingScheduler"; + final String resourceJson = ROOT_DIR + "/examples/resources/executors/beam_test_executor_resources.json"; + final String jobId = "testIntermediateAccumulatorInsertionPass"; + final String policy = "org.apache.nemo.compiler.optimizer.policy.StreamingPolicy"; + + final ArgBuilder edgarArgBuilder = new ArgBuilder() + .addScheduler(scheduler) + .addUserMain(main) + .addUserArgs(output, windowType, inputType, input) + .addResourceJson(resourceJson) + .addJobId(jobId) + .addOptimizationPolicy(policy); + return compileDAG(edgarArgBuilder.build()); + } } diff --git a/compiler/test/src/test/java/org/apache/nemo/compiler/frontend/beam/transform/CombineFnTest.java b/compiler/test/src/test/java/org/apache/nemo/compiler/frontend/beam/transform/CombineFnTest.java index acf7cc4d1b..af6721e245 100644 --- a/compiler/test/src/test/java/org/apache/nemo/compiler/frontend/beam/transform/CombineFnTest.java +++ b/compiler/test/src/test/java/org/apache/nemo/compiler/frontend/beam/transform/CombineFnTest.java @@ -34,6 +34,14 @@ public class CombineFnTest extends TestCase { public static final class CountFn extends Combine.CombineFn { public static final class Accum { int sum = 0; + + @Override + public boolean equals(Object o) { + if (Accum.class != o.getClass()) { + return false; + } + return (sum == ((Accum) o).sum); + } } @Override @@ -120,6 +128,55 @@ public void testPartialCombineFn() { } } + @Test + public void testIntermediateCombineFn() { + // Initialize intermediate combine function. + final IntermediateCombineFn intermediateCombineFn = + new IntermediateCombineFn<>(combineFn, accumCoder); + + // Create accumulator. + final CountFn.Accum accum1 = intermediateCombineFn.createAccumulator(); + final CountFn.Accum accum2 = intermediateCombineFn.createAccumulator(); + final CountFn.Accum accum3 = intermediateCombineFn.createAccumulator(); + + final CountFn.Accum expectedMergedAccum = intermediateCombineFn.createAccumulator(); + expectedMergedAccum.sum = 6; + + // Check whether accumulators are initialized correctly. + assertEquals(0, accum1.sum); + assertEquals(0, accum2.sum); + assertEquals(0, accum3.sum); + + // Change the parameter for the sake of unit testing. + accum1.sum = 1; + accum2.sum = 2; + accum3.sum = 3; + + // Add input. Intermediate combineFn's addInput method takes accumulators as input + // and merges them into a single accumulator. + final CountFn.Accum addedAccum = intermediateCombineFn.addInput(accum1, accum2); + + // Check whether inputs are added correctly. + assertEquals(3, addedAccum.sum); + + // Merge accumulators. + CountFn.Accum mergedAccum = intermediateCombineFn.mergeAccumulators(Arrays.asList(accum1, accum2, accum3)); + + // Check whether accumulators are merged correctly. + assertEquals(expectedMergedAccum, mergedAccum); + + // Extract output. + assertEquals(expectedMergedAccum, intermediateCombineFn.extractOutput(mergedAccum)); + + // Get accumulator coder. Check if the accumulator coder from intermediate combineFn is equal + // to the one from original combineFn. + try { + assertEquals(accumCoder, intermediateCombineFn.getAccumulatorCoder(CoderRegistry.createDefault(), INTEGER_CODER)); + } catch (CannotProvideCoderException e) { + throw new RuntimeException("Failed to provide an accumulator coder"); + } + } + @Test public void testFinalCombineFn() { // Initialize final combine function. diff --git a/compiler/test/src/test/java/org/apache/nemo/compiler/frontend/beam/transform/GBKTransformTest.java b/compiler/test/src/test/java/org/apache/nemo/compiler/frontend/beam/transform/CombineTransformTest.java similarity index 57% rename from compiler/test/src/test/java/org/apache/nemo/compiler/frontend/beam/transform/GBKTransformTest.java rename to compiler/test/src/test/java/org/apache/nemo/compiler/frontend/beam/transform/CombineTransformTest.java index 3c08c50fb2..0f67164f3d 100644 --- a/compiler/test/src/test/java/org/apache/nemo/compiler/frontend/beam/transform/GBKTransformTest.java +++ b/compiler/test/src/test/java/org/apache/nemo/compiler/frontend/beam/transform/CombineTransformTest.java @@ -18,54 +18,57 @@ */ package org.apache.nemo.compiler.frontend.beam.transform; -import com.google.common.collect.Iterables; -import junit.framework.TestCase; import org.apache.beam.runners.core.SystemReduceFn; import org.apache.beam.sdk.coders.*; import org.apache.beam.sdk.options.PipelineOptionsFactory; import org.apache.beam.sdk.transforms.Combine; +import org.apache.beam.sdk.transforms.CombineFnBase; import org.apache.beam.sdk.transforms.DoFnSchemaInformation; import org.apache.beam.sdk.transforms.display.DisplayData; import org.apache.beam.sdk.transforms.windowing.*; import org.apache.beam.sdk.util.AppliedCombineFn; import org.apache.beam.sdk.util.WindowedValue; -import org.apache.beam.sdk.values.*; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TupleTag; +import org.apache.beam.sdk.values.WindowingStrategy; import org.apache.nemo.common.ir.vertex.transform.Transform; import org.apache.nemo.common.punctuation.Watermark; import org.apache.nemo.compiler.frontend.beam.NemoPipelineOptions; import org.joda.time.Duration; import org.joda.time.Instant; +import org.junit.Assert; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import java.util.*; import static org.apache.beam.sdk.transforms.windowing.PaneInfo.Timing.*; import static org.apache.beam.sdk.values.WindowingStrategy.AccumulationMode.ACCUMULATING_FIRED_PANES; import static org.mockito.Mockito.mock; -public class GBKTransformTest extends TestCase { - private static final Logger LOG = LoggerFactory.getLogger(GBKTransformTest.class.getName()); +public class CombineTransformTest { + private static final Logger LOG = LoggerFactory.getLogger(CombineTransformTest.class.getName()); private final static Coder STRING_CODER = StringUtf8Coder.of(); private final static Coder INTEGER_CODER = BigEndianIntegerCoder.of(); private void checkOutput(final KV expected, final KV result) { // check key - assertEquals(expected.getKey(), result.getKey()); + Assert.assertEquals(expected.getKey(), result.getKey()); // check value - assertEquals(expected.getValue(), result.getValue()); + Assert.assertEquals(expected.getValue(), result.getValue()); } private void checkOutput2(final KV> expected, final KV> result) { // check key - assertEquals(expected.getKey(), result.getKey()); + Assert.assertEquals(expected.getKey(), result.getKey()); // check value final List resultValue = new ArrayList<>(); final List expectedValue = new ArrayList<>(expected.getValue()); result.getValue().iterator().forEachRemaining(resultValue::add); Collections.sort(resultValue); Collections.sort(expectedValue); - assertEquals(expectedValue, resultValue); + Assert.assertEquals(expectedValue, resultValue); } @@ -123,7 +126,7 @@ public Coder getAccumulatorCoder(CoderRegistry registry, Coder outputTag = new TupleTag<>("main-output"); final SlidingWindows slidingWindows = SlidingWindows.of(Duration.standardSeconds(10)) .every(Duration.standardSeconds(5)); @@ -142,7 +145,7 @@ public void test_combine() { final Watermark watermark3 = new Watermark(18000); final Watermark watermark4 = new Watermark(21000); - AppliedCombineFn applied_combine_fn = + AppliedCombineFn appliedCombineFn = AppliedCombineFn.withInputCoder( combine_fn, CoderRegistry.createDefault(), @@ -151,14 +154,14 @@ public void test_combine() { WindowingStrategy.of(slidingWindows).withMode(ACCUMULATING_FIRED_PANES) ); - final GBKTransform combine_transform = - new GBKTransform( + final CombineTransform combineTransform = + new CombineTransform( KvCoder.of(STRING_CODER, INTEGER_CODER), Collections.singletonMap(outputTag, KvCoder.of(STRING_CODER, INTEGER_CODER)), outputTag, WindowingStrategy.of(slidingWindows).withMode(ACCUMULATING_FIRED_PANES), PipelineOptionsFactory.as(NemoPipelineOptions.class), - SystemReduceFn.combining(STRING_CODER, applied_combine_fn), + SystemReduceFn.combining(STRING_CODER, appliedCombineFn), DoFnSchemaInformation.create(), DisplayData.none(), false); @@ -168,69 +171,69 @@ public void test_combine() { // window3 : [5000, 15000) // window4 : [10000, 20000) List sortedWindows = new ArrayList<>(slidingWindows.assignWindows(ts1)); - Collections.sort(sortedWindows, IntervalWindow::compareTo); + sortedWindows.sort(IntervalWindow::compareTo); final IntervalWindow window1 = sortedWindows.get(0); final IntervalWindow window2 = sortedWindows.get(1); sortedWindows = new ArrayList<>(slidingWindows.assignWindows(ts5)); - Collections.sort(sortedWindows, IntervalWindow::compareTo); + sortedWindows.sort(IntervalWindow::compareTo); final IntervalWindow window3 = sortedWindows.get(0); final IntervalWindow window4 = sortedWindows.get(1); // Prepare to test CombineStreamTransform final Transform.Context context = mock(Transform.Context.class); final TestOutputCollector> oc = new TestOutputCollector(); - combine_transform.prepare(context, oc); + combineTransform.prepare(context, oc); - combine_transform.onData(WindowedValue.of( + combineTransform.onData(WindowedValue.of( KV.of("a", 1), ts1, slidingWindows.assignWindows(ts1), PaneInfo.NO_FIRING)); - combine_transform.onData(WindowedValue.of( + combineTransform.onData(WindowedValue.of( KV.of("c", 1), ts2, slidingWindows.assignWindows(ts2), PaneInfo.NO_FIRING)); - combine_transform.onData(WindowedValue.of( + combineTransform.onData(WindowedValue.of( KV.of("b", 1), ts3, slidingWindows.assignWindows(ts3), PaneInfo.NO_FIRING)); // Emit outputs of window1 - combine_transform.onWatermark(watermark1); - Collections.sort(oc.outputs, (o1, o2) -> o1.getValue().getKey().compareTo(o2.getValue().getKey())); + combineTransform.onWatermark(watermark1); + oc.outputs.sort(Comparator.comparing(o -> o.getValue().getKey())); // Check outputs - assertEquals(Arrays.asList(window1), oc.outputs.get(0).getWindows()); - assertEquals(2, oc.outputs.size()); + Assert.assertEquals(Collections.singletonList(window1), oc.outputs.get(0).getWindows()); + Assert.assertEquals(2, oc.outputs.size()); checkOutput(KV.of("a", 1), oc.outputs.get(0).getValue()); checkOutput(KV.of("c", 1), oc.outputs.get(1).getValue()); oc.outputs.clear(); oc.watermarks.clear(); - combine_transform.onData(WindowedValue.of( + combineTransform.onData(WindowedValue.of( KV.of("a", 1), ts4, slidingWindows.assignWindows(ts4), PaneInfo.NO_FIRING)); - combine_transform.onData(WindowedValue.of( + combineTransform.onData(WindowedValue.of( KV.of("c", 1), ts5, slidingWindows.assignWindows(ts5), PaneInfo.NO_FIRING)); // Emit outputs of window2 - combine_transform.onWatermark(watermark2); - Collections.sort(oc.outputs, (o1, o2) -> o1.getValue().getKey().compareTo(o2.getValue().getKey())); + combineTransform.onWatermark(watermark2); + oc.outputs.sort(Comparator.comparing(o -> o.getValue().getKey())); // Check outputs - assertEquals(Arrays.asList(window2), oc.outputs.get(0).getWindows()); - assertEquals(3, oc.outputs.size()); + Assert.assertEquals(Collections.singletonList(window2), oc.outputs.get(0).getWindows()); + Assert.assertEquals(3, oc.outputs.size()); checkOutput(KV.of("a", 2), oc.outputs.get(0).getValue()); checkOutput(KV.of("b", 1), oc.outputs.get(1).getValue()); checkOutput(KV.of("c", 1), oc.outputs.get(2).getValue()); oc.outputs.clear(); oc.watermarks.clear(); - combine_transform.onData(WindowedValue.of( + combineTransform.onData(WindowedValue.of( KV.of("b", 1), ts6, slidingWindows.assignWindows(ts6), PaneInfo.NO_FIRING)); - combine_transform.onData(WindowedValue.of( + combineTransform.onData(WindowedValue.of( KV.of("b", 1), ts7, slidingWindows.assignWindows(ts7), PaneInfo.NO_FIRING)); - combine_transform.onData(WindowedValue.of( + combineTransform.onData(WindowedValue.of( KV.of("a", 1), ts8, slidingWindows.assignWindows(ts8), PaneInfo.NO_FIRING)); // Emit outputs of window3 - combine_transform.onWatermark(watermark3); - Collections.sort(oc.outputs, (o1, o2) -> o1.getValue().getKey().compareTo(o2.getValue().getKey())); + combineTransform.onWatermark(watermark3); + oc.outputs.sort(Comparator.comparing(o -> o.getValue().getKey())); // Check outputs - assertEquals(Arrays.asList(window3), oc.outputs.get(0).getWindows()); + Assert.assertEquals(Collections.singletonList(window3), oc.outputs.get(0).getWindows()); checkOutput(KV.of("a", 1), oc.outputs.get(0).getValue()); checkOutput(KV.of("b", 2), oc.outputs.get(1).getValue()); checkOutput(KV.of("c", 1), oc.outputs.get(2).getValue()); @@ -238,15 +241,15 @@ public void test_combine() { oc.watermarks.clear(); - combine_transform.onData(WindowedValue.of( + combineTransform.onData(WindowedValue.of( KV.of("c", 3), ts9, slidingWindows.assignWindows(ts9), PaneInfo.NO_FIRING)); // Emit outputs of window3 - combine_transform.onWatermark(watermark4); - Collections.sort(oc.outputs, (o1, o2) -> o1.getValue().getKey().compareTo(o2.getValue().getKey())); + combineTransform.onWatermark(watermark4); + oc.outputs.sort(Comparator.comparing(o -> o.getValue().getKey())); // Check outputs - assertEquals(Arrays.asList(window4), oc.outputs.get(0).getWindows()); + Assert.assertEquals(Collections.singletonList(window4), oc.outputs.get(0).getWindows()); checkOutput(KV.of("a", 1), oc.outputs.get(0).getValue()); checkOutput(KV.of("b", 2), oc.outputs.get(1).getValue()); checkOutput(KV.of("c", 4), oc.outputs.get(2).getValue()); @@ -255,10 +258,176 @@ public void test_combine() { oc.watermarks.clear(); } + private void clearOutputCollectors(final List outputCollectors) { + for (TestOutputCollector oc : outputCollectors) { + oc.outputs.clear(); + oc.watermarks.clear(); + } + } + + @SuppressWarnings("unchecked") + private void processIntermediateCombineElementsPerWindow(final List elements, + final CombineTransform partialCombineTransform, + final CombineTransform intermediateCombineTransform, + final CombineTransform finalCombineTransform, + final Watermark watermark, + final TestOutputCollector ocPartial, + final TestOutputCollector ocIntermediate) { + for (WindowedValue element : elements) { + partialCombineTransform.onData(element); + } + partialCombineTransform.onWatermark(watermark); + for (WindowedValue output: ocPartial.outputs) { + intermediateCombineTransform.onData(output); + } + intermediateCombineTransform.onWatermark(watermark); + for (WindowedValue output: ocIntermediate.outputs) { + finalCombineTransform.onData(output); + } + finalCombineTransform.onWatermark(watermark); + } + + /** + * Test intermediate combine. + */ + @Test + @SuppressWarnings("unchecked") + public void testIntermediateCombine() { + final TupleTag outputTag = new TupleTag<>("main-output"); + final SlidingWindows slidingWindows = SlidingWindows.of(Duration.standardSeconds(10)) + .every(Duration.standardSeconds(5)); + final WindowingStrategy windowingStrategy = WindowingStrategy.of(slidingWindows).withMode(ACCUMULATING_FIRED_PANES); + + final Instant ts1 = new Instant(1000); + final Instant ts2 = new Instant(2000); + final Instant ts3 = new Instant(6000); + final Instant ts4 = new Instant(8000); + final Instant ts5 = new Instant(11000); + final Instant ts6 = new Instant(14000); + final Instant ts7 = new Instant(16000); + final Instant ts8 = new Instant(17000); + final Instant ts9 = new Instant(19000); + final Watermark watermark1 = new Watermark(7000); + final Watermark watermark2 = new Watermark(12000); + final Watermark watermark3 = new Watermark(18000); + final Watermark watermark4 = new Watermark(21000); + + final KvCoder inputCoder = KvCoder.of(STRING_CODER, INTEGER_CODER); + final Coder accumulatorCoder; + try { + accumulatorCoder = combine_fn.getAccumulatorCoder(CoderRegistry.createDefault(), INTEGER_CODER); + } catch (CannotProvideCoderException e) { + throw new RuntimeException(e); + } + + final CombineFnBase.GlobalCombineFn partialCombineFn = new PartialCombineFn(combine_fn, accumulatorCoder); + final CombineFnBase.GlobalCombineFn intermediateCombineFn = new IntermediateCombineFn(combine_fn, accumulatorCoder); + final CombineFnBase.GlobalCombineFn finalCombineFn = new FinalCombineFn(combine_fn, accumulatorCoder); + + final SystemReduceFn partialSystemReduceFn = SystemReduceFn.combining(STRING_CODER, + AppliedCombineFn.withInputCoder(partialCombineFn, CoderRegistry.createDefault(), + inputCoder, null, windowingStrategy)); + final SystemReduceFn intermediateSystemReduceFn = SystemReduceFn.combining(STRING_CODER, + AppliedCombineFn.withInputCoder(intermediateCombineFn, CoderRegistry.createDefault(), + KvCoder.of(STRING_CODER, accumulatorCoder), null, windowingStrategy)); + final SystemReduceFn finalSystemReduceFn = SystemReduceFn.combining(STRING_CODER, + AppliedCombineFn.withInputCoder(finalCombineFn, CoderRegistry.createDefault(), + KvCoder.of(STRING_CODER, accumulatorCoder), null, windowingStrategy)); + + final CombineTransformFactory combineTransformFactory = new CombineTransformFactory( + inputCoder, new TupleTag<>(), KvCoder.of(STRING_CODER, accumulatorCoder), + Collections.singletonMap(outputTag, inputCoder), outputTag, windowingStrategy, + PipelineOptionsFactory.as(NemoPipelineOptions.class), partialSystemReduceFn, intermediateSystemReduceFn, + finalSystemReduceFn, DoFnSchemaInformation.create(), DisplayData.none()); + + final CombineTransform partialCombineTransform = combineTransformFactory.getPartialCombineTransform(); + final CombineTransform intermediateCombineTransform = combineTransformFactory.getIntermediateCombineTransform(); + final CombineTransform finalCombineTransform = combineTransformFactory.getFinalCombineTransform(); + + // window1 : [-5000, 5000) in millisecond + // window2 : [0, 10000) + // window3 : [5000, 15000) + // window4 : [10000, 20000) + List sortedWindows = new ArrayList<>(slidingWindows.assignWindows(ts1)); + sortedWindows.sort(IntervalWindow::compareTo); + final IntervalWindow window1 = sortedWindows.get(0); + final IntervalWindow window2 = sortedWindows.get(1); + sortedWindows = new ArrayList<>(slidingWindows.assignWindows(ts5)); + sortedWindows.sort(IntervalWindow::compareTo); + final IntervalWindow window3 = sortedWindows.get(0); + final IntervalWindow window4 = sortedWindows.get(1); + + // Prepare to test CombineStreamTransform + final Transform.Context context = mock(Transform.Context.class); + final TestOutputCollector> ocPartial = new TestOutputCollector<>(); + final TestOutputCollector> ocIntermediate = new TestOutputCollector<>(); + final TestOutputCollector> ocFinal = new TestOutputCollector<>(); + partialCombineTransform.prepare(context, ocPartial); + intermediateCombineTransform.prepare(context, ocIntermediate); + finalCombineTransform.prepare(context, ocFinal); + + processIntermediateCombineElementsPerWindow( + Arrays.asList( + WindowedValue.of(KV.of("a", 1), ts1, slidingWindows.assignWindows(ts1), PaneInfo.NO_FIRING), + WindowedValue.of(KV.of("c", 1), ts2, slidingWindows.assignWindows(ts2), PaneInfo.NO_FIRING), + WindowedValue.of(KV.of("b", 1), ts3, slidingWindows.assignWindows(ts3), PaneInfo.NO_FIRING)), + partialCombineTransform, intermediateCombineTransform, finalCombineTransform, + watermark1, ocPartial, ocIntermediate); + ocFinal.outputs.sort(Comparator.comparing(o -> o.getValue().getKey())); + + Assert.assertEquals(Collections.singletonList(window1), ocFinal.outputs.get(0).getWindows()); + Assert.assertEquals(2, ocFinal.outputs.size()); + checkOutput(KV.of("a", 1), ocFinal.outputs.get(0).getValue()); + checkOutput(KV.of("c", 1), ocFinal.outputs.get(1).getValue()); + clearOutputCollectors(Arrays.asList(ocPartial, ocIntermediate, ocFinal)); + + processIntermediateCombineElementsPerWindow( + Arrays.asList( + WindowedValue.of(KV.of("a", 1), ts4, slidingWindows.assignWindows(ts4), PaneInfo.NO_FIRING), + WindowedValue.of(KV.of("c", 1), ts5, slidingWindows.assignWindows(ts5), PaneInfo.NO_FIRING)), + partialCombineTransform, intermediateCombineTransform, finalCombineTransform, + watermark2, ocPartial, ocIntermediate); + ocFinal.outputs.sort(Comparator.comparing(o -> o.getValue().getKey())); + + Assert.assertEquals(Collections.singletonList(window2), ocFinal.outputs.get(0).getWindows()); + Assert.assertEquals(3, ocFinal.outputs.size()); + checkOutput(KV.of("a", 2), ocFinal.outputs.get(0).getValue()); + checkOutput(KV.of("b", 1), ocFinal.outputs.get(1).getValue()); + checkOutput(KV.of("c", 1), ocFinal.outputs.get(2).getValue()); + clearOutputCollectors(Arrays.asList(ocPartial, ocIntermediate, ocFinal)); + + processIntermediateCombineElementsPerWindow( + Arrays.asList( + WindowedValue.of(KV.of("b", 1), ts6, slidingWindows.assignWindows(ts6), PaneInfo.NO_FIRING), + WindowedValue.of(KV.of("b", 1), ts7, slidingWindows.assignWindows(ts7), PaneInfo.NO_FIRING), + WindowedValue.of(KV.of("a", 1), ts8, slidingWindows.assignWindows(ts8), PaneInfo.NO_FIRING)), + partialCombineTransform, intermediateCombineTransform, finalCombineTransform, + watermark3, ocPartial, ocIntermediate); + ocFinal.outputs.sort(Comparator.comparing(o -> o.getValue().getKey())); + + Assert.assertEquals(Collections.singletonList(window3), ocFinal.outputs.get(0).getWindows()); + checkOutput(KV.of("a", 1), ocFinal.outputs.get(0).getValue()); + checkOutput(KV.of("b", 2), ocFinal.outputs.get(1).getValue()); + checkOutput(KV.of("c", 1), ocFinal.outputs.get(2).getValue()); + clearOutputCollectors(Arrays.asList(ocPartial, ocIntermediate, ocFinal)); + + processIntermediateCombineElementsPerWindow( + Arrays.asList(WindowedValue.of(KV.of("c", 3), ts9, slidingWindows.assignWindows(ts9), PaneInfo.NO_FIRING)), + partialCombineTransform, intermediateCombineTransform, finalCombineTransform, + watermark4, ocPartial, ocIntermediate); + ocFinal.outputs.sort(Comparator.comparing(o -> o.getValue().getKey())); + + Assert.assertEquals(Collections.singletonList(window4), ocFinal.outputs.get(0).getWindows()); + checkOutput(KV.of("a", 1), ocFinal.outputs.get(0).getValue()); + checkOutput(KV.of("b", 2), ocFinal.outputs.get(1).getValue()); + checkOutput(KV.of("c", 4), ocFinal.outputs.get(2).getValue()); + clearOutputCollectors(Arrays.asList(ocPartial, ocIntermediate, ocFinal)); + } + // Test with late data @Test @SuppressWarnings("unchecked") - public void test_combine_lateData() { + public void testCombineLateData() { final TupleTag outputTag = new TupleTag<>("main-output"); final Duration lateness = Duration.standardSeconds(2); final SlidingWindows slidingWindows = SlidingWindows.of(Duration.standardSeconds(10)) @@ -271,7 +440,7 @@ public void test_combine_lateData() { final Watermark watermark1 = new Watermark(6500); final Watermark watermark2 = new Watermark(8000); - AppliedCombineFn applied_combine_fn = + AppliedCombineFn appliedCombineFn = AppliedCombineFn.withInputCoder( combine_fn, CoderRegistry.createDefault(), @@ -280,14 +449,14 @@ public void test_combine_lateData() { WindowingStrategy.of(slidingWindows).withMode(ACCUMULATING_FIRED_PANES).withAllowedLateness(lateness) ); - final GBKTransform combine_transform = - new GBKTransform( + final CombineTransform combineTransform = + new CombineTransform( KvCoder.of(STRING_CODER, INTEGER_CODER), Collections.singletonMap(outputTag, KvCoder.of(STRING_CODER, INTEGER_CODER)), outputTag, WindowingStrategy.of(slidingWindows).withMode(ACCUMULATING_FIRED_PANES).withAllowedLateness(lateness), PipelineOptionsFactory.as(NemoPipelineOptions.class), - SystemReduceFn.combining(STRING_CODER, applied_combine_fn), + SystemReduceFn.combining(STRING_CODER, appliedCombineFn), DoFnSchemaInformation.create(), DisplayData.none(), false); @@ -297,52 +466,52 @@ public void test_combine_lateData() { // window3 : [5000, 15000) // window4 : [10000, 20000) List sortedWindows = new ArrayList<>(slidingWindows.assignWindows(ts1)); - Collections.sort(sortedWindows, IntervalWindow::compareTo); + sortedWindows.sort(IntervalWindow::compareTo); final IntervalWindow window1 = sortedWindows.get(0); final IntervalWindow window2 = sortedWindows.get(1); sortedWindows = new ArrayList<>(slidingWindows.assignWindows(ts4)); - Collections.sort(sortedWindows, IntervalWindow::compareTo); + sortedWindows.sort(IntervalWindow::compareTo); final IntervalWindow window3 = sortedWindows.get(0); final IntervalWindow window4 = sortedWindows.get(1); // Prepare to test final Transform.Context context = mock(Transform.Context.class); final TestOutputCollector> oc = new TestOutputCollector(); - combine_transform.prepare(context, oc); + combineTransform.prepare(context, oc); - combine_transform.onData(WindowedValue.of( + combineTransform.onData(WindowedValue.of( KV.of("a", 1), ts1, slidingWindows.assignWindows(ts1), PaneInfo.NO_FIRING)); - combine_transform.onData(WindowedValue.of( + combineTransform.onData(WindowedValue.of( KV.of("b", 1), ts2, slidingWindows.assignWindows(ts2), PaneInfo.NO_FIRING)); // On-time firing of window1. Skipping checking outputs since test1 checks output from non-late data - combine_transform.onWatermark(watermark1); + combineTransform.onWatermark(watermark1); oc.outputs.clear(); // Late data in window 1. Should be accumulated since EOW + allowed lateness > current Watermark - combine_transform.onData(WindowedValue.of( + combineTransform.onData(WindowedValue.of( KV.of("a", 5), ts1, slidingWindows.assignWindows(ts1), PaneInfo.NO_FIRING)); // Check outputs - assertEquals(Arrays.asList(window1), oc.outputs.get(0).getWindows()); - assertEquals(1,oc.outputs.size()); - assertEquals(LATE, oc.outputs.get(0).getPane().getTiming()); + Assert.assertEquals(Collections.singletonList(window1), oc.outputs.get(0).getWindows()); + Assert.assertEquals(1, oc.outputs.size()); + Assert.assertEquals(LATE, oc.outputs.get(0).getPane().getTiming()); checkOutput(KV.of("a", 6), oc.outputs.get(0).getValue()); oc.outputs.clear(); oc.watermarks.clear(); // Late data in window 1. Should NOT be accumulated to outputs of window1 since EOW + allowed lateness > current Watermark - combine_transform.onWatermark(watermark2); - combine_transform.onData(WindowedValue.of( + combineTransform.onWatermark(watermark2); + combineTransform.onData(WindowedValue.of( KV.of("a", 10), ts3, slidingWindows.assignWindows(ts3), PaneInfo.NO_FIRING)); - Collections.sort(oc.outputs, (o1, o2) -> o1.getValue().getKey().compareTo(o2.getValue().getKey())); + oc.outputs.sort(Comparator.comparing(o -> o.getValue().getKey())); // Check outputs - assertEquals(Arrays.asList(window1), oc.outputs.get(0).getWindows()); - assertEquals(1, oc.outputs.size()); - assertEquals(LATE, oc.outputs.get(0).getPane().getTiming()); + Assert.assertEquals(Collections.singletonList(window1), oc.outputs.get(0).getWindows()); + Assert.assertEquals(1, oc.outputs.size()); + Assert.assertEquals(LATE, oc.outputs.get(0).getPane().getTiming()); checkOutput(KV.of("a", 10), oc.outputs.get(0).getValue()); oc.outputs.clear(); oc.watermarks.clear(); @@ -369,14 +538,14 @@ public void test_combine_lateData() { @Test @SuppressWarnings("unchecked") - public void test_gbk() { + public void testGBK() { final TupleTag outputTag = new TupleTag<>("main-output"); final SlidingWindows slidingWindows = SlidingWindows.of(Duration.standardSeconds(2)) .every(Duration.standardSeconds(1)); - final GBKTransform> doFnTransform = - new GBKTransform( + final CombineTransform> doFnTransform = + new CombineTransform( KvCoder.of(STRING_CODER, STRING_CODER), Collections.singletonMap(outputTag, KvCoder.of(STRING_CODER, IterableCoder.of(STRING_CODER))), outputTag, @@ -403,7 +572,7 @@ public void test_gbk() { List sortedWindows = new ArrayList<>(slidingWindows.assignWindows(ts1)); - Collections.sort(sortedWindows, IntervalWindow::compareTo); + sortedWindows.sort(IntervalWindow::compareTo); // [0---1000) final IntervalWindow window0 = sortedWindows.get(0); @@ -412,7 +581,7 @@ public void test_gbk() { sortedWindows.clear(); sortedWindows = new ArrayList<>(slidingWindows.assignWindows(ts4)); - Collections.sort(sortedWindows, IntervalWindow::compareTo); + sortedWindows.sort(IntervalWindow::compareTo); // [1000--3000) final IntervalWindow window2 = sortedWindows.get(1); @@ -436,21 +605,21 @@ public void test_gbk() { // output // 1: ["hello", "world"] // 2: ["hello"] - Collections.sort(oc.outputs, (o1, o2) -> o1.getValue().getKey().compareTo(o2.getValue().getKey())); + oc.outputs.sort(Comparator.comparing(o -> o.getValue().getKey())); // windowed result for key 1 - assertEquals(Arrays.asList(window0), oc.outputs.get(0).getWindows()); + Assert.assertEquals(Collections.singletonList(window0), oc.outputs.get(0).getWindows()); checkOutput2(KV.of("1", Arrays.asList("hello", "world")), oc.outputs.get(0).getValue()); // windowed result for key 2 - assertEquals(Arrays.asList(window0), oc.outputs.get(1).getWindows()); - checkOutput2(KV.of("2", Arrays.asList("hello")), oc.outputs.get(1).getValue()); + Assert.assertEquals(Collections.singletonList(window0), oc.outputs.get(1).getWindows()); + checkOutput2(KV.of("2", Collections.singletonList("hello")), oc.outputs.get(1).getValue()); - assertEquals(2, oc.outputs.size()); - assertEquals(2, oc.watermarks.size()); + Assert.assertEquals(2, oc.outputs.size()); + Assert.assertEquals(2, oc.watermarks.size()); // check output watermark - assertEquals(1000, + Assert.assertEquals(1000, oc.watermarks.get(0).getTimestamp()); oc.outputs.clear(); @@ -462,8 +631,8 @@ public void test_gbk() { doFnTransform.onWatermark(watermark2); - assertEquals(0, oc.outputs.size()); // do not emit anything - assertEquals(0, oc.watermarks.size()); + Assert.assertEquals(0, oc.outputs.size()); // do not emit anything + Assert.assertEquals(0, oc.watermarks.size()); doFnTransform.onData(WindowedValue.of( KV.of("3", "a"), ts5, slidingWindows.assignWindows(ts5), PaneInfo.NO_FIRING)); @@ -481,23 +650,23 @@ public void test_gbk() { // 1: ["hello", "world", "a"] // 2: ["hello"] // 3: ["a", "a", "b"] - Collections.sort(oc.outputs, (o1, o2) -> o1.getValue().getKey().compareTo(o2.getValue().getKey())); + oc.outputs.sort(Comparator.comparing(o -> o.getValue().getKey())); // windowed result for key 1 - assertEquals(Arrays.asList(window1), oc.outputs.get(0).getWindows()); + Assert.assertEquals(Collections.singletonList(window1), oc.outputs.get(0).getWindows()); checkOutput2(KV.of("1", Arrays.asList("hello", "world", "a")), oc.outputs.get(0).getValue()); // windowed result for key 2 - assertEquals(Arrays.asList(window1), oc.outputs.get(1).getWindows()); - checkOutput2(KV.of("2", Arrays.asList("hello")), oc.outputs.get(1).getValue()); + Assert.assertEquals(Collections.singletonList(window1), oc.outputs.get(1).getWindows()); + checkOutput2(KV.of("2", Collections.singletonList("hello")), oc.outputs.get(1).getValue()); // windowed result for key 3 - assertEquals(Arrays.asList(window1), oc.outputs.get(2).getWindows()); + Assert.assertEquals(Collections.singletonList(window1), oc.outputs.get(2).getWindows()); checkOutput2(KV.of("3", Arrays.asList("a", "a", "b")), oc.outputs.get(2).getValue()); // check output watermark - assertEquals(2000, + Assert.assertEquals(2000, oc.watermarks.get(0).getTimestamp()); oc.outputs.clear(); @@ -516,20 +685,20 @@ public void test_gbk() { // output // 1: ["a", "a"] // 3: ["a", "a", "b", "b"] - Collections.sort(oc.outputs, (o1, o2) -> o1.getValue().getKey().compareTo(o2.getValue().getKey())); + oc.outputs.sort(Comparator.comparing(o -> o.getValue().getKey())); - assertEquals(2, oc.outputs.size()); + Assert.assertEquals(2, oc.outputs.size()); // windowed result for key 1 - assertEquals(Arrays.asList(window2), oc.outputs.get(0).getWindows()); + Assert.assertEquals(Collections.singletonList(window2), oc.outputs.get(0).getWindows()); checkOutput2(KV.of("1", Arrays.asList("a", "a")), oc.outputs.get(0).getValue()); // windowed result for key 3 - assertEquals(Arrays.asList(window2), oc.outputs.get(1).getWindows()); + Assert.assertEquals(Collections.singletonList(window2), oc.outputs.get(1).getWindows()); checkOutput2(KV.of("3", Arrays.asList("a", "a", "b", "b")), oc.outputs.get(1).getValue()); // check output watermark - assertEquals(3000, + Assert.assertEquals(3000, oc.watermarks.get(0).getTimestamp()); doFnTransform.close(); @@ -539,7 +708,7 @@ public void test_gbk() { * Test complex triggers that emit early and late firing. */ @Test - public void test_gbk_eventTimeTrigger() { + public void testGBKEventTimeTrigger() { final Duration lateness = Duration.standardSeconds(1); final AfterWatermark.AfterWatermarkEarlyAndLate trigger = AfterWatermark.pastEndOfWindow() // early firing @@ -561,8 +730,8 @@ public void test_gbk_eventTimeTrigger() { final TupleTag outputTag = new TupleTag<>("main-output"); - final GBKTransform> doFnTransform = - new GBKTransform( + final CombineTransform> doFnTransform = + new CombineTransform( KvCoder.of(STRING_CODER, STRING_CODER), Collections.singletonMap(outputTag, KvCoder.of(STRING_CODER, IterableCoder.of(STRING_CODER))), outputTag, @@ -591,8 +760,8 @@ public void test_gbk_eventTimeTrigger() { // early firing is not related to the watermark progress doFnTransform.onWatermark(new Watermark(2)); - assertEquals(1, oc.outputs.size()); - assertEquals(EARLY, oc.outputs.get(0).getPane().getTiming()); + Assert.assertEquals(1, oc.outputs.size()); + Assert.assertEquals(EARLY, oc.outputs.get(0).getPane().getTiming()); oc.outputs.clear(); doFnTransform.onData(WindowedValue.of( @@ -607,8 +776,8 @@ public void test_gbk_eventTimeTrigger() { // GBKTransform emits data when receiving watermark // TODO #250: element-wise processing doFnTransform.onWatermark(new Watermark(5)); - assertEquals(1, oc.outputs.size()); - assertEquals(EARLY, oc.outputs.get(0).getPane().getTiming()); + Assert.assertEquals(1, oc.outputs.size()); + Assert.assertEquals(EARLY, oc.outputs.get(0).getPane().getTiming()); // ACCUMULATION MODE checkOutput2(KV.of("1", Arrays.asList("hello", "world")), oc.outputs.get(0).getValue()); oc.outputs.clear(); @@ -617,8 +786,8 @@ public void test_gbk_eventTimeTrigger() { doFnTransform.onData(WindowedValue.of( KV.of("1", "!!"), new Instant(3), window.assignWindow(new Instant(3)), PaneInfo.NO_FIRING)); doFnTransform.onWatermark(new Watermark(5001)); - assertEquals(1, oc.outputs.size()); - assertEquals(ON_TIME, oc.outputs.get(0).getPane().getTiming()); + Assert.assertEquals(1, oc.outputs.size()); + Assert.assertEquals(ON_TIME, oc.outputs.get(0).getPane().getTiming()); // ACCUMULATION MODE checkOutput2(KV.of("1", Arrays.asList("hello", "world", "!!")), oc.outputs.get(0).getValue()); oc.outputs.clear(); @@ -634,8 +803,8 @@ public void test_gbk_eventTimeTrigger() { KV.of("1", "bye!"), new Instant(1000), window.assignWindow(new Instant(1000)), PaneInfo.NO_FIRING)); doFnTransform.onWatermark(new Watermark(6000)); - assertEquals(1, oc.outputs.size()); - assertEquals(LATE, oc.outputs.get(0).getPane().getTiming()); + Assert.assertEquals(1, oc.outputs.size()); + Assert.assertEquals(LATE, oc.outputs.get(0).getPane().getTiming()); // The data should be accumulated to the previous window because it allows 1 second lateness checkOutput2(KV.of("1", Arrays.asList("hello", "world", "!!", "bye!")), oc.outputs.get(0).getValue()); oc.outputs.clear(); @@ -651,8 +820,8 @@ public void test_gbk_eventTimeTrigger() { KV.of("1", "hello again!"), new Instant(4800), window.assignWindow(new Instant(4800)), PaneInfo.NO_FIRING)); doFnTransform.onWatermark(new Watermark(6300)); - assertEquals(1, oc.outputs.size()); - assertEquals(LATE, oc.outputs.get(0).getPane().getTiming()); + Assert.assertEquals(1, oc.outputs.size()); + Assert.assertEquals(LATE, oc.outputs.get(0).getPane().getTiming()); checkOutput2(KV.of("1", Arrays.asList("hello again!")), oc.outputs.get(0).getValue()); oc.outputs.clear(); doFnTransform.close(); diff --git a/compiler/test/src/test/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/IntermediateAccumulatorInsertionPassTest.java b/compiler/test/src/test/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/IntermediateAccumulatorInsertionPassTest.java new file mode 100644 index 0000000000..7fff8a9f9b --- /dev/null +++ b/compiler/test/src/test/java/org/apache/nemo/compiler/optimizer/pass/compiletime/reshaping/IntermediateAccumulatorInsertionPassTest.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.nemo.compiler.optimizer.pass.compiletime.reshaping; + +import org.apache.nemo.client.JobLauncher; +import org.apache.nemo.common.ir.IRDAG; +import org.apache.nemo.common.ir.vertex.executionproperty.ShuffleExecutorSetProperty; +import org.apache.nemo.compiler.CompilerTestUtil; +import org.apache.nemo.compiler.optimizer.pass.compiletime.annotating.*; +import org.apache.nemo.compiler.optimizer.policy.Policy; +import org.apache.nemo.compiler.optimizer.policy.PolicyBuilder; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.powermock.core.classloader.annotations.PrepareForTest; +import org.powermock.modules.junit4.PowerMockRunner; + +import static junit.framework.TestCase.assertTrue; +import static org.apache.nemo.common.dag.DAG.EMPTY_DAG_DIRECTORY; + +/** + * Test {@link IntermediateAccumulatorInsertionPass}. + */ +@RunWith(PowerMockRunner.class) +@PrepareForTest(JobLauncher.class) +public class IntermediateAccumulatorInsertionPassTest { + private IRDAG compiledDAG; + + @Before + public void setUp() throws Exception { + compiledDAG = CompilerTestUtil.compileWindowedWordcountIntermediateAccumulationDAG(); + } + + @Test + public void testIntermediateAccumulatorInsertionPass() { + final PolicyBuilder builder = new PolicyBuilder(); + builder.registerCompileTimePass(new DefaultParallelismPass(16, 2)) + .registerCompileTimePass(new DefaultEdgeEncoderPass()) + .registerCompileTimePass(new DefaultEdgeDecoderPass()) + .registerCompileTimePass(new DefaultDataStorePass()) + .registerCompileTimePass(new DefaultDataPersistencePass()) + .registerCompileTimePass(new DefaultScheduleGroupPass()) + .registerCompileTimePass(new CompressionPass()) + .registerCompileTimePass(new ResourceLocalityPass()) + .registerCompileTimePass(new ResourceSlotPass()) + .registerCompileTimePass(new PipeTransferForAllEdgesPass()) + .registerCompileTimePass(new IntermediateAccumulatorInsertionPass(true)); + final Policy policy = builder.build(); + compiledDAG = policy.runCompileTimeOptimization(compiledDAG, EMPTY_DAG_DIRECTORY); + assertTrue(compiledDAG.getTopologicalSort().stream() + .anyMatch(v -> v.getPropertyValue(ShuffleExecutorSetProperty.class).isPresent())); + assertTrue(compiledDAG.checkIntegrity().isPassed()); + } +} diff --git a/examples/resources/inputs/example_labeldict.json b/examples/resources/inputs/example_labeldict.json new file mode 100644 index 0000000000..daa1b4aa52 --- /dev/null +++ b/examples/resources/inputs/example_labeldict.json @@ -0,0 +1,7 @@ +{ + "0": ["mulan-16.maas", "0"], + "1": ["mulan-23.maas", "0"], + "2": ["mulan-m", "0"], + "3": ["1+2", "0.00003721"], + "4": ["0+3", "2.19395143"] +} diff --git a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeInputReader.java b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeInputReader.java index cab2ed2f43..194fc258c3 100644 --- a/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeInputReader.java +++ b/runtime/executor/src/main/java/org/apache/nemo/runtime/executor/datatransfer/PipeInputReader.java @@ -71,7 +71,8 @@ public List> read() { if (comValue.equals(CommunicationPatternProperty.Value.ONE_TO_ONE)) { return Collections.singletonList(pipeManagerWorker.read(dstTaskIndex, runtimeEdge, dstTaskIndex)); } else if (comValue.equals(CommunicationPatternProperty.Value.BROADCAST) - || comValue.equals(CommunicationPatternProperty.Value.SHUFFLE)) { + || comValue.equals(CommunicationPatternProperty.Value.SHUFFLE) + || comValue.equals(CommunicationPatternProperty.Value.PARTIAL_SHUFFLE)) { final int numSrcTasks = InputReader.getSourceParallelism(this); final List> futures = new ArrayList<>(); for (int srcTaskIdx = 0; srcTaskIdx < numSrcTasks; srcTaskIdx++) {