/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.shuffle;

import java.util.HashMap;
import java.util.Set;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.TaskContext;
import org.apache.spark.shuffle.RssShuffleManager;
import org.apache.spark.shuffle.RssSparkConfig;
import org.apache.spark.shuffle.RssSparkShuffleUtils;
import org.apache.spark.shuffle.ShuffleBlockResolver;
import org.apache.spark.shuffle.ShuffleHandle;
import org.apache.spark.shuffle.ShuffleManager;
import org.apache.spark.shuffle.ShuffleReadMetricsReporter;
import org.apache.spark.shuffle.ShuffleReader;
import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.uniffle.client.impl.grpc.CoordinatorGrpcRetryableClient;
import org.apache.uniffle.client.request.RssAccessClusterRequest;
import org.apache.uniffle.client.response.RssAccessClusterResponse;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.rpc.StatusCode;
import org.apache.uniffle.common.util.Constants;
import org.apache.uniffle.shaded.com.google.common.collect.Maps;
import org.apache.uniffle.shaded.org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DelegationRssShuffleManager
implements ShuffleManager {
    private static final Logger LOG = LoggerFactory.getLogger(DelegationRssShuffleManager.class);
    private final ShuffleManager delegate;
    private final CoordinatorGrpcRetryableClient coordinatorClient;
    private final int accessTimeoutMs;
    private final SparkConf sparkConf;
    private String user;
    private String uuid;

    public DelegationRssShuffleManager(SparkConf sparkConf, boolean isDriver) throws Exception {
        LOG.info("Uniffle {} version: {}", (Object)this.getClass().getName(), (Object)Constants.VERSION_AND_REVISION_SHORT);
        this.sparkConf = sparkConf;
        this.accessTimeoutMs = (Integer)sparkConf.get(RssSparkConfig.RSS_ACCESS_TIMEOUT_MS);
        if (isDriver) {
            this.coordinatorClient = RssSparkShuffleUtils.createCoordinatorClients(sparkConf);
            this.delegate = this.createShuffleManagerInDriver();
        } else {
            this.coordinatorClient = null;
            this.delegate = this.createShuffleManagerInExecutor();
        }
        if (this.delegate == null) {
            throw new RssException("Fail to create shuffle manager!");
        }
    }

    private ShuffleManager createShuffleManagerInDriver() throws RssException {
        ShuffleManager shuffleManager;
        this.user = "user";
        try {
            this.user = UserGroupInformation.getCurrentUser().getShortUserName();
        }
        catch (Exception e) {
            LOG.error("Error on getting user from ugi." + e);
        }
        boolean canAccess = this.tryAccessCluster();
        if (this.uuid == null || "".equals(this.uuid)) {
            this.uuid = String.valueOf(System.currentTimeMillis());
        }
        if (canAccess) {
            try {
                this.sparkConf.set("spark.rss.quota.user", this.user);
                this.sparkConf.set("spark.rss.quota.uuid", this.uuid);
                RssShuffleManager shuffleManager2 = new RssShuffleManager(this.sparkConf, true);
                this.sparkConf.set(RssSparkConfig.RSS_ENABLED.key(), "true");
                this.sparkConf.set("spark.shuffle.manager", RssShuffleManager.class.getCanonicalName());
                LOG.info("Use RssShuffleManager");
                return shuffleManager2;
            }
            catch (Exception exception) {
                LOG.warn("Fail to create RssShuffleManager, fallback to SortShuffleManager {}", (Object)exception.getMessage());
            }
        }
        try {
            shuffleManager = RssSparkShuffleUtils.loadShuffleManager("org.apache.spark.shuffle.sort.SortShuffleManager", this.sparkConf, true);
            this.sparkConf.set(RssSparkConfig.RSS_ENABLED.key(), "false");
            this.sparkConf.set("spark.shuffle.manager", "sort");
            LOG.info("Use SortShuffleManager");
        }
        catch (Exception e) {
            throw new RssException(e.getMessage());
        }
        return shuffleManager;
    }

    private boolean tryAccessCluster() {
        String accessId = this.sparkConf.get(RssSparkConfig.RSS_ACCESS_ID.key(), "").trim();
        if (StringUtils.isEmpty(accessId)) {
            LOG.warn("Access id key is empty");
            return false;
        }
        long retryInterval = (Long)this.sparkConf.get(RssSparkConfig.RSS_CLIENT_ACCESS_RETRY_INTERVAL_MS);
        int retryTimes = (Integer)this.sparkConf.get(RssSparkConfig.RSS_CLIENT_ACCESS_RETRY_TIMES);
        int assignmentShuffleNodesNum = (Integer)this.sparkConf.get(RssSparkConfig.RSS_CLIENT_ASSIGNMENT_SHUFFLE_SERVER_NUMBER);
        HashMap<String, String> extraProperties = Maps.newHashMap();
        extraProperties.put("access_info_required_shuffle_nodes_num", String.valueOf(assignmentShuffleNodesNum));
        Set<String> assignmentTags = RssSparkShuffleUtils.getAssignmentTags(this.sparkConf);
        try {
            if (this.coordinatorClient != null) {
                RssAccessClusterResponse response = this.coordinatorClient.accessCluster(new RssAccessClusterRequest(accessId, assignmentTags, this.accessTimeoutMs, extraProperties, this.user, retryInterval, retryTimes));
                if (response.getStatusCode() == StatusCode.SUCCESS) {
                    LOG.warn("Success to access cluster using {}", (Object)accessId);
                    this.uuid = response.getUuid();
                    return true;
                }
                if (response.getStatusCode() == StatusCode.ACCESS_DENIED) {
                    throw new RssException("Request to access cluster is denied using " + accessId + " for " + response.getMessage());
                }
                throw new RssException("Fail to reach cluster for " + response.getMessage());
            }
        }
        catch (Throwable e) {
            LOG.warn("Fail to access cluster using {} for ", (Object)accessId, (Object)e);
        }
        return false;
    }

    private ShuffleManager createShuffleManagerInExecutor() throws RssException {
        RssShuffleManager shuffleManager;
        boolean useRSS = (Boolean)this.sparkConf.get(RssSparkConfig.RSS_ENABLED);
        if (useRSS) {
            shuffleManager = new RssShuffleManager(this.sparkConf, false);
            LOG.info("Use RssShuffleManager");
        } else {
            try {
                shuffleManager = RssSparkShuffleUtils.loadShuffleManager("org.apache.spark.shuffle.sort.SortShuffleManager", this.sparkConf, false);
                LOG.info("Use SortShuffleManager");
            }
            catch (Exception e) {
                throw new RssException(e.getMessage());
            }
        }
        return shuffleManager;
    }

    public ShuffleManager getDelegate() {
        return this.delegate;
    }

    public <K, V, C> ShuffleHandle registerShuffle(int shuffleId, ShuffleDependency<K, V, C> dependency) {
        return this.delegate.registerShuffle(shuffleId, dependency);
    }

    public <K, V> ShuffleWriter<K, V> getWriter(ShuffleHandle handle, long mapId, TaskContext context, ShuffleWriteMetricsReporter metrics) {
        return this.delegate.getWriter(handle, mapId, context, metrics);
    }

    public <K, C> ShuffleReader<K, C> getReader(ShuffleHandle handle, int startPartition, int endPartition, TaskContext context, ShuffleReadMetricsReporter metrics) {
        return this.delegate.getReader(handle, startPartition, endPartition, context, metrics);
    }

    public <K, C> ShuffleReader<K, C> getReader(ShuffleHandle handle, int startMapIndex, int endMapIndex, int startPartition, int endPartition, TaskContext context, ShuffleReadMetricsReporter metrics) {
        ShuffleReader reader = null;
        try {
            reader = (ShuffleReader)this.delegate.getClass().getDeclaredMethod("getReader", ShuffleHandle.class, Integer.TYPE, Integer.TYPE, Integer.TYPE, Integer.TYPE, TaskContext.class, ShuffleReadMetricsReporter.class).invoke((Object)handle, startMapIndex, endMapIndex, startPartition, endPartition, context, metrics);
        }
        catch (Exception e) {
            throw new RssException(e);
        }
        return reader;
    }

    public <K, C> ShuffleReader<K, C> getReaderForRange(ShuffleHandle handle, int startMapIndex, int endMapIndex, int startPartition, int endPartition, TaskContext context, ShuffleReadMetricsReporter metrics) {
        ShuffleReader reader = null;
        try {
            reader = (ShuffleReader)this.delegate.getClass().getDeclaredMethod("getReaderForRange", ShuffleHandle.class, Integer.TYPE, Integer.TYPE, Integer.TYPE, Integer.TYPE, TaskContext.class, ShuffleReadMetricsReporter.class).invoke((Object)handle, startMapIndex, endMapIndex, startPartition, endPartition, context, metrics);
        }
        catch (Exception e) {
            throw new RssException(e);
        }
        return reader;
    }

    public boolean unregisterShuffle(int shuffleId) {
        return this.delegate.unregisterShuffle(shuffleId);
    }

    public void stop() {
        this.delegate.stop();
        if (this.coordinatorClient != null) {
            this.coordinatorClient.close();
        }
    }

    public ShuffleBlockResolver shuffleBlockResolver() {
        return this.delegate.shuffleBlockResolver();
    }
}

