Skip to content
This repository was archived by the owner on Jul 1, 2025. It is now read-only.
Open
34 changes: 34 additions & 0 deletions common/src/main/java/org/apache/nemo/common/ir/IRDAG.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.nemo.common.ir.executionproperty.ResourceSpecification;
import org.apache.nemo.common.ir.vertex.IRVertex;
import org.apache.nemo.common.ir.vertex.LoopVertex;
import org.apache.nemo.common.ir.vertex.OperatorVertex;
import org.apache.nemo.common.ir.vertex.SourceVertex;
import org.apache.nemo.common.ir.vertex.executionproperty.MessageIdVertexProperty;
import org.apache.nemo.common.ir.vertex.executionproperty.ParallelismProperty;
Expand Down Expand Up @@ -799,6 +800,39 @@ public void insert(final TaskSizeSplitterVertex toInsert) {
modifiedDAG = builder.build();
}

public void insert(final OperatorVertex accumulatorVertex, final IREdge targetEdge) {
// Create a completely new DAG with the vertex inserted.
final DAGBuilder<IRVertex, IREdge> builder = new DAGBuilder<>();

builder.addVertex(accumulatorVertex);
modifiedDAG.topologicalDo(v -> {
builder.addVertex(v);

modifiedDAG.getIncomingEdgesOf(v).forEach(e -> {
if (e == targetEdge) {
// Edge to the accumulatorVertex
final IREdge toAV = new IREdge(CommunicationPatternProperty.Value.PARTIAL_SHUFFLE,
e.getSrc(), accumulatorVertex);
e.copyExecutionPropertiesTo(toAV);
toAV.setProperty(CommunicationPatternProperty.of(CommunicationPatternProperty.Value.PARTIAL_SHUFFLE));

// Edge from the accumulatorVertex
final IREdge fromAV = new IREdge(CommunicationPatternProperty.Value.SHUFFLE, accumulatorVertex, e.getDst());
e.copyExecutionPropertiesTo(fromAV);

// Connect the new edges
builder.connectVertices(toAV);
builder.connectVertices(fromAV);
} else {
// Simply connect vertices as before
builder.connectVertices(e);
}
});
});

modifiedDAG = builder.build();
}

/**
* Reshape unsafely, without guarantees on preserving application semantics.
* TODO #330: Refactor Unsafe Reshaping Passes
Expand Down
31 changes: 28 additions & 3 deletions common/src/main/java/org/apache/nemo/common/ir/IRDAGChecker.java
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ private IRDAGChecker() {
addLoopVertexCheckers();
addScheduleGroupCheckers();
addCacheCheckers();
addIntermediateAccumulatorVertexCheckers();
}

/**
Expand Down Expand Up @@ -284,23 +285,25 @@ void addShuffleEdgeCheckers() {
final NeighborChecker shuffleChecker = ((v, inEdges, outEdges) -> {
for (final IREdge inEdge : inEdges) {
if (CommunicationPatternProperty.Value.SHUFFLE
.equals(inEdge.getPropertyValue(CommunicationPatternProperty.class).get())
|| CommunicationPatternProperty.Value.PARTIAL_SHUFFLE
.equals(inEdge.getPropertyValue(CommunicationPatternProperty.class).get())) {
// Shuffle edges must have the following properties
if (!inEdge.getPropertyValue(KeyExtractorProperty.class).isPresent()
|| !inEdge.getPropertyValue(KeyEncoderProperty.class).isPresent()
|| !inEdge.getPropertyValue(KeyDecoderProperty.class).isPresent()) {
return failure("Shuffle edge does not have a Key-related property: " + inEdge.getId());
return failure("(Partial)Shuffle edge does not have a Key-related property: " + inEdge.getId());
}
} else {
// Non-shuffle edges must not have the following properties
final Optional<Pair<PartitionerProperty.Type, Integer>> partitioner =
inEdge.getPropertyValue(PartitionerProperty.class);
if (partitioner.isPresent() && partitioner.get().left().equals(PartitionerProperty.Type.HASH)) {
return failure("Only shuffle can have the hash partitioner",
return failure("Only (partial)shuffle can have the hash partitioner",
inEdge, CommunicationPatternProperty.class, PartitionerProperty.class);
}
if (inEdge.getPropertyValue(PartitionSetProperty.class).isPresent()) {
return failure("Only shuffle can select partition sets",
return failure("Only (partial)shuffle can select partition sets",
inEdge, CommunicationPatternProperty.class, PartitionSetProperty.class);
}
}
Expand Down Expand Up @@ -486,6 +489,28 @@ void addEncodingCompressionCheckers() {
singleEdgeCheckerList.add(compressAndDecompress);
}

void addIntermediateAccumulatorVertexCheckers() {
final NeighborChecker shuffleExecutorSet = ((v, inEdges, outEdges) -> {
if (v.getPropertyValue(ShuffleExecutorSetProperty.class).isPresent()) {
if (inEdges.size() != 1 || outEdges.size() != 1 || inEdges.stream().anyMatch(e ->
!e.getPropertyValue(CommunicationPatternProperty.class).get()
.equals(CommunicationPatternProperty.Value.PARTIAL_SHUFFLE))) {
return failure("Only intermediate accumulator vertex can have shuffle executor set property", v);
} else if (v.getPropertyValue(ParallelismProperty.class).get()
< v.getPropertyValue(ShuffleExecutorSetProperty.class).get().size()) {
return failure("Parallelism must be greater or equal to the number of shuffle executor set", v);
}
} else {
if (inEdges.stream().anyMatch(e -> e.getPropertyValue(CommunicationPatternProperty.class).get()
.equals(CommunicationPatternProperty.Value.PARTIAL_SHUFFLE))) {
return failure("Intermediate accumulator vertex must have shuffle executor set property", v);
}
}
return success();
});
neighborCheckerList.add(shuffleExecutorSet);
}

/**
* Group outgoing edges by the additional output tag property.
* @param outEdges the outedges to group.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public static CommunicationPatternProperty of(final Value value) {
public enum Value {
ONE_TO_ONE,
BROADCAST,
SHUFFLE
SHUFFLE,
PARTIAL_SHUFFLE
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ public static ExecutionPropertyMap<EdgeExecutionProperty> of(
map.put(EncoderProperty.of(EncoderFactory.DUMMY_ENCODER_FACTORY));
map.put(DecoderProperty.of(DecoderFactory.DUMMY_DECODER_FACTORY));
switch (commPattern) {
case PARTIAL_SHUFFLE:
case SHUFFLE:
map.put(DataFlowProperty.of(DataFlowProperty.Value.PULL));
map.put(PartitionerProperty.of(PartitionerProperty.Type.HASH));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.nemo.common.ir.vertex.executionproperty;

import org.apache.nemo.common.ir.executionproperty.VertexExecutionProperty;

import java.util.ArrayList;
import java.util.HashSet;

/**
* List of set of node names to limit the scheduling of the tasks of the vertex to while shuffling.
*/
public final class ShuffleExecutorSetProperty extends VertexExecutionProperty<ArrayList<HashSet<String>>> {

/**
* Default constructor.
* @param value value of the execution property.
*/
private ShuffleExecutorSetProperty(final ArrayList<HashSet<String>> value) {
super(value);
}

/**
* Static method for constructing {@link ShuffleExecutorSetProperty}.
*
* @param setsOfExecutors the list of executors to schedule the tasks of the vertex on.
* Leave empty to make it effectless.
* @return the new execution property
*/
public static ShuffleExecutorSetProperty of(final HashSet<HashSet<String>> setsOfExecutors) {
return new ShuffleExecutorSetProperty(new ArrayList<>(setsOfExecutors));
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package org.apache.nemo.common.ir.vertex.executionproperty;

import org.apache.nemo.common.Pair;
import org.apache.nemo.common.ir.executionproperty.VertexExecutionProperty;

import java.util.HashMap;
import java.util.List;

/**
* Keep track of where the tasks are located by its executor ID.
*/
public final class TaskIndexToExecutorIDProperty
extends VertexExecutionProperty<HashMap<Integer, List<Pair<String, String>>>> {
/**
* Default constructor.
* @param taskIDToExecutorIDsMap value of the execution property.
*/
private TaskIndexToExecutorIDProperty(final HashMap<Integer, List<Pair<String, String>>> taskIDToExecutorIDsMap) {
super(taskIDToExecutorIDsMap);
}

/**
* Static method for constructing {@link TaskIndexToExecutorIDProperty}.
*
* @param taskIndexToExecutorIDsMap the map indicating the executor IDs where the tasks are located on.
* @return the new execution property
*/
public static TaskIndexToExecutorIDProperty of(
final HashMap<Integer, List<Pair<String, String>>> taskIndexToExecutorIDsMap) {
return new TaskIndexToExecutorIDProperty(taskIndexToExecutorIDsMap);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,8 @@ public HashPartitioner(final int numOfPartitions,
public Integer partition(final Object element) {
return Math.abs(keyExtractor.extractKey(element).hashCode() % numOfPartitions);
}

public Integer partition(final Object element, final int numOfSubPartitions) {
return Math.abs(keyExtractor.extractKey(element).hashCode() % numOfSubPartitions);
}
}
9 changes: 9 additions & 0 deletions common/src/test/java/org/apache/nemo/common/ir/IRDAGTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,15 @@ public void testSplitterVertex() {
mustPass();
}

@Test
public void testAccumulatorVertex() {
final OperatorVertex cv = new OperatorVertex(new EmptyComponents.EmptyTransform("iav"));
cv.setProperty(ShuffleExecutorSetProperty.of(new HashSet<>()));
cv.setProperty(ParallelismProperty.of(5));
irdag.insert(cv, shuffleEdge);
mustPass();
}

private MessageAggregatorVertex insertNewTriggerVertex(final IRDAG dag, final IREdge edgeToGetStatisticsOf) {
final MessageGeneratorVertex mb = new MessageGeneratorVertex<>((l, r) -> null);
final MessageAggregatorVertex ma = new MessageAggregatorVertex<>(() -> new Object(), (l, r) -> null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ private CommunicationPatternProperty.Value getCommPattern(final IRVertex src, fi
}
// If GBKTransform represents a partial CombinePerKey transformation, we do NOT need to shuffle its input,
// since its output will be shuffled before going through a final CombinePerKey transformation.
if ((dstTransform instanceof GBKTransform && !((GBKTransform) dstTransform).getIsPartialCombining())
if ((dstTransform instanceof CombineTransform && !((CombineTransform) dstTransform).getIsPartialCombining())
|| dstTransform instanceof GroupByKeyTransform) {
return CommunicationPatternProperty.Value.SHUFFLE;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,10 +358,12 @@ private static Pipeline.PipelineVisitor.CompositeBehavior combinePerKeyTranslato
}

final CombineFnBase.GlobalCombineFn combineFn = perKey.getFn();
final AppliedPTransform pTransform = beamNode.toAppliedPTransform(ctx.getPipeline());

final PCollection<?> mainInput = (PCollection<?>) Iterables.getOnlyElement(
TransformInputs.nonAdditionalInputs(beamNode.toAppliedPTransform(ctx.getPipeline())));
TransformInputs.nonAdditionalInputs(pTransform));
final PCollection inputs = (PCollection) Iterables.getOnlyElement(
TransformInputs.nonAdditionalInputs(beamNode.toAppliedPTransform(ctx.getPipeline())));
TransformInputs.nonAdditionalInputs(pTransform));
final KvCoder inputCoder = (KvCoder) inputs.getCoder();
final Coder accumulatorCoder;

Expand All @@ -386,48 +388,52 @@ private static Pipeline.PipelineVisitor.CompositeBehavior combinePerKeyTranslato
finalCombine = new OperatorVertex(new CombineFnFinalTransform<>(combineFn));
} else {
// Stream data processing, using GBKTransform
final AppliedPTransform pTransform = beamNode.toAppliedPTransform(ctx.getPipeline());
final CombineFnBase.GlobalCombineFn partialCombineFn = new PartialCombineFn(
(Combine.CombineFn) combineFn, accumulatorCoder);
final CombineFnBase.GlobalCombineFn intermediateCombineFn = new IntermediateCombineFn(
(Combine.CombineFn) combineFn, accumulatorCoder);
final CombineFnBase.GlobalCombineFn finalCombineFn = new FinalCombineFn(
(Combine.CombineFn) combineFn, accumulatorCoder);

final SystemReduceFn partialSystemReduceFn =
SystemReduceFn.combining(
inputCoder.getKeyCoder(),
AppliedCombineFn.withInputCoder(partialCombineFn,
ctx.getPipeline().getCoderRegistry(), inputCoder,
null,
mainInput.getWindowingStrategy()));
ctx.getPipeline().getCoderRegistry(),
inputCoder,
null, mainInput.getWindowingStrategy()));
final SystemReduceFn intermediateSystemReduceFn =
SystemReduceFn.combining(
inputCoder.getKeyCoder(),
AppliedCombineFn.withInputCoder(intermediateCombineFn,
ctx.getPipeline().getCoderRegistry(),
KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder),
null, mainInput.getWindowingStrategy()));
final SystemReduceFn finalSystemReduceFn =
SystemReduceFn.combining(
inputCoder.getKeyCoder(),
AppliedCombineFn.withInputCoder(finalCombineFn,
ctx.getPipeline().getCoderRegistry(),
KvCoder.of(inputCoder.getKeyCoder(),
accumulatorCoder),
KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder),
null, mainInput.getWindowingStrategy()));
final TupleTag<?> partialMainOutputTag = new TupleTag<>();
final GBKTransform partialCombineStreamTransform =
new GBKTransform(inputCoder,
Collections.singletonMap(partialMainOutputTag, KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder)),
partialMainOutputTag,
mainInput.getWindowingStrategy(),
ctx.getPipelineOptions(),
partialSystemReduceFn,
DoFnSchemaInformation.create(),
DisplayData.from(beamNode.getTransform()),
true);

final GBKTransform finalCombineStreamTransform =
new GBKTransform(KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder),
final CombineTransformFactory combineTransformFactory =
new CombineTransformFactory(inputCoder,
partialMainOutputTag,
KvCoder.of(inputCoder.getKeyCoder(), accumulatorCoder),
getOutputCoders(pTransform),
Iterables.getOnlyElement(beamNode.getOutputs().keySet()),
mainInput.getWindowingStrategy(),
ctx.getPipelineOptions(),
partialSystemReduceFn,
intermediateSystemReduceFn,
finalSystemReduceFn,
DoFnSchemaInformation.create(),
DisplayData.from(beamNode.getTransform()),
false);
DisplayData.from(beamNode.getTransform()));

final CombineTransform partialCombineStreamTransform = combineTransformFactory.getPartialCombineTransform();
final CombineTransform finalCombineStreamTransform = combineTransformFactory.getFinalCombineTransform();

partialCombine = new OperatorVertex(partialCombineStreamTransform);
finalCombine = new OperatorVertex(finalCombineStreamTransform);
Expand Down Expand Up @@ -564,7 +570,7 @@ private static Transform createGBKTransform(
return new GroupByKeyTransform();
} else {
// GroupByKey Transform when using a non-global windowing strategy.
return new GBKTransform<>(
return new CombineTransform<>(
(KvCoder) mainInput.getCoder(),
getOutputCoders(pTransform),
mainOutputTag,
Expand Down
Loading