/*
 * Decompiled with CFR 0.152.
 */
package org.apache.uniffle.shuffle.manager;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import org.apache.spark.shuffle.handle.MutableShuffleHandleInfo;
import org.apache.spark.shuffle.handle.StageAttemptShuffleHandleInfo;
import org.apache.uniffle.common.ReceivingFailureServer;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.util.JavaUtils;
import org.apache.uniffle.common.util.RssUtils;
import org.apache.uniffle.proto.RssProtos;
import org.apache.uniffle.proto.ShuffleManagerGrpc;
import org.apache.uniffle.shaded.com.google.protobuf.UnsafeByteOperations;
import org.apache.uniffle.shaded.io.grpc.stub.StreamObserver;
import org.apache.uniffle.shaded.org.roaringbitmap.longlong.Roaring64NavigableMap;
import org.apache.uniffle.shuffle.BlockIdManager;
import org.apache.uniffle.shuffle.manager.RssShuffleManagerInterface;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class ShuffleManagerGrpcService
extends ShuffleManagerGrpc.ShuffleManagerImplBase {
    private static final Logger LOG = LoggerFactory.getLogger(ShuffleManagerGrpcService.class);
    private final Map<Integer, RssShuffleStatus> shuffleStatus = JavaUtils.newConcurrentMap();
    private final Map<Integer, ShuffleServerFailureRecord> shuffleWriteStatus = JavaUtils.newConcurrentMap();
    private final RssShuffleManagerInterface shuffleManager;

    public ShuffleManagerGrpcService(RssShuffleManagerInterface shuffleManager) {
        this.shuffleManager = shuffleManager;
    }

    @Override
    public void reportShuffleWriteFailure(RssProtos.ReportShuffleWriteFailureRequest request, StreamObserver<RssProtos.ReportShuffleWriteFailureResponse> responseObserver) {
        boolean reSubmitWholeStage;
        RssProtos.StatusCode code;
        String msg;
        String appId = request.getAppId();
        int shuffleId = request.getShuffleId();
        int stageAttemptNumber = request.getStageAttemptNumber();
        List<RssProtos.ShuffleServerId> shuffleServerIdsList = request.getShuffleServerIdsList();
        if (!appId.equals(this.shuffleManager.getAppId())) {
            msg = String.format("got a wrong shuffle write failure report from appId: %s, expected appId: %s", appId, this.shuffleManager.getAppId());
            LOG.warn(msg);
            code = RssProtos.StatusCode.INVALID_REQUEST;
            reSubmitWholeStage = false;
        } else {
            ConcurrentHashMap shuffleServerInfoIntegerMap = JavaUtils.newConcurrentMap();
            List<ShuffleServerInfo> shuffleServerInfos = ShuffleServerInfo.fromProto(shuffleServerIdsList);
            shuffleServerInfos.forEach(shuffleServerInfo -> shuffleServerInfoIntegerMap.put(shuffleServerInfo.getId(), new AtomicInteger(0)));
            ShuffleServerFailureRecord shuffleServerFailureRecord = this.shuffleWriteStatus.computeIfAbsent(shuffleId, key -> new ShuffleServerFailureRecord(shuffleServerInfoIntegerMap, stageAttemptNumber));
            boolean resetflag = shuffleServerFailureRecord.resetStageAttemptIfNecessary(stageAttemptNumber);
            if (resetflag) {
                msg = String.format("got an old stage(%d vs %d) shuffle write failure report, which should be impossible.", shuffleServerFailureRecord.getStageAttempt(), stageAttemptNumber);
                LOG.warn(msg);
                code = RssProtos.StatusCode.INVALID_REQUEST;
                reSubmitWholeStage = false;
            } else {
                code = RssProtos.StatusCode.SUCCESS;
                boolean fetchFailureflag = shuffleServerFailureRecord.incPartitionWriteFailure(stageAttemptNumber, shuffleServerInfos, this.shuffleManager);
                if (fetchFailureflag) {
                    reSubmitWholeStage = true;
                    msg = String.format("report shuffle write failure as maximum number(%d) of shuffle write is occurred", this.shuffleManager.getMaxFetchFailures());
                } else {
                    reSubmitWholeStage = false;
                    msg = "don't report shuffle write failure";
                }
            }
        }
        RssProtos.ReportShuffleWriteFailureResponse reply = RssProtos.ReportShuffleWriteFailureResponse.newBuilder().setStatus(code).setReSubmitWholeStage(reSubmitWholeStage).setMsg(msg).build();
        responseObserver.onNext(reply);
        responseObserver.onCompleted();
    }

    @Override
    public void reportShuffleFetchFailure(RssProtos.ReportShuffleFetchFailureRequest request, StreamObserver<RssProtos.ReportShuffleFetchFailureResponse> responseObserver) {
        boolean reSubmitWholeStage;
        RssProtos.StatusCode code;
        String msg;
        String appId = request.getAppId();
        int stageAttempt = request.getStageAttemptId();
        int partitionId = request.getPartitionId();
        if (!appId.equals(this.shuffleManager.getAppId())) {
            msg = String.format("got a wrong shuffle fetch failure report from appId: %s, expected appId: %s", appId, this.shuffleManager.getAppId());
            LOG.warn(msg);
            code = RssProtos.StatusCode.INVALID_REQUEST;
            reSubmitWholeStage = false;
        } else {
            RssShuffleStatus status = this.shuffleStatus.computeIfAbsent(request.getShuffleId(), key -> {
                int partitionNum = this.shuffleManager.getPartitionNum((int)key);
                return new RssShuffleStatus(partitionNum, stageAttempt);
            });
            int c = status.resetStageAttemptIfNecessary(stageAttempt);
            if (c < 0) {
                msg = String.format("got an old stage(%d vs %d) shuffle fetch failure report, which should be impossible.", status.getStageAttempt(), stageAttempt);
                LOG.warn(msg);
                code = RssProtos.StatusCode.INVALID_REQUEST;
                reSubmitWholeStage = false;
            } else {
                code = RssProtos.StatusCode.SUCCESS;
                status.incPartitionFetchFailure(stageAttempt, partitionId);
                int fetchFailureNum = status.getPartitionFetchFailureNum(stageAttempt, partitionId);
                if (fetchFailureNum >= this.shuffleManager.getMaxFetchFailures()) {
                    reSubmitWholeStage = true;
                    msg = String.format("report shuffle fetch failure as maximum number(%d) of shuffle fetch is occurred", this.shuffleManager.getMaxFetchFailures());
                } else {
                    reSubmitWholeStage = false;
                    msg = "don't report shuffle fetch failure";
                }
            }
        }
        RssProtos.ReportShuffleFetchFailureResponse reply = RssProtos.ReportShuffleFetchFailureResponse.newBuilder().setStatus(code).setReSubmitWholeStage(reSubmitWholeStage).setMsg(msg).build();
        responseObserver.onNext(reply);
        responseObserver.onCompleted();
    }

    @Override
    public void getPartitionToShufflerServerWithStageRetry(RssProtos.PartitionToShuffleServerRequest request, StreamObserver<RssProtos.ReassignOnStageRetryResponse> responseObserver) {
        RssProtos.ReassignOnStageRetryResponse reply;
        int shuffleId = request.getShuffleId();
        StageAttemptShuffleHandleInfo shuffleHandle = (StageAttemptShuffleHandleInfo)this.shuffleManager.getShuffleHandleInfoByShuffleId(shuffleId);
        if (shuffleHandle != null) {
            RssProtos.StatusCode code = RssProtos.StatusCode.SUCCESS;
            reply = RssProtos.ReassignOnStageRetryResponse.newBuilder().setStatus(code).setShuffleHandleInfo(StageAttemptShuffleHandleInfo.toProto(shuffleHandle)).build();
        } else {
            RssProtos.StatusCode code = RssProtos.StatusCode.INVALID_REQUEST;
            reply = RssProtos.ReassignOnStageRetryResponse.newBuilder().setStatus(code).build();
        }
        responseObserver.onNext(reply);
        responseObserver.onCompleted();
    }

    @Override
    public void getPartitionToShufflerServerWithBlockRetry(RssProtos.PartitionToShuffleServerRequest request, StreamObserver<RssProtos.ReassignOnBlockSendFailureResponse> responseObserver) {
        RssProtos.ReassignOnBlockSendFailureResponse reply;
        int shuffleId = request.getShuffleId();
        MutableShuffleHandleInfo shuffleHandle = (MutableShuffleHandleInfo)this.shuffleManager.getShuffleHandleInfoByShuffleId(shuffleId);
        if (shuffleHandle != null) {
            RssProtos.StatusCode code = RssProtos.StatusCode.SUCCESS;
            reply = RssProtos.ReassignOnBlockSendFailureResponse.newBuilder().setStatus(code).setHandle(MutableShuffleHandleInfo.toProto(shuffleHandle)).build();
        } else {
            RssProtos.StatusCode code = RssProtos.StatusCode.INVALID_REQUEST;
            reply = RssProtos.ReassignOnBlockSendFailureResponse.newBuilder().setStatus(code).build();
        }
        responseObserver.onNext(reply);
        responseObserver.onCompleted();
    }

    @Override
    public void reassignOnStageResubmit(RssProtos.ReassignServersRequest request, StreamObserver<RssProtos.ReassignServersResponse> responseObserver) {
        int stageId = request.getStageId();
        int stageAttemptNumber = request.getStageAttemptNumber();
        int shuffleId = request.getShuffleId();
        int numPartitions = request.getNumPartitions();
        boolean needReassign = this.shuffleManager.reassignOnStageResubmit(stageId, stageAttemptNumber, shuffleId, numPartitions);
        RssProtos.StatusCode code = RssProtos.StatusCode.SUCCESS;
        RssProtos.ReassignServersResponse reply = RssProtos.ReassignServersResponse.newBuilder().setStatus(code).setNeedReassign(needReassign).build();
        responseObserver.onNext(reply);
        responseObserver.onCompleted();
    }

    @Override
    public void reassignOnBlockSendFailure(RssProtos.RssReassignOnBlockSendFailureRequest request, StreamObserver<RssProtos.ReassignOnBlockSendFailureResponse> responseObserver) {
        RssProtos.ReassignOnBlockSendFailureResponse reply;
        RssProtos.StatusCode code = RssProtos.StatusCode.INTERNAL_ERROR;
        try {
            LOG.info("Accepted reassign request on block sent failure for shuffleId: {}, stageId: {}, stageAttemptNumber: {} from taskAttemptId: {} on executorId: {} while partition split:{}", new Object[]{request.getShuffleId(), request.getStageId(), request.getStageAttemptNumber(), request.getTaskAttemptId(), request.getExecutorId(), request.getPartitionSplit()});
            MutableShuffleHandleInfo handle = this.shuffleManager.reassignOnBlockSendFailure(request.getStageId(), request.getStageAttemptNumber(), request.getShuffleId(), request.getFailurePartitionToServerIdsMap().entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, x -> ReceivingFailureServer.fromProto((RssProtos.ReceivingFailureServers)x.getValue()))), request.getPartitionSplit());
            code = RssProtos.StatusCode.SUCCESS;
            reply = RssProtos.ReassignOnBlockSendFailureResponse.newBuilder().setStatus(code).setHandle(MutableShuffleHandleInfo.toProto(handle)).build();
        }
        catch (Exception e) {
            LOG.error("Errors on reassigning when block send failure.", (Throwable)e);
            reply = RssProtos.ReassignOnBlockSendFailureResponse.newBuilder().setStatus(code).setMsg(e.getMessage()).build();
        }
        responseObserver.onNext(reply);
        responseObserver.onCompleted();
    }

    public void unregisterShuffle(int shuffleId) {
        this.shuffleStatus.remove(shuffleId);
    }

    @Override
    public void getShuffleResult(RssProtos.GetShuffleResultRequest request, StreamObserver<RssProtos.GetShuffleResultResponse> responseObserver) {
        RssProtos.GetShuffleResultResponse reply;
        String appId = request.getAppId();
        if (!appId.equals(this.shuffleManager.getAppId())) {
            RssProtos.GetShuffleResultResponse reply2 = RssProtos.GetShuffleResultResponse.newBuilder().setStatus(RssProtos.StatusCode.ACCESS_DENIED).setRetMsg("Illegal appId: " + appId).build();
            responseObserver.onNext(reply2);
            responseObserver.onCompleted();
            return;
        }
        int shuffleId = request.getShuffleId();
        int partitionId = request.getPartitionId();
        BlockIdManager blockIdManager = this.shuffleManager.getBlockIdManager();
        Roaring64NavigableMap blockIdBitmap = blockIdManager.get(shuffleId, partitionId);
        try {
            byte[] serializeBitmap = RssUtils.serializeBitMap(blockIdBitmap);
            reply = RssProtos.GetShuffleResultResponse.newBuilder().setStatus(RssProtos.StatusCode.SUCCESS).setSerializedBitmap(UnsafeByteOperations.unsafeWrap(serializeBitmap)).build();
        }
        catch (Exception exception) {
            LOG.error("Errors on getting the blockId bitmap.", (Throwable)exception);
            reply = RssProtos.GetShuffleResultResponse.newBuilder().setStatus(RssProtos.StatusCode.INTERNAL_ERROR).build();
        }
        responseObserver.onNext(reply);
        responseObserver.onCompleted();
    }

    @Override
    public void getShuffleResultForMultiPart(RssProtos.GetShuffleResultForMultiPartRequest request, StreamObserver<RssProtos.GetShuffleResultForMultiPartResponse> responseObserver) {
        RssProtos.GetShuffleResultForMultiPartResponse reply;
        String appId = request.getAppId();
        if (!appId.equals(this.shuffleManager.getAppId())) {
            RssProtos.GetShuffleResultForMultiPartResponse reply2 = RssProtos.GetShuffleResultForMultiPartResponse.newBuilder().setStatus(RssProtos.StatusCode.ACCESS_DENIED).setRetMsg("Illegal appId: " + appId).build();
            responseObserver.onNext(reply2);
            responseObserver.onCompleted();
            return;
        }
        BlockIdManager blockIdManager = this.shuffleManager.getBlockIdManager();
        int shuffleId = request.getShuffleId();
        List<Integer> partitionIds = request.getPartitionsList();
        Roaring64NavigableMap blockIdBitmapCollection = Roaring64NavigableMap.bitmapOf(new long[0]);
        for (int partitionId : partitionIds) {
            Roaring64NavigableMap blockIds = blockIdManager.get(shuffleId, partitionId);
            blockIds.forEach(x -> blockIdBitmapCollection.add(x));
        }
        try {
            byte[] serializeBitmap = RssUtils.serializeBitMap(blockIdBitmapCollection);
            reply = RssProtos.GetShuffleResultForMultiPartResponse.newBuilder().setStatus(RssProtos.StatusCode.SUCCESS).setSerializedBitmap(UnsafeByteOperations.unsafeWrap(serializeBitmap)).build();
        }
        catch (Exception exception) {
            LOG.error("Errors on getting the blockId bitmap.", (Throwable)exception);
            reply = RssProtos.GetShuffleResultForMultiPartResponse.newBuilder().setStatus(RssProtos.StatusCode.INTERNAL_ERROR).build();
        }
        responseObserver.onNext(reply);
        responseObserver.onCompleted();
    }

    @Override
    public void reportShuffleResult(RssProtos.ReportShuffleResultRequest request, StreamObserver<RssProtos.ReportShuffleResultResponse> responseObserver) {
        String appId = request.getAppId();
        if (!appId.equals(this.shuffleManager.getAppId())) {
            RssProtos.ReportShuffleResultResponse reply = RssProtos.ReportShuffleResultResponse.newBuilder().setStatus(RssProtos.StatusCode.ACCESS_DENIED).setRetMsg("Illegal appId: " + appId).build();
            responseObserver.onNext(reply);
            responseObserver.onCompleted();
            return;
        }
        BlockIdManager blockIdManager = this.shuffleManager.getBlockIdManager();
        int shuffleId = request.getShuffleId();
        for (RssProtos.PartitionToBlockIds partitionToBlockIds : request.getPartitionToBlockIdsList()) {
            int partitionId = partitionToBlockIds.getPartitionId();
            List<Long> blockIds = partitionToBlockIds.getBlockIdsList();
            blockIdManager.add(shuffleId, partitionId, blockIds);
        }
        RssProtos.ReportShuffleResultResponse reply = RssProtos.ReportShuffleResultResponse.newBuilder().setStatus(RssProtos.StatusCode.SUCCESS).build();
        responseObserver.onNext(reply);
        responseObserver.onCompleted();
    }

    private static class RssShuffleStatus {
        private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
        private final ReentrantReadWriteLock.ReadLock readLock = this.lock.readLock();
        private final ReentrantReadWriteLock.WriteLock writeLock = this.lock.writeLock();
        private final int[] partitions;
        private int stageAttempt;

        private RssShuffleStatus(int partitionNum, int stageAttempt) {
            this.stageAttempt = stageAttempt;
            this.partitions = new int[partitionNum];
        }

        private <T> T withReadLock(Supplier<T> fn) {
            this.readLock.lock();
            try {
                T t2 = fn.get();
                return t2;
            }
            finally {
                this.readLock.unlock();
            }
        }

        private <T> T withWriteLock(Supplier<T> fn) {
            this.writeLock.lock();
            try {
                T t2 = fn.get();
                return t2;
            }
            finally {
                this.writeLock.unlock();
            }
        }

        public int getStageAttempt() {
            return this.withReadLock(() -> this.stageAttempt);
        }

        public int resetStageAttemptIfNecessary(int stageAttempt) {
            return this.withWriteLock(() -> {
                if (this.stageAttempt < stageAttempt) {
                    Arrays.fill(this.partitions, 0);
                    this.stageAttempt = stageAttempt;
                    return 1;
                }
                if (this.stageAttempt > stageAttempt) {
                    return -1;
                }
                return 0;
            });
        }

        public void incPartitionFetchFailure(int stageAttempt, int partition) {
            this.withWriteLock(() -> {
                if (this.stageAttempt == stageAttempt) {
                    this.partitions[partition] = this.partitions[partition] + 1;
                }
                return null;
            });
        }

        public int getPartitionFetchFailureNum(int stageAttempt, int partition) {
            return this.withReadLock(() -> {
                if (this.stageAttempt != stageAttempt) {
                    return 0;
                }
                return this.partitions[partition];
            });
        }
    }

    private static class ShuffleServerFailureRecord {
        private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock();
        private final ReentrantReadWriteLock.ReadLock readLock = this.lock.readLock();
        private final ReentrantReadWriteLock.WriteLock writeLock = this.lock.writeLock();
        private final Map<String, AtomicInteger> shuffleServerFailureRecordCount;
        private int stageAttemptNumber;

        private ShuffleServerFailureRecord(Map<String, AtomicInteger> shuffleServerFailureRecordCount, int stageAttemptNumber) {
            this.shuffleServerFailureRecordCount = shuffleServerFailureRecordCount;
            this.stageAttemptNumber = stageAttemptNumber;
        }

        private <T> T withReadLock(Supplier<T> fn) {
            this.readLock.lock();
            try {
                T t2 = fn.get();
                return t2;
            }
            finally {
                this.readLock.unlock();
            }
        }

        private <T> T withWriteLock(Supplier<T> fn) {
            this.writeLock.lock();
            try {
                T t2 = fn.get();
                return t2;
            }
            finally {
                this.writeLock.unlock();
            }
        }

        public int getStageAttempt() {
            return this.withReadLock(() -> this.stageAttemptNumber);
        }

        public boolean resetStageAttemptIfNecessary(int stageAttemptNumber) {
            return this.withWriteLock(() -> {
                if (this.stageAttemptNumber < stageAttemptNumber) {
                    this.shuffleServerFailureRecordCount.clear();
                    this.stageAttemptNumber = stageAttemptNumber;
                    return false;
                }
                if (this.stageAttemptNumber > stageAttemptNumber) {
                    return true;
                }
                return false;
            });
        }

        public boolean incPartitionWriteFailure(int stageAttemptNumber, List<ShuffleServerInfo> shuffleServerInfos, RssShuffleManagerInterface shuffleManager) {
            return this.withWriteLock(() -> {
                if (this.stageAttemptNumber != stageAttemptNumber) {
                    return false;
                }
                shuffleServerInfos.forEach(shuffleServerInfo -> this.shuffleServerFailureRecordCount.computeIfAbsent(shuffleServerInfo.getId(), k -> new AtomicInteger()).incrementAndGet());
                ArrayList<Map.Entry<String, AtomicInteger>> list = new ArrayList<Map.Entry<String, AtomicInteger>>(this.shuffleServerFailureRecordCount.entrySet());
                if (!list.isEmpty()) {
                    Collections.sort(list, (o1, o2) -> ((AtomicInteger)o1.getValue()).get() - ((AtomicInteger)o2.getValue()).get());
                    Map.Entry shuffleServerInfoIntegerEntry = (Map.Entry)list.get(0);
                    if (((AtomicInteger)shuffleServerInfoIntegerEntry.getValue()).get() > shuffleManager.getMaxFetchFailures()) {
                        shuffleManager.addFailuresShuffleServerInfos((String)shuffleServerInfoIntegerEntry.getKey());
                        return true;
                    }
                }
                return false;
            });
        }
    }
}

