package org.apache.uniffle.shuffle.manager;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo;
import org.apache.uniffle.common.ReceivingFailureServer;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.common.util.RssUtils;
import org.apache.uniffle.proto.RssProtos;
import org.apache.uniffle.proto.ShuffleManagerGrpc;
import org.apache.uniffle.shaded.com.google.protobuf.UnsafeByteOperations;
import org.apache.uniffle.shaded.io.grpc.stub.StreamObserver;
import org.apache.uniffle.shaded.org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.apache.uniffle.shuffle.BlockIdManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService.class */
public class ShuffleManagerGrpcService extends ShuffleManagerGrpc.ShuffleManagerImplBase {
    private static final Logger LOG = LoggerFactory.getLogger(ShuffleManagerGrpcService.class);
    private final Map<Integer, RssShuffleStatus> shuffleStatus = JavaUtils.newConcurrentMap();
    private final Map<Integer, ShuffleServerFailureRecord> shuffleWriteStatus = JavaUtils.newConcurrentMap();
    private final RssShuffleManagerInterface shuffleManager;

    /* loaded from: input_file:org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService$RssShuffleStatus.class */
    private static class RssShuffleStatus {
        private final ReentrantReadWriteLock lock;
        private final ReentrantReadWriteLock.ReadLock readLock;
        private final ReentrantReadWriteLock.WriteLock writeLock;
        private final int[] partitions;
        private int stageAttempt;

        private RssShuffleStatus(int i, int i2) {
            this.lock = new ReentrantReadWriteLock();
            this.readLock = this.lock.readLock();
            this.writeLock = this.lock.writeLock();
            this.stageAttempt = i2;
            this.partitions = new int[i];
        }

        private <T> T withReadLock(Supplier<T> supplier) {
            this.readLock.lock();
            try {
                return supplier.get();
            } finally {
                this.readLock.unlock();
            }
        }

        private <T> T withWriteLock(Supplier<T> supplier) {
            this.writeLock.lock();
            try {
                return supplier.get();
            } finally {
                this.writeLock.unlock();
            }
        }

        public int getStageAttempt() {
            return ((Integer) withReadLock(() -> {
                return Integer.valueOf(this.stageAttempt);
            })).intValue();
        }

        public int resetStageAttemptIfNecessary(int i) {
            return ((Integer) withWriteLock(() -> {
                if (this.stageAttempt >= i) {
                    return this.stageAttempt > i ? -1 : 0;
                }
                Arrays.fill(this.partitions, 0);
                this.stageAttempt = i;
                return 1;
            })).intValue();
        }

        public void incPartitionFetchFailure(int i, int i2) {
            withWriteLock(() -> {
                if (this.stageAttempt != i) {
                    return null;
                }
                this.partitions[i2] = this.partitions[i2] + 1;
                return null;
            });
        }

        public int getPartitionFetchFailureNum(int i, int i2) {
            return ((Integer) withReadLock(() -> {
                if (this.stageAttempt != i) {
                    return 0;
                }
                return Integer.valueOf(this.partitions[i2]);
            })).intValue();
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/uniffle/shuffle/manager/ShuffleManagerGrpcService$ShuffleServerFailureRecord.class */
    public static class ShuffleServerFailureRecord {
        private final ReentrantReadWriteLock lock;
        private final ReentrantReadWriteLock.ReadLock readLock;
        private final ReentrantReadWriteLock.WriteLock writeLock;
        private final Map<String, AtomicInteger> shuffleServerFailureRecordCount;
        private int stageAttemptNumber;

        private ShuffleServerFailureRecord(Map<String, AtomicInteger> map, int i) {
            this.lock = new ReentrantReadWriteLock();
            this.readLock = this.lock.readLock();
            this.writeLock = this.lock.writeLock();
            this.shuffleServerFailureRecordCount = map;
            this.stageAttemptNumber = i;
        }

        private <T> T withReadLock(Supplier<T> supplier) {
            this.readLock.lock();
            try {
                return supplier.get();
            } finally {
                this.readLock.unlock();
            }
        }

        private <T> T withWriteLock(Supplier<T> supplier) {
            this.writeLock.lock();
            try {
                return supplier.get();
            } finally {
                this.writeLock.unlock();
            }
        }

        public int getStageAttempt() {
            return ((Integer) withReadLock(() -> {
                return Integer.valueOf(this.stageAttemptNumber);
            })).intValue();
        }

        public boolean resetStageAttemptIfNecessary(int i) {
            return ((Boolean) withWriteLock(() -> {
                if (this.stageAttemptNumber >= i) {
                    return this.stageAttemptNumber > i;
                }
                this.shuffleServerFailureRecordCount.clear();
                this.stageAttemptNumber = i;
                return false;
            })).booleanValue();
        }

        public boolean incPartitionWriteFailure(int i, List<ShuffleServerInfo> list, RssShuffleManagerInterface rssShuffleManagerInterface) {
            return ((Boolean) withWriteLock(() -> {
                if (this.stageAttemptNumber != i) {
                    return false;
                }
                list.forEach(shuffleServerInfo -> {
                    this.shuffleServerFailureRecordCount.computeIfAbsent(shuffleServerInfo.getId(), str -> {
                        return new AtomicInteger();
                    }).incrementAndGet();
                });
                ArrayList arrayList = new ArrayList(this.shuffleServerFailureRecordCount.entrySet());
                if (!arrayList.isEmpty()) {
                    Collections.sort(arrayList, (entry, entry2) -> {
                        return ((AtomicInteger) entry.getValue()).get() - ((AtomicInteger) entry2.getValue()).get();
                    });
                    Map.Entry entry3 = (Map.Entry) arrayList.get(0);
                    if (((AtomicInteger) entry3.getValue()).get() > rssShuffleManagerInterface.getMaxFetchFailures()) {
                        rssShuffleManagerInterface.addFailuresShuffleServerInfos((String) entry3.getKey());
                        return true;
                    }
                }
                return false;
            })).booleanValue();
        }
    }

    public ShuffleManagerGrpcService(RssShuffleManagerInterface rssShuffleManagerInterface) {
        this.shuffleManager = rssShuffleManagerInterface;
    }

    @Override // org.apache.uniffle.proto.ShuffleManagerGrpc.AsyncService
    public void reportShuffleWriteFailure(RssProtos.ReportShuffleWriteFailureRequest reportShuffleWriteFailureRequest, StreamObserver<RssProtos.ReportShuffleWriteFailureResponse> streamObserver) {
        RssProtos.StatusCode statusCode;
        boolean z;
        String str;
        String appId = reportShuffleWriteFailureRequest.getAppId();
        int shuffleId = reportShuffleWriteFailureRequest.getShuffleId();
        int stageAttemptNumber = reportShuffleWriteFailureRequest.getStageAttemptNumber();
        List<RssProtos.ShuffleServerId> shuffleServerIdsList = reportShuffleWriteFailureRequest.getShuffleServerIdsList();
        if (appId.equals(this.shuffleManager.getAppId())) {
            ConcurrentHashMap newConcurrentMap = JavaUtils.newConcurrentMap();
            List<ShuffleServerInfo> fromProto = ShuffleServerInfo.fromProto(shuffleServerIdsList);
            fromProto.forEach(shuffleServerInfo -> {
                newConcurrentMap.put(shuffleServerInfo.getId(), new AtomicInteger(0));
            });
            ShuffleServerFailureRecord computeIfAbsent = this.shuffleWriteStatus.computeIfAbsent(Integer.valueOf(shuffleId), num -> {
                return new ShuffleServerFailureRecord(newConcurrentMap, stageAttemptNumber);
            });
            if (computeIfAbsent.resetStageAttemptIfNecessary(stageAttemptNumber)) {
                str = String.format("got an old stage(%d vs %d) shuffle write failure report, which should be impossible.", Integer.valueOf(computeIfAbsent.getStageAttempt()), Integer.valueOf(stageAttemptNumber));
                LOG.warn(str);
                statusCode = RssProtos.StatusCode.INVALID_REQUEST;
                z = false;
            } else {
                statusCode = RssProtos.StatusCode.SUCCESS;
                if (computeIfAbsent.incPartitionWriteFailure(stageAttemptNumber, fromProto, this.shuffleManager)) {
                    z = true;
                    str = String.format("report shuffle write failure as maximum number(%d) of shuffle write is occurred", Integer.valueOf(this.shuffleManager.getMaxFetchFailures()));
                } else {
                    z = false;
                    str = "don't report shuffle write failure";
                }
            }
        } else {
            str = String.format("got a wrong shuffle write failure report from appId: %s, expected appId: %s", appId, this.shuffleManager.getAppId());
            LOG.warn(str);
            statusCode = RssProtos.StatusCode.INVALID_REQUEST;
            z = false;
        }
        streamObserver.onNext(RssProtos.ReportShuffleWriteFailureResponse.newBuilder().setStatus(statusCode).setReSubmitWholeStage(z).setMsg(str).build());
        streamObserver.onCompleted();
    }

    @Override // org.apache.uniffle.proto.ShuffleManagerGrpc.AsyncService
    public void reportShuffleFetchFailure(RssProtos.ReportShuffleFetchFailureRequest reportShuffleFetchFailureRequest, StreamObserver<RssProtos.ReportShuffleFetchFailureResponse> streamObserver) {
        RssProtos.StatusCode statusCode;
        boolean z;
        String str;
        String appId = reportShuffleFetchFailureRequest.getAppId();
        int stageAttemptId = reportShuffleFetchFailureRequest.getStageAttemptId();
        int partitionId = reportShuffleFetchFailureRequest.getPartitionId();
        if (appId.equals(this.shuffleManager.getAppId())) {
            RssShuffleStatus computeIfAbsent = this.shuffleStatus.computeIfAbsent(Integer.valueOf(reportShuffleFetchFailureRequest.getShuffleId()), num -> {
                return new RssShuffleStatus(this.shuffleManager.getPartitionNum(num.intValue()), stageAttemptId);
            });
            if (computeIfAbsent.resetStageAttemptIfNecessary(stageAttemptId) < 0) {
                str = String.format("got an old stage(%d vs %d) shuffle fetch failure report, which should be impossible.", Integer.valueOf(computeIfAbsent.getStageAttempt()), Integer.valueOf(stageAttemptId));
                LOG.warn(str);
                statusCode = RssProtos.StatusCode.INVALID_REQUEST;
                z = false;
            } else {
                statusCode = RssProtos.StatusCode.SUCCESS;
                computeIfAbsent.incPartitionFetchFailure(stageAttemptId, partitionId);
                if (computeIfAbsent.getPartitionFetchFailureNum(stageAttemptId, partitionId) >= this.shuffleManager.getMaxFetchFailures()) {
                    z = true;
                    str = String.format("report shuffle fetch failure as maximum number(%d) of shuffle fetch is occurred", Integer.valueOf(this.shuffleManager.getMaxFetchFailures()));
                } else {
                    z = false;
                    str = "don't report shuffle fetch failure";
                }
            }
        } else {
            str = String.format("got a wrong shuffle fetch failure report from appId: %s, expected appId: %s", appId, this.shuffleManager.getAppId());
            LOG.warn(str);
            statusCode = RssProtos.StatusCode.INVALID_REQUEST;
            z = false;
        }
        streamObserver.onNext(RssProtos.ReportShuffleFetchFailureResponse.newBuilder().setStatus(statusCode).setReSubmitWholeStage(z).setMsg(str).build());
        streamObserver.onCompleted();
    }

    @Override // org.apache.uniffle.proto.ShuffleManagerGrpc.AsyncService
    public void getPartitionToShufflerServerWithStageRetry(RssProtos.PartitionToShuffleServerRequest partitionToShuffleServerRequest, StreamObserver<RssProtos.ReassignOnStageRetryResponse> streamObserver) {
        RssProtos.ReassignOnStageRetryResponse build;
        StageAttemptShuffleHandleInfo stageAttemptShuffleHandleInfo = (StageAttemptShuffleHandleInfo) this.shuffleManager.getShuffleHandleInfoByShuffleId(partitionToShuffleServerRequest.getShuffleId());
        if (stageAttemptShuffleHandleInfo != null) {
            build = RssProtos.ReassignOnStageRetryResponse.newBuilder().setStatus(RssProtos.StatusCode.SUCCESS).setShuffleHandleInfo(StageAttemptShuffleHandleInfo.toProto(stageAttemptShuffleHandleInfo)).build();
        } else {
            build = RssProtos.ReassignOnStageRetryResponse.newBuilder().setStatus(RssProtos.StatusCode.INVALID_REQUEST).build();
        }
        streamObserver.onNext(build);
        streamObserver.onCompleted();
    }

    @Override // org.apache.uniffle.proto.ShuffleManagerGrpc.AsyncService
    public void getPartitionToShufflerServerWithBlockRetry(RssProtos.PartitionToShuffleServerRequest partitionToShuffleServerRequest, StreamObserver<RssProtos.ReassignOnBlockSendFailureResponse> streamObserver) {
        RssProtos.ReassignOnBlockSendFailureResponse build;
        MutableShuffleHandleInfo mutableShuffleHandleInfo = (MutableShuffleHandleInfo) this.shuffleManager.getShuffleHandleInfoByShuffleId(partitionToShuffleServerRequest.getShuffleId());
        if (mutableShuffleHandleInfo != null) {
            build = RssProtos.ReassignOnBlockSendFailureResponse.newBuilder().setStatus(RssProtos.StatusCode.SUCCESS).setHandle(MutableShuffleHandleInfo.toProto(mutableShuffleHandleInfo)).build();
        } else {
            build = RssProtos.ReassignOnBlockSendFailureResponse.newBuilder().setStatus(RssProtos.StatusCode.INVALID_REQUEST).build();
        }
        streamObserver.onNext(build);
        streamObserver.onCompleted();
    }

    @Override // org.apache.uniffle.proto.ShuffleManagerGrpc.AsyncService
    public void reassignOnStageResubmit(RssProtos.ReassignServersRequest reassignServersRequest, StreamObserver<RssProtos.ReassignServersResponse> streamObserver) {
        streamObserver.onNext(RssProtos.ReassignServersResponse.newBuilder().setStatus(RssProtos.StatusCode.SUCCESS).setNeedReassign(this.shuffleManager.reassignOnStageResubmit(reassignServersRequest.getStageId(), reassignServersRequest.getStageAttemptNumber(), reassignServersRequest.getShuffleId(), reassignServersRequest.getNumPartitions())).build());
        streamObserver.onCompleted();
    }

    @Override // org.apache.uniffle.proto.ShuffleManagerGrpc.AsyncService
    public void reassignOnBlockSendFailure(RssProtos.RssReassignOnBlockSendFailureRequest rssReassignOnBlockSendFailureRequest, StreamObserver<RssProtos.ReassignOnBlockSendFailureResponse> streamObserver) {
        RssProtos.ReassignOnBlockSendFailureResponse build;
        RssProtos.StatusCode statusCode = RssProtos.StatusCode.INTERNAL_ERROR;
        try {
            LOG.info("Accepted reassign request on block sent failure for shuffleId: {}, stageId: {}, stageAttemptNumber: {} from taskAttemptId: {} on executorId: {} while partition split:{}", new Object[]{Integer.valueOf(rssReassignOnBlockSendFailureRequest.getShuffleId()), Integer.valueOf(rssReassignOnBlockSendFailureRequest.getStageId()), Integer.valueOf(rssReassignOnBlockSendFailureRequest.getStageAttemptNumber()), Long.valueOf(rssReassignOnBlockSendFailureRequest.getTaskAttemptId()), rssReassignOnBlockSendFailureRequest.getExecutorId(), Boolean.valueOf(rssReassignOnBlockSendFailureRequest.getPartitionSplit())});
            MutableShuffleHandleInfo reassignOnBlockSendFailure = this.shuffleManager.reassignOnBlockSendFailure(rssReassignOnBlockSendFailureRequest.getStageId(), rssReassignOnBlockSendFailureRequest.getStageAttemptNumber(), rssReassignOnBlockSendFailureRequest.getShuffleId(), (Map) rssReassignOnBlockSendFailureRequest.getFailurePartitionToServerIdsMap().entrySet().stream().collect(Collectors.toMap((v0) -> {
                return v0.getKey();
            }, entry -> {
                return ReceivingFailureServer.fromProto((RssProtos.ReceivingFailureServers) entry.getValue());
            })), rssReassignOnBlockSendFailureRequest.getPartitionSplit());
            statusCode = RssProtos.StatusCode.SUCCESS;
            build = RssProtos.ReassignOnBlockSendFailureResponse.newBuilder().setStatus(statusCode).setHandle(MutableShuffleHandleInfo.toProto(reassignOnBlockSendFailure)).build();
        } catch (Exception e) {
            LOG.error("Errors on reassigning when block send failure.", e);
            build = RssProtos.ReassignOnBlockSendFailureResponse.newBuilder().setStatus(statusCode).setMsg(e.getMessage()).build();
        }
        streamObserver.onNext(build);
        streamObserver.onCompleted();
    }

    public void unregisterShuffle(int i) {
        this.shuffleStatus.remove(Integer.valueOf(i));
    }

    @Override // org.apache.uniffle.proto.ShuffleManagerGrpc.AsyncService
    public void getShuffleResult(RssProtos.GetShuffleResultRequest getShuffleResultRequest, StreamObserver<RssProtos.GetShuffleResultResponse> streamObserver) {
        RssProtos.GetShuffleResultResponse build;
        String appId = getShuffleResultRequest.getAppId();
        if (!appId.equals(this.shuffleManager.getAppId())) {
            streamObserver.onNext(RssProtos.GetShuffleResultResponse.newBuilder().setStatus(RssProtos.StatusCode.ACCESS_DENIED).setRetMsg("Illegal appId: " + appId).build());
            streamObserver.onCompleted();
            return;
        }
        try {
            build = RssProtos.GetShuffleResultResponse.newBuilder().setStatus(RssProtos.StatusCode.SUCCESS).setSerializedBitmap(UnsafeByteOperations.unsafeWrap(RssUtils.serializeBitMap(this.shuffleManager.getBlockIdManager().get(getShuffleResultRequest.getShuffleId(), getShuffleResultRequest.getPartitionId())))).build();
        } catch (Exception e) {
            LOG.error("Errors on getting the blockId bitmap.", e);
            build = RssProtos.GetShuffleResultResponse.newBuilder().setStatus(RssProtos.StatusCode.INTERNAL_ERROR).build();
        }
        streamObserver.onNext(build);
        streamObserver.onCompleted();
    }

    @Override // org.apache.uniffle.proto.ShuffleManagerGrpc.AsyncService
    public void getShuffleResultForMultiPart(RssProtos.GetShuffleResultForMultiPartRequest getShuffleResultForMultiPartRequest, StreamObserver<RssProtos.GetShuffleResultForMultiPartResponse> streamObserver) {
        RssProtos.GetShuffleResultForMultiPartResponse build;
        String appId = getShuffleResultForMultiPartRequest.getAppId();
        if (!appId.equals(this.shuffleManager.getAppId())) {
            streamObserver.onNext(RssProtos.GetShuffleResultForMultiPartResponse.newBuilder().setStatus(RssProtos.StatusCode.ACCESS_DENIED).setRetMsg("Illegal appId: " + appId).build());
            streamObserver.onCompleted();
            return;
        }
        BlockIdManager blockIdManager = this.shuffleManager.getBlockIdManager();
        int shuffleId = getShuffleResultForMultiPartRequest.getShuffleId();
        List<Integer> partitionsList = getShuffleResultForMultiPartRequest.getPartitionsList();
        Roaring64NavigableMap bitmapOf = Roaring64NavigableMap.bitmapOf(new long[0]);
        Iterator<Integer> it = partitionsList.iterator();
        while (it.hasNext()) {
            blockIdManager.get(shuffleId, it.next().intValue()).forEach(j -> {
                bitmapOf.add(j);
            });
        }
        try {
            build = RssProtos.GetShuffleResultForMultiPartResponse.newBuilder().setStatus(RssProtos.StatusCode.SUCCESS).setSerializedBitmap(UnsafeByteOperations.unsafeWrap(RssUtils.serializeBitMap(bitmapOf))).build();
        } catch (Exception e) {
            LOG.error("Errors on getting the blockId bitmap.", e);
            build = RssProtos.GetShuffleResultForMultiPartResponse.newBuilder().setStatus(RssProtos.StatusCode.INTERNAL_ERROR).build();
        }
        streamObserver.onNext(build);
        streamObserver.onCompleted();
    }

    @Override // org.apache.uniffle.proto.ShuffleManagerGrpc.AsyncService
    public void reportShuffleResult(RssProtos.ReportShuffleResultRequest reportShuffleResultRequest, StreamObserver<RssProtos.ReportShuffleResultResponse> streamObserver) {
        String appId = reportShuffleResultRequest.getAppId();
        if (!appId.equals(this.shuffleManager.getAppId())) {
            streamObserver.onNext(RssProtos.ReportShuffleResultResponse.newBuilder().setStatus(RssProtos.StatusCode.ACCESS_DENIED).setRetMsg("Illegal appId: " + appId).build());
            streamObserver.onCompleted();
            return;
        }
        BlockIdManager blockIdManager = this.shuffleManager.getBlockIdManager();
        int shuffleId = reportShuffleResultRequest.getShuffleId();
        for (RssProtos.PartitionToBlockIds partitionToBlockIds : reportShuffleResultRequest.getPartitionToBlockIdsList()) {
            blockIdManager.add(shuffleId, partitionToBlockIds.getPartitionId(), partitionToBlockIds.getBlockIdsList());
        }
        streamObserver.onNext(RssProtos.ReportShuffleResultResponse.newBuilder().setStatus(RssProtos.StatusCode.SUCCESS).build());
        streamObserver.onCompleted();
    }
}
