/*
 * Decompiled with CFR 0.152.
 */
package org.apache.uniffle.shuffle.manager;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.spark.MapOutputTracker;
import org.apache.spark.MapOutputTrackerMaster;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv;
import org.apache.spark.SparkException;
import org.apache.spark.shuffle.RssShuffleHandle;
import org.apache.spark.shuffle.RssSparkConfig;
import org.apache.spark.shuffle.RssSparkShuffleUtils;
import org.apache.spark.shuffle.RssStageInfo;
import org.apache.spark.shuffle.RssStageResubmitManager;
import org.apache.spark.shuffle.ShuffleHandleInfoManager;
import org.apache.spark.shuffle.ShuffleManager;
import org.apache.spark.shuffle.SparkVersionUtils;
import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
import org.apache.spark.shuffle.handle.ShuffleHandleInfo;
import org.apache.spark.shuffle.handle.SimpleShuffleHandleInfo;
import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo;
import org.apache.uniffle.client.api.ShuffleManagerClient;
import org.apache.uniffle.client.api.ShuffleWriteClient;
import org.apache.uniffle.client.factory.CoordinatorClientFactory;
import org.apache.uniffle.client.factory.ShuffleManagerClientFactory;
import org.apache.uniffle.client.impl.grpc.CoordinatorGrpcRetryableClient;
import org.apache.uniffle.client.request.RssFetchClientConfRequest;
import org.apache.uniffle.client.request.RssPartitionToShuffleServerRequest;
import org.apache.uniffle.client.response.RssFetchClientConfResponse;
import org.apache.uniffle.client.response.RssReassignOnBlockSendFailureResponse;
import org.apache.uniffle.client.response.RssReassignOnStageRetryResponse;
import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.PartitionRange;
import org.apache.uniffle.common.ReceivingFailureServer;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleAssignmentsInfo;
import org.apache.uniffle.common.ShuffleDataDistributionType;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.config.ConfigOption;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.Constants;
import org.apache.uniffle.common.util.ExpiringCloseableSupplier;
import org.apache.uniffle.common.util.RetryUtils;
import org.apache.uniffle.shaded.com.google.common.annotations.VisibleForTesting;
import org.apache.uniffle.shaded.com.google.common.collect.Maps;
import org.apache.uniffle.shaded.com.google.common.collect.Sets;
import org.apache.uniffle.shaded.org.apache.commons.collections4.CollectionUtils;
import org.apache.uniffle.shuffle.BlockIdManager;
import org.apache.uniffle.shuffle.manager.RssShuffleManagerInterface;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

public abstract class RssShuffleManagerBase
implements RssShuffleManagerInterface,
ShuffleManager {
    private static final Logger LOG = LoggerFactory.getLogger(RssShuffleManagerBase.class);
    private AtomicBoolean isInitialized = new AtomicBoolean(false);
    private Method unregisterAllMapOutputMethod;
    private Method registerShuffleMethod;
    private volatile BlockIdManager blockIdManager;
    protected ShuffleDataDistributionType dataDistributionType;
    private Object blockIdManagerLock = new Object();
    protected AtomicReference<String> id = new AtomicReference();
    protected String appId = "";
    protected ShuffleWriteClient shuffleWriteClient;
    protected boolean dynamicConfEnabled;
    protected int maxConcurrencyPerPartitionToWrite;
    protected String clientType;
    protected SparkConf sparkConf;
    protected Supplier<ShuffleManagerClient> managerClientSupplier;
    protected boolean rssStageRetryEnabled;
    protected boolean rssStageRetryForWriteFailureEnabled;
    protected boolean rssStageRetryForFetchFailureEnabled;
    protected ShuffleHandleInfoManager shuffleHandleInfoManager;
    protected RssStageResubmitManager rssStageResubmitManager;
    protected int partitionReassignMaxServerNum;
    protected boolean blockIdSelfManagedEnabled;
    protected boolean partitionReassignEnabled;
    protected boolean shuffleManagerRpcServiceEnabled;

    public RssShuffleManagerBase() {
        LOG.info("Uniffle {} version: {}", (Object)this.getClass().getName(), (Object)Constants.VERSION_AND_REVISION_SHORT);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public BlockIdManager getBlockIdManager() {
        if (this.blockIdManager == null) {
            Object object = this.blockIdManagerLock;
            synchronized (object) {
                if (this.blockIdManager == null) {
                    this.blockIdManager = new BlockIdManager();
                    LOG.info("BlockId manager has been initialized.");
                }
            }
        }
        return this.blockIdManager;
    }

    public boolean unregisterShuffle(int shuffleId) {
        if (this.blockIdManager != null) {
            this.blockIdManager.remove(shuffleId);
        }
        return true;
    }

    public abstract void configureBlockIdLayout(SparkConf var1, RssConf var2);

    @VisibleForTesting
    protected static void configureBlockIdLayout(SparkConf sparkConf, RssConf rssConf, int maxFailures, boolean speculation) {
        if (sparkConf.contains(RssSparkConfig.RSS_MAX_PARTITIONS.key())) {
            RssShuffleManagerBase.configureBlockIdLayoutFromMaxPartitions(sparkConf, rssConf, maxFailures, speculation);
        } else {
            RssShuffleManagerBase.configureBlockIdLayoutFromLayoutConfig(sparkConf, rssConf, maxFailures, speculation);
        }
    }

    private static void configureBlockIdLayoutFromMaxPartitions(SparkConf sparkConf, RssConf rssConf, int maxFailures, boolean speculation) {
        int maxPartitions = sparkConf.getInt(RssSparkConfig.RSS_MAX_PARTITIONS.key(), ((Integer)RssSparkConfig.RSS_MAX_PARTITIONS.defaultValue().get()).intValue());
        if (maxPartitions <= 1) {
            throw new IllegalArgumentException("Value of " + RssSparkConfig.RSS_MAX_PARTITIONS.key() + " must be larger than 1: " + maxPartitions);
        }
        int attemptIdBits = ClientUtils.getNumberOfSignificantBits(ClientUtils.getMaxAttemptNo(maxFailures, speculation));
        int partitionIdBits = ClientUtils.getNumberOfSignificantBits(maxPartitions - 1);
        int taskAttemptIdBits = partitionIdBits + attemptIdBits;
        int sequenceNoBits = 63 - partitionIdBits - taskAttemptIdBits;
        if (taskAttemptIdBits > 31) {
            throw new IllegalArgumentException("Cannot support " + RssSparkConfig.RSS_MAX_PARTITIONS.key() + "=" + maxPartitions + " partitions, as this would require to reserve more than 31 bits in the block id for task attempt ids. With spark.maxFailures=" + maxFailures + " and spark.speculation=" + (speculation ? "true" : "false") + " at most " + (1 << 31 - attemptIdBits) + " partitions can be supported.");
        }
        if (sequenceNoBits > 31) {
            int spareBits = sequenceNoBits - 31;
            spareBits += spareBits % 2;
            taskAttemptIdBits += spareBits / 2;
            maxPartitions = 1 << (partitionIdBits += spareBits / 2);
            if (LOG.isInfoEnabled()) {
                LOG.info("Increasing " + RssSparkConfig.RSS_MAX_PARTITIONS.key() + " to " + maxPartitions + ", otherwise we would have to support 2^" + sequenceNoBits + " (more than 2^31) sequence numbers.");
            }
            sequenceNoBits -= spareBits;
            sparkConf.set(RssSparkConfig.RSS_MAX_PARTITIONS.key(), String.valueOf(maxPartitions));
        }
        rssConf.set(RssClientConf.BLOCKID_SEQUENCE_NO_BITS, sequenceNoBits);
        rssConf.set(RssClientConf.BLOCKID_PARTITION_ID_BITS, partitionIdBits);
        rssConf.set(RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS, taskAttemptIdBits);
        sparkConf.set("spark." + RssClientConf.BLOCKID_SEQUENCE_NO_BITS.key(), String.valueOf(sequenceNoBits));
        sparkConf.set("spark." + RssClientConf.BLOCKID_PARTITION_ID_BITS.key(), String.valueOf(partitionIdBits));
        sparkConf.set("spark." + RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS.key(), String.valueOf(taskAttemptIdBits));
    }

    private static void configureBlockIdLayoutFromLayoutConfig(SparkConf sparkConf, RssConf rssConf, int maxFailures, boolean speculation) {
        String sparkPrefix = "spark.";
        String sparkSeqNoBitsKey = sparkPrefix + RssClientConf.BLOCKID_SEQUENCE_NO_BITS.key();
        String sparkPartIdBitsKey = sparkPrefix + RssClientConf.BLOCKID_PARTITION_ID_BITS.key();
        String sparkTaskIdBitsKey = sparkPrefix + RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS.key();
        List<String> sparkKeys = Arrays.asList(sparkSeqNoBitsKey, sparkPartIdBitsKey, sparkTaskIdBitsKey);
        if (sparkKeys.stream().anyMatch(arg_0 -> ((SparkConf)sparkConf).contains(arg_0))) {
            if (!sparkKeys.stream().allMatch(arg_0 -> ((SparkConf)sparkConf).contains(arg_0))) {
                String allKeys = sparkKeys.stream().collect(Collectors.joining(", "));
                String existingKeys = Arrays.stream(sparkConf.getAll()).map(t2 -> (String)t2._1).filter(sparkKeys.stream().collect(Collectors.toSet())::contains).collect(Collectors.joining(", "));
                throw new IllegalArgumentException("All block id bit config keys must be provided (" + allKeys + "), not just a sub-set: " + existingKeys);
            }
        }
        List<ConfigOption> rssKeys = Arrays.asList(RssClientConf.BLOCKID_SEQUENCE_NO_BITS, RssClientConf.BLOCKID_PARTITION_ID_BITS, RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS);
        if (rssKeys.stream().anyMatch(rssConf::contains)) {
            if (!rssKeys.stream().allMatch(rssConf::contains)) {
                String allKeys = rssKeys.stream().map(ConfigOption::key).collect(Collectors.joining(", "));
                String existingKeys = rssConf.getKeySet().stream().filter(rssKeys.stream().map(ConfigOption::key).collect(Collectors.toSet())::contains).collect(Collectors.joining(", "));
                throw new IllegalArgumentException("All block id bit config keys must be provided (" + allKeys + "), not just a sub-set: " + existingKeys);
            }
        }
        if (sparkKeys.stream().allMatch(arg_0 -> ((SparkConf)sparkConf).contains(arg_0))) {
            rssConf.set(RssClientConf.BLOCKID_SEQUENCE_NO_BITS, sparkConf.getInt(sparkSeqNoBitsKey, 0));
            rssConf.set(RssClientConf.BLOCKID_PARTITION_ID_BITS, sparkConf.getInt(sparkPartIdBitsKey, 0));
            rssConf.set(RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS, sparkConf.getInt(sparkTaskIdBitsKey, 0));
        } else if (rssKeys.stream().allMatch(rssConf::contains)) {
            sparkConf.set(sparkSeqNoBitsKey, rssConf.getValue(RssClientConf.BLOCKID_SEQUENCE_NO_BITS));
            sparkConf.set(sparkPartIdBitsKey, rssConf.getValue(RssClientConf.BLOCKID_PARTITION_ID_BITS));
            sparkConf.set(sparkTaskIdBitsKey, rssConf.getValue(RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS));
        } else {
            sparkConf.set(RssSparkConfig.RSS_MAX_PARTITIONS.key(), RssSparkConfig.RSS_MAX_PARTITIONS.defaultValueString());
            RssShuffleManagerBase.configureBlockIdLayoutFromMaxPartitions(sparkConf, rssConf, maxFailures, speculation);
        }
    }

    public abstract long getTaskAttemptIdForBlockId(int var1, int var2);

    protected static long getTaskAttemptIdForBlockId(int mapIndex, int attemptNo, int maxFailures, boolean speculation, int maxTaskAttemptIdBits) {
        int maxAttemptNo = ClientUtils.getMaxAttemptNo(maxFailures, speculation);
        int attemptBits = ClientUtils.getNumberOfSignificantBits(maxAttemptNo);
        if (attemptNo > maxAttemptNo) {
            throw new RssException("Observing attempt number " + attemptNo + " while maxFailures is set to " + maxFailures + (speculation ? " with speculation enabled" : "") + ".");
        }
        int mapIndexBits = ClientUtils.getNumberOfSignificantBits(mapIndex);
        if (mapIndexBits + attemptBits > maxTaskAttemptIdBits) {
            throw new RssException("Observing mapIndex[" + mapIndex + "] that would produce a taskAttemptId with " + (mapIndexBits + attemptBits) + " bits which is larger than the allowed " + maxTaskAttemptIdBits + " bits (maxFailures[" + maxFailures + "], speculation[" + speculation + "]). Please consider providing more bits for taskAttemptIds.");
        }
        return (long)mapIndex << attemptBits | (long)attemptNo;
    }

    protected static void fetchAndApplyDynamicConf(SparkConf sparkConf) {
        String user;
        String clientType = (String)sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE);
        String coordinators = sparkConf.get(RssSparkConfig.RSS_COORDINATOR_QUORUM.key());
        long retryIntervalMs = (Long)sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_INTERVAL_MAX);
        int retryTimes = (Integer)sparkConf.get(RssSparkConfig.RSS_CLIENT_RETRY_MAX);
        int heartbeatThread = (Integer)sparkConf.get(RssSparkConfig.RSS_CLIENT_HEARTBEAT_THREAD_NUM);
        CoordinatorClientFactory coordinatorClientFactory = CoordinatorClientFactory.getInstance();
        CoordinatorGrpcRetryableClient coordinatorClient = coordinatorClientFactory.createCoordinatorClient(ClientType.valueOf(clientType), coordinators, retryIntervalMs, retryTimes, heartbeatThread);
        int timeoutMs = sparkConf.getInt(RssSparkConfig.RSS_ACCESS_TIMEOUT_MS.key(), ((Integer)RssSparkConfig.RSS_ACCESS_TIMEOUT_MS.defaultValue().get()).intValue());
        try {
            user = UserGroupInformation.getCurrentUser().getShortUserName();
        }
        catch (Exception e) {
            throw new RssException("Errors on getting current user.", e);
        }
        RssFetchClientConfRequest request = new RssFetchClientConfRequest(timeoutMs, user, Collections.emptyMap());
        RssFetchClientConfResponse response = coordinatorClient.fetchClientConf(request);
        if (response.getStatusCode() == StatusCode.SUCCESS) {
            RssSparkShuffleUtils.applyDynamicClientConf(sparkConf, response.getClientConf());
        }
        coordinatorClient.close();
    }

    @Override
    public void unregisterAllMapOutput(int shuffleId) throws SparkException {
        if (!RssSparkShuffleUtils.isStageResubmitSupported()) {
            return;
        }
        MapOutputTrackerMaster tracker = RssShuffleManagerBase.getMapOutputTrackerMaster();
        if (this.isInitialized.compareAndSet(false, true)) {
            this.unregisterAllMapOutputMethod = RssShuffleManagerBase.getUnregisterAllMapOutputMethod(tracker);
            this.registerShuffleMethod = RssShuffleManagerBase.getRegisterShuffleMethod(tracker);
        }
        if (this.unregisterAllMapOutputMethod != null) {
            try {
                this.unregisterAllMapOutputMethod.invoke((Object)tracker, shuffleId);
            }
            catch (IllegalAccessException | InvocationTargetException e) {
                throw new RssException("Invoke unregisterAllMapOutput method failed", e);
            }
        } else {
            int numMaps = this.getNumMaps(shuffleId);
            int numReduces = this.getPartitionNum(shuffleId);
            RssShuffleManagerBase.defaultUnregisterAllMapOutput(tracker, this.registerShuffleMethod, shuffleId, numMaps, numReduces);
        }
    }

    private static void defaultUnregisterAllMapOutput(MapOutputTrackerMaster tracker, Method registerShuffle, int shuffleId, int numMaps, int numReduces) throws SparkException {
        if (tracker != null && registerShuffle != null) {
            tracker.unregisterShuffle(shuffleId);
            try {
                if (SparkVersionUtils.MAJOR_VERSION > 3 || SparkVersionUtils.isSpark3() && SparkVersionUtils.MINOR_VERSION >= 2) {
                    registerShuffle.invoke((Object)tracker, shuffleId, numMaps, numReduces);
                }
                registerShuffle.invoke((Object)tracker, shuffleId, numMaps);
            }
            catch (IllegalAccessException | InvocationTargetException e) {
                throw new RssException("Invoke registerShuffle method failed", e);
            }
        } else {
            throw new SparkException("default unregisterAllMapOutput should only be called on the driver side");
        }
        tracker.incrementEpoch();
    }

    private static Method getUnregisterAllMapOutputMethod(MapOutputTrackerMaster tracker) {
        if (tracker != null) {
            Class<?> klass = tracker.getClass();
            Method m4 = null;
            try {
                if (SparkVersionUtils.isSpark2() && SparkVersionUtils.MINOR_VERSION <= 3) {
                    LOG.warn("Spark version <= 2.3, fallback to default method");
                } else if (SparkVersionUtils.isSpark2()) {
                    m4 = klass.getDeclaredMethod("unregisterAllMapOutput", Integer.TYPE);
                } else if (SparkVersionUtils.isSpark3() && SparkVersionUtils.MINOR_VERSION <= 1) {
                    m4 = klass.getDeclaredMethod("unregisterAllMapOutput", Integer.TYPE);
                } else if (SparkVersionUtils.isSpark3()) {
                    m4 = klass.getDeclaredMethod("unregisterAllMapAndMergeOutput", Integer.TYPE);
                } else {
                    LOG.warn("Unknown spark version({}), fallback to default method", (Object)SparkVersionUtils.SPARK_VERSION);
                }
            }
            catch (NoSuchMethodException e) {
                LOG.warn("Got no such method error when get unregisterAllMapOutput method for spark version({})", (Object)SparkVersionUtils.SPARK_VERSION);
            }
            return m4;
        }
        return null;
    }

    private static Method getRegisterShuffleMethod(MapOutputTrackerMaster tracker) {
        if (tracker != null) {
            Class<?> klass = tracker.getClass();
            Method m4 = null;
            try {
                m4 = SparkVersionUtils.MAJOR_VERSION > 3 || SparkVersionUtils.isSpark3() && SparkVersionUtils.MINOR_VERSION >= 2 ? klass.getDeclaredMethod("registerShuffle", Integer.TYPE, Integer.TYPE, Integer.TYPE) : klass.getDeclaredMethod("registerShuffle", Integer.TYPE, Integer.TYPE);
            }
            catch (NoSuchMethodException e) {
                LOG.warn("Got no such method error when get registerShuffle method for spark version({})", (Object)SparkVersionUtils.SPARK_VERSION);
            }
            return m4;
        }
        return null;
    }

    private static MapOutputTrackerMaster getMapOutputTrackerMaster() {
        MapOutputTracker tracker = Optional.ofNullable(SparkEnv.get()).map(SparkEnv::mapOutputTracker).orElse(null);
        return tracker instanceof MapOutputTrackerMaster ? (MapOutputTrackerMaster)tracker : null;
    }

    private static Map<String, String> parseRemoteStorageConf(Configuration conf) {
        HashMap<String, String> confItems = Maps.newHashMap();
        for (Map.Entry entry : conf) {
            confItems.put((String)entry.getKey(), (String)entry.getValue());
        }
        return confItems;
    }

    protected static RemoteStorageInfo getDefaultRemoteStorageInfo(SparkConf sparkConf) {
        HashMap<String, String> confItems = Maps.newHashMap();
        RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
        if (rssConf.getBoolean(RssClientConf.RSS_CLIENT_REMOTE_STORAGE_USE_LOCAL_CONF_ENABLED)) {
            confItems = RssShuffleManagerBase.parseRemoteStorageConf(new Configuration(true));
        }
        for (String key : rssConf.getKeySet()) {
            String val;
            if (!key.startsWith("rss.hadoop.") || (val = rssConf.getString(key, null)) == null) continue;
            String extractedKey = key.replaceFirst("rss.hadoop.", "");
            confItems.put(extractedKey, val);
        }
        return new RemoteStorageInfo(sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), ""), confItems);
    }

    public ShuffleHandleInfo getShuffleHandleInfo(RssShuffleHandle<?, ?, ?> rssHandle) {
        int shuffleId = rssHandle.getShuffleId();
        if (this.shuffleManagerRpcServiceEnabled && this.rssStageRetryEnabled) {
            return this.getRemoteShuffleHandleInfoWithStageRetry(shuffleId);
        }
        if (this.shuffleManagerRpcServiceEnabled && this.partitionReassignEnabled) {
            return this.getRemoteShuffleHandleInfoWithBlockRetry(shuffleId);
        }
        return new SimpleShuffleHandleInfo(shuffleId, rssHandle.getPartitionToServers(), rssHandle.getRemoteStorage());
    }

    protected synchronized StageAttemptShuffleHandleInfo getRemoteShuffleHandleInfoWithStageRetry(int shuffleId) {
        RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest = new RssPartitionToShuffleServerRequest(shuffleId);
        RssReassignOnStageRetryResponse rpcPartitionToShufflerServer = this.getOrCreateShuffleManagerClientSupplier().get().getPartitionToShufflerServerWithStageRetry(rssPartitionToShuffleServerRequest);
        StageAttemptShuffleHandleInfo shuffleHandleInfo = StageAttemptShuffleHandleInfo.fromProto(rpcPartitionToShufflerServer.getShuffleHandleInfoProto());
        return shuffleHandleInfo;
    }

    protected synchronized MutableShuffleHandleInfo getRemoteShuffleHandleInfoWithBlockRetry(int shuffleId) {
        RssPartitionToShuffleServerRequest rssPartitionToShuffleServerRequest = new RssPartitionToShuffleServerRequest(shuffleId);
        RssReassignOnBlockSendFailureResponse rpcPartitionToShufflerServer = this.getOrCreateShuffleManagerClientSupplier().get().getPartitionToShufflerServerWithBlockRetry(rssPartitionToShuffleServerRequest);
        MutableShuffleHandleInfo shuffleHandleInfo = MutableShuffleHandleInfo.fromProto(rpcPartitionToShufflerServer.getHandle());
        return shuffleHandleInfo;
    }

    protected synchronized Supplier<ShuffleManagerClient> getOrCreateShuffleManagerClientSupplier() {
        if (this.managerClientSupplier == null) {
            RssConf rssConf = RssSparkConfig.toRssConf(this.sparkConf);
            String driver = rssConf.getString("driver.host", "");
            int port = rssConf.get(RssClientConf.SHUFFLE_MANAGER_GRPC_PORT);
            long rpcTimeout = rssConf.getLong(RssClientConf.RPC_TIMEOUT_MS);
            this.managerClientSupplier = ExpiringCloseableSupplier.of(() -> ShuffleManagerClientFactory.getInstance().createShuffleManagerClient(ClientType.GRPC, driver, port, rpcTimeout));
        }
        return this.managerClientSupplier;
    }

    @Override
    public ShuffleHandleInfo getShuffleHandleInfoByShuffleId(int shuffleId) {
        return this.shuffleHandleInfoManager.get(shuffleId);
    }

    @Override
    public int getMaxFetchFailures() {
        String TASK_MAX_FAILURE = "spark.task.maxFailures";
        return Math.max(0, this.sparkConf.getInt("spark.task.maxFailures", 4) - 1);
    }

    @Override
    public void addFailuresShuffleServerInfos(String shuffleServerId) {
        this.rssStageResubmitManager.recordFailuresShuffleServer(shuffleServerId);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public boolean reassignOnStageResubmit(int stageId, int stageAttemptNumber, int shuffleId, int numPartitions) {
        RssStageInfo rssStageInfo;
        String stageIdAndAttempt = stageId + "_" + stageAttemptNumber;
        RssStageInfo rssStageInfo2 = rssStageInfo = this.rssStageResubmitManager.recordAndGetServerAssignedInfo(shuffleId, stageIdAndAttempt);
        synchronized (rssStageInfo2) {
            Boolean needReassign = rssStageInfo.isReassigned();
            if (!needReassign.booleanValue()) {
                int requiredShuffleServerNumber = RssSparkShuffleUtils.getRequiredShuffleServerNumber(this.sparkConf);
                int estimateTaskConcurrency = RssSparkShuffleUtils.estimateTaskConcurrency(this.sparkConf);
                Map<Integer, List<ShuffleServerInfo>> partitionToServers = this.requestShuffleAssignment(shuffleId, numPartitions, 1, requiredShuffleServerNumber, estimateTaskConcurrency, this.rssStageResubmitManager.getServerIdBlackList(), stageId, stageAttemptNumber, false);
                try {
                    this.unregisterAllMapOutput(shuffleId);
                }
                catch (SparkException e) {
                    LOG.error("Clear MapoutTracker Meta failed!");
                    throw new RssException("Clear MapoutTracker Meta failed!", e);
                }
                MutableShuffleHandleInfo shuffleHandleInfo = new MutableShuffleHandleInfo(shuffleId, partitionToServers, this.getRemoteStorageInfo());
                StageAttemptShuffleHandleInfo stageAttemptShuffleHandleInfo = (StageAttemptShuffleHandleInfo)this.shuffleHandleInfoManager.get(shuffleId);
                stageAttemptShuffleHandleInfo.replaceCurrentShuffleHandleInfo(shuffleHandleInfo);
                this.rssStageResubmitManager.recordAndGetServerAssignedInfo(shuffleId, stageIdAndAttempt, true);
                LOG.info("The stage retry has been triggered successfully for the stageId: {}, attemptNumber: {}", (Object)stageId, (Object)stageAttemptNumber);
                return true;
            }
            LOG.info("Do nothing that the stage: {} has been reassigned for attempt{}", (Object)stageId, (Object)stageAttemptNumber);
            return false;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public MutableShuffleHandleInfo reassignOnBlockSendFailure(int stageId, int stageAttemptNumber, int shuffleId, Map<Integer, List<ReceivingFailureServer>> partitionToFailureServers, boolean partitionSplit) {
        long startTime = System.currentTimeMillis();
        ShuffleHandleInfo handleInfo = this.shuffleHandleInfoManager.get(shuffleId);
        MutableShuffleHandleInfo internalHandle = null;
        if (handleInfo instanceof MutableShuffleHandleInfo) {
            internalHandle = (MutableShuffleHandleInfo)handleInfo;
        } else if (handleInfo instanceof StageAttemptShuffleHandleInfo) {
            internalHandle = (MutableShuffleHandleInfo)((StageAttemptShuffleHandleInfo)handleInfo).getCurrent();
        }
        if (internalHandle == null) {
            throw new RssException("An unexpected error occurred: internalHandle is null, which should not happen");
        }
        MutableShuffleHandleInfo mutableShuffleHandleInfo = internalHandle;
        synchronized (mutableShuffleHandleInfo) {
            if (!partitionSplit) {
                internalHandle.checkPartitionReassignServerNum(partitionToFailureServers.keySet(), this.partitionReassignMaxServerNum);
            }
            HashMap<ShuffleServerInfo, List<PartitionRange>> newServerToPartitions = new HashMap<ShuffleServerInfo, List<PartitionRange>>();
            HashMap<String, Map> reassignResult = new HashMap<String, Map>();
            for (Map.Entry<Integer, List<ReceivingFailureServer>> entry : partitionToFailureServers.entrySet()) {
                int partitionId = entry.getKey();
                for (ReceivingFailureServer receivingFailureServer : entry.getValue()) {
                    Set<ShuffleServerInfo> updatedReassignServers;
                    Set<ShuffleServerInfo> replacements;
                    StatusCode code = receivingFailureServer.getStatusCode();
                    String serverId = receivingFailureServer.getServerId();
                    boolean serverHasReplaced = false;
                    if (!partitionSplit) {
                        replacements = internalHandle.getReplacements(serverId);
                        if (CollectionUtils.isEmpty(replacements)) {
                            replacements = this.requestReassignServer(stageId, stageAttemptNumber, shuffleId, internalHandle, partitionId, serverId);
                        } else {
                            serverHasReplaced = true;
                        }
                        updatedReassignServers = internalHandle.updateAssignment(partitionId, serverId, replacements);
                    } else {
                        replacements = internalHandle.getReplacementsForPartition(partitionId, serverId);
                        if (CollectionUtils.isEmpty(replacements)) {
                            replacements = this.requestReassignServer(stageId, stageAttemptNumber, shuffleId, internalHandle, partitionId, serverId);
                        } else {
                            serverHasReplaced = true;
                        }
                        updatedReassignServers = internalHandle.updateAssignmentOnPartitionSplit(partitionId, serverId, replacements);
                    }
                    if (updatedReassignServers.isEmpty()) continue;
                    reassignResult.computeIfAbsent(serverId, x -> new HashMap()).computeIfAbsent(partitionId, x -> new HashSet()).addAll(updatedReassignServers.stream().map(x -> x.getId()).collect(Collectors.toSet()));
                    if (!serverHasReplaced) continue;
                    for (ShuffleServerInfo serverInfo : updatedReassignServers) {
                        newServerToPartitions.computeIfAbsent(serverInfo, x -> new ArrayList()).add(new PartitionRange(partitionId, partitionId));
                    }
                }
            }
            if (!newServerToPartitions.isEmpty()) {
                LOG.info("Register the new partition->servers assignment on reassign. {}", newServerToPartitions);
                this.registerShuffleServers(this.id.get(), shuffleId, newServerToPartitions, this.getRemoteStorageInfo());
            }
            LOG.info("Finished reassignOnBlockSendFailure request and cost {}(ms). Reassign result: {}", (Object)(System.currentTimeMillis() - startTime), reassignResult);
            return internalHandle;
        }
    }

    private Set<ShuffleServerInfo> requestReassignServer(int stageId, int stageAttemptNumber, int shuffleId, MutableShuffleHandleInfo internalHandle, int partitionId, String serverId) {
        boolean requiredServerNum = true;
        HashSet<String> excludedServers = new HashSet<String>(internalHandle.listExcludedServers());
        excludedServers.addAll(internalHandle.listExcludedServersForPartition(partitionId));
        excludedServers.add(serverId);
        Set<ShuffleServerInfo> replacements = this.reassignServerForTask(stageId, stageAttemptNumber, shuffleId, Sets.newHashSet(partitionId), excludedServers, 1, true);
        return replacements;
    }

    public void stop() {
        if (this.managerClientSupplier != null && this.managerClientSupplier instanceof ExpiringCloseableSupplier) {
            ((ExpiringCloseableSupplier)this.managerClientSupplier).close();
        }
    }

    private ShuffleAssignmentsInfo createShuffleAssignmentsInfo(Set<ShuffleServerInfo> servers, Set<Integer> partitionIds) {
        HashMap<Integer, List<ShuffleServerInfo>> newPartitionToServers = new HashMap<Integer, List<ShuffleServerInfo>>();
        ArrayList<PartitionRange> partitionRanges = new ArrayList<PartitionRange>();
        for (Integer partitionId : partitionIds) {
            newPartitionToServers.put(partitionId, new ArrayList<ShuffleServerInfo>(servers));
            partitionRanges.add(new PartitionRange(partitionId, partitionId));
        }
        HashMap<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges = new HashMap<ShuffleServerInfo, List<PartitionRange>>();
        for (ShuffleServerInfo server : servers) {
            serverToPartitionRanges.put(server, partitionRanges);
        }
        return new ShuffleAssignmentsInfo(newPartitionToServers, serverToPartitionRanges);
    }

    private Set<ShuffleServerInfo> reassignServerForTask(int stageId, int stageAttemptNumber, int shuffleId, Set<Integer> partitionIds, Set<String> excludedServers, int requiredServerNum, boolean reassign) {
        AtomicReference replacementsRef = new AtomicReference(new HashSet());
        this.requestShuffleAssignment(shuffleId, requiredServerNum, 1, requiredServerNum, 1, excludedServers, shuffleAssignmentsInfo -> {
            if (shuffleAssignmentsInfo == null) {
                return null;
            }
            Set<ShuffleServerInfo> replacements = shuffleAssignmentsInfo.getPartitionToServers().values().stream().flatMap(x -> x.stream()).collect(Collectors.toSet());
            replacementsRef.set(replacements);
            return this.createShuffleAssignmentsInfo(replacements, partitionIds);
        }, stageId, stageAttemptNumber, reassign);
        return replacementsRef.get();
    }

    private Map<Integer, List<ShuffleServerInfo>> requestShuffleAssignment(int shuffleId, int partitionNum, int partitionNumPerRange, int assignmentShuffleServerNumber, int estimateTaskConcurrency, Set<String> faultyServerIds, Function<ShuffleAssignmentsInfo, ShuffleAssignmentsInfo> reassignmentHandler, int stageId, int stageAttemptNumber, boolean reassign) {
        Set<String> assignmentTags = RssSparkShuffleUtils.getAssignmentTags(this.sparkConf);
        ClientUtils.validateClientType(this.clientType);
        assignmentTags.add(this.clientType);
        long retryInterval = (Long)this.sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
        int retryTimes = (Integer)this.sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
        faultyServerIds.addAll(this.rssStageResubmitManager.getServerIdBlackList());
        try {
            ShuffleAssignmentsInfo response = this.shuffleWriteClient.getShuffleAssignments(this.id.get(), shuffleId, partitionNum, partitionNumPerRange, assignmentTags, assignmentShuffleServerNumber, estimateTaskConcurrency, faultyServerIds, stageId, stageAttemptNumber, reassign, retryInterval, retryTimes);
            LOG.info("Finished reassign");
            if (reassignmentHandler != null) {
                response = reassignmentHandler.apply(response);
            }
            this.registerShuffleServers(this.id.get(), shuffleId, response.getServerToPartitionRanges(), this.getRemoteStorageInfo());
            return response.getPartitionToServers();
        }
        catch (Throwable throwable) {
            throw new RssException("registerShuffle failed!", throwable);
        }
    }

    protected Map<Integer, List<ShuffleServerInfo>> requestShuffleAssignment(int shuffleId, int partitionNum, int partitionNumPerRange, int assignmentShuffleServerNumber, int estimateTaskConcurrency, Set<String> faultyServerIds, int stageId, int stageAttemptNumber, boolean reassign) {
        Set<String> assignmentTags = RssSparkShuffleUtils.getAssignmentTags(this.sparkConf);
        ClientUtils.validateClientType(this.clientType);
        assignmentTags.add(this.clientType);
        long retryInterval = (Long)this.sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_INTERVAL);
        int retryTimes = (Integer)this.sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_RETRY_TIMES);
        faultyServerIds.addAll(this.rssStageResubmitManager.getServerIdBlackList());
        try {
            return RetryUtils.retry(() -> {
                ShuffleAssignmentsInfo response = this.shuffleWriteClient.getShuffleAssignments(this.appId, shuffleId, partitionNum, partitionNumPerRange, assignmentTags, assignmentShuffleServerNumber, estimateTaskConcurrency, faultyServerIds, stageId, stageAttemptNumber, reassign, 0L, 0);
                this.registerShuffleServers(this.appId, shuffleId, response.getServerToPartitionRanges(), this.getRemoteStorageInfo(), stageAttemptNumber);
                return response.getPartitionToServers();
            }, retryInterval, retryTimes);
        }
        catch (Throwable throwable) {
            throw new RssException("getShuffleAssignments or registerShuffle failed!", throwable);
        }
    }

    protected Map<Integer, List<ShuffleServerInfo>> requestShuffleAssignment(int shuffleId, int partitionNum, int partitionNumPerRange, int assignmentShuffleServerNumber, int estimateTaskConcurrency, Set<String> faultyServerIds, int stageAttemptNumber) {
        return this.requestShuffleAssignment(shuffleId, partitionNum, partitionNumPerRange, assignmentShuffleServerNumber, estimateTaskConcurrency, faultyServerIds, -1, stageAttemptNumber, false);
    }

    protected void registerShuffleServers(String appId, int shuffleId, Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges, RemoteStorageInfo remoteStorage, int stageAttemptNumber) {
        if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) {
            return;
        }
        LOG.info("Start to register shuffleId {}", (Object)shuffleId);
        long start = System.currentTimeMillis();
        Map<String, String> sparkConfMap = this.sparkConfToMap(this.getSparkConf());
        serverToPartitionRanges.entrySet().stream().forEach(entry -> this.shuffleWriteClient.registerShuffle((ShuffleServerInfo)entry.getKey(), appId, shuffleId, (List)entry.getValue(), remoteStorage, ShuffleDataDistributionType.NORMAL, this.maxConcurrencyPerPartitionToWrite, stageAttemptNumber, null, sparkConfMap));
        LOG.info("Finish register shuffleId {} with {} ms", (Object)shuffleId, (Object)(System.currentTimeMillis() - start));
    }

    @VisibleForTesting
    protected void registerShuffleServers(String appId, int shuffleId, Map<ShuffleServerInfo, List<PartitionRange>> serverToPartitionRanges, RemoteStorageInfo remoteStorage) {
        if (serverToPartitionRanges == null || serverToPartitionRanges.isEmpty()) {
            return;
        }
        LOG.info("Start to register shuffleId[{}]", (Object)shuffleId);
        long start = System.currentTimeMillis();
        Map<String, String> sparkConfMap = this.sparkConfToMap(this.getSparkConf());
        Set<Map.Entry<ShuffleServerInfo, List<PartitionRange>>> entries = serverToPartitionRanges.entrySet();
        entries.stream().forEach(entry -> this.shuffleWriteClient.registerShuffle((ShuffleServerInfo)entry.getKey(), appId, shuffleId, (List)entry.getValue(), remoteStorage, this.dataDistributionType, this.maxConcurrencyPerPartitionToWrite, sparkConfMap));
        LOG.info("Finish register shuffleId[{}] with {} ms", (Object)shuffleId, (Object)(System.currentTimeMillis() - start));
    }

    protected RemoteStorageInfo getRemoteStorageInfo() {
        String storageType = this.sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key());
        RemoteStorageInfo defaultRemoteStorage = new RemoteStorageInfo(this.sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), ""));
        return ClientUtils.fetchRemoteStorage(this.appId, defaultRemoteStorage, this.dynamicConfEnabled, storageType, this.shuffleWriteClient);
    }

    public boolean isRssStageRetryEnabled() {
        return this.rssStageRetryEnabled;
    }

    public boolean isRssStageRetryForWriteFailureEnabled() {
        return this.rssStageRetryForWriteFailureEnabled;
    }

    public boolean isRssStageRetryForFetchFailureEnabled() {
        return this.rssStageRetryForFetchFailureEnabled;
    }

    @VisibleForTesting
    public SparkConf getSparkConf() {
        return this.sparkConf;
    }

    public Map<String, String> sparkConfToMap(SparkConf sparkConf) {
        HashMap<String, String> map = new HashMap<String, String>();
        for (Tuple2 tuple : sparkConf.getAll()) {
            String key = (String)tuple._1;
            map.put(key, (String)tuple._2);
        }
        return map;
    }
}

