package org.apache.uniffle.shuffle.manager;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.hadoop.conf.Configuration;
import org.apache.spark.MapOutputTracker;
import org.apache.spark.MapOutputTrackerMaster;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv;
import org.apache.spark.SparkException;
import org.apache.spark.shuffle.RssSparkConfig;
import org.apache.spark.shuffle.RssSparkShuffleUtils;
import org.apache.spark.shuffle.ShuffleManager;
import org.apache.spark.shuffle.SparkVersionUtils;
import org.apache.uniffle.client.api.CoordinatorClient;
import org.apache.uniffle.client.factory.CoordinatorClientFactory;
import org.apache.uniffle.client.request.RssFetchClientConfRequest;
import org.apache.uniffle.client.response.RssFetchClientConfResponse;
import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.config.RssClientConf;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.Constants;
import org.apache.uniffle.shaded.com.google.common.annotations.VisibleForTesting;
import org.apache.uniffle.shaded.com.google.common.collect.Maps;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/uniffle/shuffle/manager/RssShuffleManagerBase.class */
public abstract class RssShuffleManagerBase implements RssShuffleManagerInterface, ShuffleManager {
    private static final Logger LOG = LoggerFactory.getLogger(RssShuffleManagerBase.class);
    private AtomicBoolean isInitialized = new AtomicBoolean(false);
    private Method unregisterAllMapOutputMethod;
    private Method registerShuffleMethod;

    public abstract void configureBlockIdLayout(SparkConf sparkConf, RssConf rssConf);

    /* JADX INFO: Access modifiers changed from: protected */
    @VisibleForTesting
    public static void configureBlockIdLayout(SparkConf sparkConf, RssConf rssConf, int i, boolean z) {
        if (sparkConf.contains(RssSparkConfig.RSS_MAX_PARTITIONS.key())) {
            configureBlockIdLayoutFromMaxPartitions(sparkConf, rssConf, i, z);
        } else {
            configureBlockIdLayoutFromLayoutConfig(sparkConf, rssConf, i, z);
        }
    }

    private static void configureBlockIdLayoutFromMaxPartitions(SparkConf sparkConf, RssConf rssConf, int i, boolean z) {
        int i2 = sparkConf.getInt(RssSparkConfig.RSS_MAX_PARTITIONS.key(), ((Integer) RssSparkConfig.RSS_MAX_PARTITIONS.defaultValue().get()).intValue());
        if (i2 <= 1) {
            throw new IllegalArgumentException("Value of " + RssSparkConfig.RSS_MAX_PARTITIONS.key() + " must be larger than 1: " + i2);
        }
        int numberOfSignificantBits = ClientUtils.getNumberOfSignificantBits(ClientUtils.getMaxAttemptNo(i, z));
        int numberOfSignificantBits2 = ClientUtils.getNumberOfSignificantBits(i2 - 1);
        int i3 = numberOfSignificantBits2 + numberOfSignificantBits;
        int i4 = (63 - numberOfSignificantBits2) - i3;
        if (i3 > 31) {
            throw new IllegalArgumentException("Cannot support " + RssSparkConfig.RSS_MAX_PARTITIONS.key() + Constants.EQUAL_SPLIT_CHAR + i2 + " partitions, as this would require to reserve more than 31 bits in the block id for task attempt ids. With spark.maxFailures=" + i + " and spark.speculation=" + (z ? "true" : "false") + " at most " + (1 << (31 - numberOfSignificantBits)) + " partitions can be supported.");
        }
        if (i4 > 31) {
            int i5 = i4 - 31;
            int i6 = i5 + (i5 % 2);
            numberOfSignificantBits2 += i6 / 2;
            i3 += i6 / 2;
            int i7 = 1 << numberOfSignificantBits2;
            if (LOG.isInfoEnabled()) {
                LOG.info("Increasing " + RssSparkConfig.RSS_MAX_PARTITIONS.key() + " to " + i7 + ", otherwise we would have to support 2^" + i4 + " (more than 2^31) sequence numbers.");
            }
            i4 -= i6;
            sparkConf.set(RssSparkConfig.RSS_MAX_PARTITIONS.key(), String.valueOf(i7));
        }
        rssConf.set(RssClientConf.BLOCKID_SEQUENCE_NO_BITS, Integer.valueOf(i4));
        rssConf.set(RssClientConf.BLOCKID_PARTITION_ID_BITS, Integer.valueOf(numberOfSignificantBits2));
        rssConf.set(RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS, Integer.valueOf(i3));
        sparkConf.set(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConf.BLOCKID_SEQUENCE_NO_BITS.key(), String.valueOf(i4));
        sparkConf.set(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConf.BLOCKID_PARTITION_ID_BITS.key(), String.valueOf(numberOfSignificantBits2));
        sparkConf.set(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS.key(), String.valueOf(i3));
    }

    private static void configureBlockIdLayoutFromLayoutConfig(SparkConf sparkConf, RssConf rssConf, int i, boolean z) {
        String str = RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConf.BLOCKID_SEQUENCE_NO_BITS.key();
        String str2 = RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConf.BLOCKID_PARTITION_ID_BITS.key();
        String str3 = RssSparkConfig.SPARK_RSS_CONFIG_PREFIX + RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS.key();
        List asList = Arrays.asList(str, str2, str3);
        Stream stream = asList.stream();
        sparkConf.getClass();
        if (stream.anyMatch(sparkConf::contains)) {
            Stream stream2 = asList.stream();
            sparkConf.getClass();
            if (!stream2.allMatch(sparkConf::contains)) {
                String str4 = (String) asList.stream().collect(Collectors.joining(", "));
                Stream map = Arrays.stream(sparkConf.getAll()).map(tuple2 -> {
                    return (String) tuple2._1;
                });
                Set set = (Set) asList.stream().collect(Collectors.toSet());
                set.getClass();
                throw new IllegalArgumentException("All block id bit config keys must be provided (" + str4 + "), not just a sub-set: " + ((String) map.filter((v1) -> {
                    return r1.contains(v1);
                }).collect(Collectors.joining(", "))));
            }
        }
        List asList2 = Arrays.asList(RssClientConf.BLOCKID_SEQUENCE_NO_BITS, RssClientConf.BLOCKID_PARTITION_ID_BITS, RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS);
        Stream stream3 = asList2.stream();
        rssConf.getClass();
        if (stream3.anyMatch(rssConf::contains)) {
            Stream stream4 = asList2.stream();
            rssConf.getClass();
            if (!stream4.allMatch(rssConf::contains)) {
                String str5 = (String) asList2.stream().map((v0) -> {
                    return v0.key();
                }).collect(Collectors.joining(", "));
                Stream<String> stream5 = rssConf.getKeySet().stream();
                Set set2 = (Set) asList2.stream().map((v0) -> {
                    return v0.key();
                }).collect(Collectors.toSet());
                set2.getClass();
                throw new IllegalArgumentException("All block id bit config keys must be provided (" + str5 + "), not just a sub-set: " + ((String) stream5.filter((v1) -> {
                    return r1.contains(v1);
                }).collect(Collectors.joining(", "))));
            }
        }
        Stream stream6 = asList.stream();
        sparkConf.getClass();
        if (stream6.allMatch(sparkConf::contains)) {
            rssConf.set(RssClientConf.BLOCKID_SEQUENCE_NO_BITS, Integer.valueOf(sparkConf.getInt(str, 0)));
            rssConf.set(RssClientConf.BLOCKID_PARTITION_ID_BITS, Integer.valueOf(sparkConf.getInt(str2, 0)));
            rssConf.set(RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS, Integer.valueOf(sparkConf.getInt(str3, 0)));
            return;
        }
        Stream stream7 = asList2.stream();
        rssConf.getClass();
        if (!stream7.allMatch(rssConf::contains)) {
            sparkConf.set(RssSparkConfig.RSS_MAX_PARTITIONS.key(), RssSparkConfig.RSS_MAX_PARTITIONS.defaultValueString());
            configureBlockIdLayoutFromMaxPartitions(sparkConf, rssConf, i, z);
        } else {
            sparkConf.set(str, rssConf.getValue(RssClientConf.BLOCKID_SEQUENCE_NO_BITS));
            sparkConf.set(str2, rssConf.getValue(RssClientConf.BLOCKID_PARTITION_ID_BITS));
            sparkConf.set(str3, rssConf.getValue(RssClientConf.BLOCKID_TASK_ATTEMPT_ID_BITS));
        }
    }

    public abstract long getTaskAttemptIdForBlockId(int i, int i2);

    /* JADX INFO: Access modifiers changed from: protected */
    public static long getTaskAttemptIdForBlockId(int i, int i2, int i3, boolean z, int i4) {
        int maxAttemptNo = ClientUtils.getMaxAttemptNo(i3, z);
        int numberOfSignificantBits = ClientUtils.getNumberOfSignificantBits(maxAttemptNo);
        if (i2 > maxAttemptNo) {
            throw new RssException("Observing attempt number " + i2 + " while maxFailures is set to " + i3 + (z ? " with speculation enabled" : "") + ".");
        }
        int numberOfSignificantBits2 = ClientUtils.getNumberOfSignificantBits(i);
        if (numberOfSignificantBits2 + numberOfSignificantBits > i4) {
            throw new RssException("Observing mapIndex[" + i + "] that would produce a taskAttemptId with " + (numberOfSignificantBits2 + numberOfSignificantBits) + " bits which is larger than the allowed " + i4 + " bits (maxFailures[" + i3 + "], speculation[" + z + "]). Please consider providing more bits for taskAttemptIds.");
        }
        return (i << numberOfSignificantBits) | i2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static void fetchAndApplyDynamicConf(SparkConf sparkConf) {
        String str = (String) sparkConf.get(RssSparkConfig.RSS_CLIENT_TYPE);
        List<CoordinatorClient> createCoordinatorClient = CoordinatorClientFactory.getInstance().createCoordinatorClient(ClientType.valueOf(str), sparkConf.get(RssSparkConfig.RSS_COORDINATOR_QUORUM.key()));
        int i = sparkConf.getInt(RssSparkConfig.RSS_ACCESS_TIMEOUT_MS.key(), ((Integer) RssSparkConfig.RSS_ACCESS_TIMEOUT_MS.defaultValue().get()).intValue());
        Iterator<CoordinatorClient> it = createCoordinatorClient.iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            CoordinatorClient next = it.next();
            RssFetchClientConfResponse fetchClientConf = next.fetchClientConf(new RssFetchClientConfRequest(i));
            if (fetchClientConf.getStatusCode() == StatusCode.SUCCESS) {
                LOG.info("Success to get conf from {}", next.getDesc());
                RssSparkShuffleUtils.applyDynamicClientConf(sparkConf, fetchClientConf.getClientConf());
                break;
            }
            LOG.warn("Fail to get conf from {}", next.getDesc());
        }
        createCoordinatorClient.forEach((v0) -> {
            v0.close();
        });
    }

    @Override // org.apache.uniffle.shuffle.manager.RssShuffleManagerInterface
    public void unregisterAllMapOutput(int i) throws SparkException {
        if (RssSparkShuffleUtils.isStageResubmitSupported()) {
            MapOutputTrackerMaster mapOutputTrackerMaster = getMapOutputTrackerMaster();
            if (this.isInitialized.compareAndSet(false, true)) {
                this.unregisterAllMapOutputMethod = getUnregisterAllMapOutputMethod(mapOutputTrackerMaster);
                this.registerShuffleMethod = getRegisterShuffleMethod(mapOutputTrackerMaster);
            }
            if (this.unregisterAllMapOutputMethod == null) {
                defaultUnregisterAllMapOutput(mapOutputTrackerMaster, this.registerShuffleMethod, i, getNumMaps(i), getPartitionNum(i));
                return;
            }
            try {
                this.unregisterAllMapOutputMethod.invoke(mapOutputTrackerMaster, Integer.valueOf(i));
            } catch (IllegalAccessException | InvocationTargetException e) {
                throw new RssException("Invoke unregisterAllMapOutput method failed", e);
            }
        }
    }

    private static void defaultUnregisterAllMapOutput(MapOutputTrackerMaster mapOutputTrackerMaster, Method method, int i, int i2, int i3) throws SparkException {
        if (mapOutputTrackerMaster == null || method == null) {
            throw new SparkException("default unregisterAllMapOutput should only be called on the driver side");
        }
        mapOutputTrackerMaster.unregisterShuffle(i);
        try {
            if (SparkVersionUtils.MAJOR_VERSION > 3 || (SparkVersionUtils.isSpark3() && SparkVersionUtils.MINOR_VERSION >= 2)) {
                method.invoke(mapOutputTrackerMaster, Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(i3));
            } else {
                method.invoke(mapOutputTrackerMaster, Integer.valueOf(i), Integer.valueOf(i2));
            }
            mapOutputTrackerMaster.incrementEpoch();
        } catch (IllegalAccessException | InvocationTargetException e) {
            throw new RssException("Invoke registerShuffle method failed", e);
        }
    }

    private static Method getUnregisterAllMapOutputMethod(MapOutputTrackerMaster mapOutputTrackerMaster) {
        if (mapOutputTrackerMaster == null) {
            return null;
        }
        Class<?> cls = mapOutputTrackerMaster.getClass();
        Method method = null;
        try {
            if (SparkVersionUtils.isSpark2() && SparkVersionUtils.MINOR_VERSION <= 3) {
                LOG.warn("Spark version <= 2.3, fallback to default method");
            } else if (SparkVersionUtils.isSpark2()) {
                method = cls.getDeclaredMethod("unregisterAllMapOutput", Integer.TYPE);
            } else if (SparkVersionUtils.isSpark3() && SparkVersionUtils.MINOR_VERSION <= 1) {
                method = cls.getDeclaredMethod("unregisterAllMapOutput", Integer.TYPE);
            } else if (SparkVersionUtils.isSpark3()) {
                method = cls.getDeclaredMethod("unregisterAllMapAndMergeOutput", Integer.TYPE);
            } else {
                LOG.warn("Unknown spark version({}), fallback to default method", SparkVersionUtils.SPARK_VERSION);
            }
        } catch (NoSuchMethodException e) {
            LOG.warn("Got no such method error when get unregisterAllMapOutput method for spark version({})", SparkVersionUtils.SPARK_VERSION);
        }
        return method;
    }

    private static Method getRegisterShuffleMethod(MapOutputTrackerMaster mapOutputTrackerMaster) {
        if (mapOutputTrackerMaster == null) {
            return null;
        }
        Class<?> cls = mapOutputTrackerMaster.getClass();
        Method method = null;
        try {
            method = (SparkVersionUtils.MAJOR_VERSION > 3 || (SparkVersionUtils.isSpark3() && SparkVersionUtils.MINOR_VERSION >= 2)) ? cls.getDeclaredMethod("registerShuffle", Integer.TYPE, Integer.TYPE, Integer.TYPE) : cls.getDeclaredMethod("registerShuffle", Integer.TYPE, Integer.TYPE);
        } catch (NoSuchMethodException e) {
            LOG.warn("Got no such method error when get registerShuffle method for spark version({})", SparkVersionUtils.SPARK_VERSION);
        }
        return method;
    }

    private static MapOutputTrackerMaster getMapOutputTrackerMaster() {
        MapOutputTrackerMaster mapOutputTrackerMaster = (MapOutputTracker) Optional.ofNullable(SparkEnv.get()).map((v0) -> {
            return v0.mapOutputTracker();
        }).orElse(null);
        if (mapOutputTrackerMaster instanceof MapOutputTrackerMaster) {
            return mapOutputTrackerMaster;
        }
        return null;
    }

    private static Map<String, String> parseRemoteStorageConf(Configuration configuration) {
        HashMap newHashMap = Maps.newHashMap();
        Iterator it = configuration.iterator();
        while (it.hasNext()) {
            Map.Entry entry = (Map.Entry) it.next();
            newHashMap.put(entry.getKey(), entry.getValue());
        }
        return newHashMap;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    public static RemoteStorageInfo getDefaultRemoteStorageInfo(SparkConf sparkConf) {
        String string;
        Map newHashMap = Maps.newHashMap();
        RssConf rssConf = RssSparkConfig.toRssConf(sparkConf);
        if (rssConf.getBoolean(RssClientConf.RSS_CLIENT_REMOTE_STORAGE_USE_LOCAL_CONF_ENABLED)) {
            newHashMap = parseRemoteStorageConf(new Configuration(true));
        }
        for (String str : rssConf.getKeySet()) {
            if (str.startsWith(RssClientConf.HADOOP_CONFIG_KEY_PREFIX) && (string = rssConf.getString(str, (String) null)) != null) {
                newHashMap.put(str.replaceFirst(RssClientConf.HADOOP_CONFIG_KEY_PREFIX, ""), string);
            }
        }
        return new RemoteStorageInfo(sparkConf.get(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), ""), (Map<String, String>) newHashMap);
    }
}
