diff --git a/service/src/main/scala/com/bytedance/css/service/deploy/master/ShuffleTaskManager.scala b/service/src/main/scala/com/bytedance/css/service/deploy/master/ShuffleTaskManager.scala index 5dbfac4..4f43df8 100644 --- a/service/src/main/scala/com/bytedance/css/service/deploy/master/ShuffleTaskManager.scala +++ b/service/src/main/scala/com/bytedance/css/service/deploy/master/ShuffleTaskManager.scala @@ -39,7 +39,7 @@ class ShuffleTaskManager extends Logging { // key appId-shuffleId // store all reducerId-EpochId for a shuffleKey private[deploy] val shuffleEpochSetMap = - new ConcurrentHashMap[String, ConcurrentHashMap[Int, util.List[PartitionInfo]]]() + new ConcurrentHashMap[String, util.HashSet[PartitionInfo]]() private[deploy] val batchBlacklistMap = new ConcurrentHashMap[String, ConcurrentHashMap[Int, util.List[FailedPartitionInfoBatch]]]() @@ -48,7 +48,7 @@ class ShuffleTaskManager extends Logging { numMappers: Int, numPartitions: Int): Unit = { shuffleMapperAttempts.putIfAbsent(shuffleKey, Array.fill(numMappers)(-1)) - shuffleEpochSetMap.putIfAbsent(shuffleKey, new ConcurrentHashMap[Int, util.List[PartitionInfo]]()) + shuffleEpochSetMap.putIfAbsent(shuffleKey, new util.HashSet[PartitionInfo]())) reducerFileGroupsMap.putIfAbsent(shuffleKey, new Array[Array[CommittedPartitionInfo]](numPartitions)) } @@ -71,9 +71,9 @@ class ShuffleTaskManager extends Logging { // epochList not be null. val epochSet = shuffleEpochSetMap.computeWhenAbsent(shuffleKey, _ => { - new ConcurrentHashMap[Int, util.List[PartitionInfo]]() + new util.HashSet[PartitionInfo]() }) - epochSet.put(mapId, epochList) + epochSet.addAll(epochList) if (batchBlacklist != null) { val blacklist = batchBlacklistMap.computeWhenAbsent(shuffleKey, _ => { @@ -112,10 +112,7 @@ class ShuffleTaskManager extends Logging { commitPieces: ConcurrentHashMap[String, ConcurrentSet[CommittedPartitionInfo]]): Boolean = { // check for data lost - val allEpochSets = new util.HashSet[PartitionInfo]() - allEpochSets.addAll( - shuffleEpochSetMap.getOrDefault(shuffleKey, new ConcurrentHashMap[Int, util.List[PartitionInfo]]()) - .values().asScala.flatMap(x => x.asScala.toSet[PartitionInfo]).toSet.asJava) + val allEpochSets = shuffleEpochSetMap.getOrDefault(shuffleKey, new util.HashSet[PartitionInfo]()) var dataLost = false val validCommitted = allEpochSets.asScala