package org.apache.spark.shuffle.writer;

import com.clearspring.analytics.util.Lists;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.serializer.SerializationStream;
import org.apache.spark.serializer.Serializer;
import org.apache.spark.serializer.SerializerInstance;
import org.apache.spark.shuffle.RssSparkConfig;
import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.com.google.common.annotations.VisibleForTesting;
import org.apache.uniffle.com.google.common.collect.Maps;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.compression.Codec;
import org.apache.uniffle.common.config.RssConf;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.ChecksumUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.reflect.ClassTag$;
import scala.reflect.ManifestFactory$;

/* loaded from: input_file:org/apache/spark/shuffle/writer/WriteBufferManager.class */
public class WriteBufferManager extends MemoryConsumer {
    private static final Logger LOG = LoggerFactory.getLogger(WriteBufferManager.class);
    private int bufferSize;
    private long spillSize;
    private AtomicLong allocatedBytes;
    private AtomicLong usedBytes;
    private AtomicLong inSendListBytes;
    private Map<Integer, Integer> partitionToSeqNo;
    private long askExecutorMemory;
    private int shuffleId;
    private String taskId;
    private long taskAttemptId;
    private SerializerInstance instance;
    private ShuffleWriteMetrics shuffleWriteMetrics;
    private Map<Integer, WriterBuffer> buffers;
    private Map<Integer, List<ShuffleServerInfo>> partitionToServers;
    private int serializerBufferSize;
    private int bufferSegmentSize;
    private long copyTime;
    private long serializeTime;
    private long compressTime;
    private long writeTime;
    private long estimateTime;
    private long requireMemoryTime;
    private SerializationStream serializeStream;
    private WrappedByteArrayOutputStream arrayOutputStream;
    private long uncompressedDataLen;
    private long requireMemoryInterval;
    private int requireMemoryRetryMax;
    private Codec codec;
    private Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> spillFunc;
    private long sendSizeLimit;
    private boolean memorySpillEnabled;
    private int memorySpillTimeoutSec;
    private boolean isRowBased;

    public WriteBufferManager(int i, long j, BufferManagerOptions bufferManagerOptions, Serializer serializer, Map<Integer, List<ShuffleServerInfo>> map, TaskMemoryManager taskMemoryManager, ShuffleWriteMetrics shuffleWriteMetrics, RssConf rssConf) {
        this(i, null, j, bufferManagerOptions, serializer, map, taskMemoryManager, shuffleWriteMetrics, rssConf, null);
    }

    public WriteBufferManager(int i, String str, long j, BufferManagerOptions bufferManagerOptions, Serializer serializer, Map<Integer, List<ShuffleServerInfo>> map, TaskMemoryManager taskMemoryManager, ShuffleWriteMetrics shuffleWriteMetrics, RssConf rssConf, Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> function) {
        super(taskMemoryManager, taskMemoryManager.pageSizeBytes(), MemoryMode.ON_HEAP);
        this.allocatedBytes = new AtomicLong(0L);
        this.usedBytes = new AtomicLong(0L);
        this.inSendListBytes = new AtomicLong(0L);
        this.partitionToSeqNo = Maps.newHashMap();
        this.copyTime = 0L;
        this.serializeTime = 0L;
        this.compressTime = 0L;
        this.writeTime = 0L;
        this.estimateTime = 0L;
        this.requireMemoryTime = 0L;
        this.uncompressedDataLen = 0L;
        this.bufferSize = bufferManagerOptions.getBufferSize();
        this.spillSize = bufferManagerOptions.getBufferSpillThreshold();
        this.buffers = Maps.newHashMap();
        this.shuffleId = i;
        this.taskId = str;
        this.taskAttemptId = j;
        this.partitionToServers = map;
        this.shuffleWriteMetrics = shuffleWriteMetrics;
        this.serializerBufferSize = bufferManagerOptions.getSerializerBufferSize();
        this.bufferSegmentSize = bufferManagerOptions.getBufferSegmentSize();
        this.askExecutorMemory = bufferManagerOptions.getPreAllocatedBufferSize();
        this.requireMemoryInterval = bufferManagerOptions.getRequireMemoryInterval();
        this.requireMemoryRetryMax = bufferManagerOptions.getRequireMemoryRetryMax();
        this.arrayOutputStream = new WrappedByteArrayOutputStream(this.serializerBufferSize);
        this.isRowBased = rssConf.getBoolean(RssSparkConfig.RSS_ROW_BASED);
        if (this.isRowBased) {
            this.instance = serializer.newInstance();
            this.serializeStream = this.instance.serializeStream(this.arrayOutputStream);
        }
        this.codec = rssConf.getBoolean(RssSparkConfig.SPARK_SHUFFLE_COMPRESS_KEY.substring(RssSparkConfig.SPARK_RSS_CONFIG_PREFIX.length()), true) ? Codec.newInstance(rssConf) : null;
        this.spillFunc = function;
        this.sendSizeLimit = ((Long) rssConf.get(RssSparkConfig.RSS_CLIENT_SEND_SIZE_LIMITATION)).longValue();
        this.memorySpillTimeoutSec = ((Integer) rssConf.get(RssSparkConfig.RSS_MEMORY_SPILL_TIMEOUT)).intValue();
        this.memorySpillEnabled = ((Boolean) rssConf.get(RssSparkConfig.RSS_MEMORY_SPILL_ENABLED)).booleanValue();
    }

    public List<ShuffleBlockInfo> addPartitionData(int i, byte[] bArr) {
        return addPartitionData(i, bArr, bArr.length, System.currentTimeMillis());
    }

    public List<ShuffleBlockInfo> addPartitionData(int i, byte[] bArr, int i2, long j) {
        List<ShuffleBlockInfo> insertIntoBuffer = insertIntoBuffer(i, bArr, i2);
        if (this.usedBytes.get() - this.inSendListBytes.get() > this.spillSize) {
            insertIntoBuffer.addAll(clear());
        }
        this.writeTime += System.currentTimeMillis() - j;
        return insertIntoBuffer;
    }

    private List<ShuffleBlockInfo> insertIntoBuffer(int i, byte[] bArr, int i2) {
        ArrayList arrayList = new ArrayList();
        long max = Math.max(this.bufferSegmentSize, i2);
        boolean z = false;
        if (this.buffers.containsKey(Integer.valueOf(i)) && this.buffers.get(Integer.valueOf(i)).askForMemory(i2)) {
            requestMemory(max);
            z = true;
        }
        if (this.buffers.containsKey(Integer.valueOf(i))) {
            if (z) {
                this.usedBytes.addAndGet(max);
            }
            WriterBuffer writerBuffer = this.buffers.get(Integer.valueOf(i));
            writerBuffer.addRecord(bArr, i2);
            if (writerBuffer.getMemoryUsed() > this.bufferSize) {
                arrayList.add(createShuffleBlock(i, writerBuffer));
                this.copyTime += writerBuffer.getCopyTime();
                this.buffers.remove(Integer.valueOf(i));
                LOG.debug("Single buffer is full for shuffleId[" + this.shuffleId + "] partition[" + i + "] with memoryUsed[" + writerBuffer.getMemoryUsed() + "], dataLength[" + writerBuffer.getDataLength() + "]");
            }
        } else {
            if (!z) {
                requestMemory(max);
            }
            this.usedBytes.addAndGet(max);
            WriterBuffer writerBuffer2 = new WriterBuffer(this.bufferSegmentSize);
            writerBuffer2.addRecord(bArr, i2);
            this.buffers.put(Integer.valueOf(i), writerBuffer2);
        }
        return arrayList;
    }

    public List<ShuffleBlockInfo> addRecord(int i, Object obj, Object obj2) {
        long currentTimeMillis = System.currentTimeMillis();
        this.arrayOutputStream.reset();
        if (obj != null) {
            this.serializeStream.writeKey(obj, ClassTag$.MODULE$.apply(obj.getClass()));
        } else {
            this.serializeStream.writeKey((Object) null, ManifestFactory$.MODULE$.Null());
        }
        if (obj2 != null) {
            this.serializeStream.writeValue(obj2, ClassTag$.MODULE$.apply(obj2.getClass()));
        } else {
            this.serializeStream.writeValue((Object) null, ManifestFactory$.MODULE$.Null());
        }
        this.serializeStream.flush();
        this.serializeTime += System.currentTimeMillis() - currentTimeMillis;
        byte[] buf = this.arrayOutputStream.getBuf();
        int size = this.arrayOutputStream.size();
        if (size == 0) {
            return null;
        }
        List<ShuffleBlockInfo> addPartitionData = addPartitionData(i, buf, size, currentTimeMillis);
        if (this.isRowBased) {
            this.shuffleWriteMetrics.incRecordsWritten(1L);
        }
        return addPartitionData;
    }

    public synchronized List<ShuffleBlockInfo> clear() {
        List<ShuffleBlockInfo> newArrayList = Lists.newArrayList();
        long j = 0;
        long j2 = 0;
        Iterator<Map.Entry<Integer, WriterBuffer>> it = this.buffers.entrySet().iterator();
        while (it.hasNext()) {
            Map.Entry<Integer, WriterBuffer> next = it.next();
            WriterBuffer value = next.getValue();
            j += value.getDataLength();
            j2 += value.getMemoryUsed();
            newArrayList.add(createShuffleBlock(next.getKey().intValue(), value));
            it.remove();
            this.copyTime += value.getCopyTime();
        }
        LOG.info("Flush total buffer for shuffleId[" + this.shuffleId + "] with allocated[" + this.allocatedBytes + "], dataSize[" + j + "], memoryUsed[" + j2 + "]");
        return newArrayList;
    }

    protected ShuffleBlockInfo createShuffleBlock(int i, WriterBuffer writerBuffer) {
        byte[] data = writerBuffer.getData();
        int length = data.length;
        byte[] bArr = data;
        if (this.codec != null) {
            long currentTimeMillis = System.currentTimeMillis();
            bArr = this.codec.compress(data);
            this.compressTime += System.currentTimeMillis() - currentTimeMillis;
        }
        long crc32 = ChecksumUtils.getCrc32(bArr);
        long longValue = ClientUtils.getBlockId(i, this.taskAttemptId, getNextSeqNo(i)).longValue();
        this.uncompressedDataLen += data.length;
        this.shuffleWriteMetrics.incBytesWritten(bArr.length);
        this.inSendListBytes.addAndGet(writerBuffer.getMemoryUsed());
        return new ShuffleBlockInfo(this.shuffleId, i, longValue, bArr.length, crc32, bArr, this.partitionToServers.get(Integer.valueOf(i)), length, writerBuffer.getMemoryUsed(), this.taskAttemptId);
    }

    private int getNextSeqNo(int i) {
        this.partitionToSeqNo.putIfAbsent(Integer.valueOf(i), 0);
        int intValue = this.partitionToSeqNo.get(Integer.valueOf(i)).intValue();
        this.partitionToSeqNo.put(Integer.valueOf(i), Integer.valueOf(intValue + 1));
        return intValue;
    }

    private void requestMemory(long j) {
        long currentTimeMillis = System.currentTimeMillis();
        if (this.allocatedBytes.get() - this.usedBytes.get() < j) {
            requestExecutorMemory(j);
        }
        this.requireMemoryTime += System.currentTimeMillis() - currentTimeMillis;
    }

    private void requestExecutorMemory(long j) {
        long acquireMemory = acquireMemory(this.askExecutorMemory);
        this.allocatedBytes.addAndGet(acquireMemory);
        int i = 0;
        while (this.allocatedBytes.get() - this.usedBytes.get() < j) {
            LOG.info("Can't get memory for now, sleep and try[" + i + "] again, request[" + this.askExecutorMemory + "], got[" + acquireMemory + "] less than " + j);
            try {
                Thread.sleep(this.requireMemoryInterval);
                acquireMemory = acquireMemory(this.askExecutorMemory);
                this.allocatedBytes.addAndGet(acquireMemory);
                i++;
                if (i > this.requireMemoryRetryMax) {
                    String str = "Can't get memory to cache shuffle data, request[" + this.askExecutorMemory + "], got[" + acquireMemory + "], WriteBufferManager allocated[" + this.allocatedBytes + "] task used[" + this.used + "]. It may be caused by shuffle server is full of data or consider to optimize 'spark.executor.memory', 'spark.rss.writer.buffer.spill.size'.";
                    LOG.error(str);
                    throw new RssException(str);
                }
            } catch (InterruptedException e) {
                throw new RssException("Interrupted when waiting for memory.", e);
            }
        }
    }

    public List<AddBlockEvent> buildBlockEvents(List<ShuffleBlockInfo> list) {
        long j = 0;
        long j2 = 0;
        ArrayList arrayList = new ArrayList();
        List newArrayList = Lists.newArrayList();
        for (ShuffleBlockInfo shuffleBlockInfo : list) {
            j += shuffleBlockInfo.getSize();
            j2 += shuffleBlockInfo.getFreeMemory();
            newArrayList.add(shuffleBlockInfo);
            if (j > this.sendSizeLimit) {
                LOG.debug("Build event with " + newArrayList.size() + " blocks and " + j + " bytes");
                List list2 = newArrayList;
                arrayList.add(new AddBlockEvent(this.taskId, newArrayList, () -> {
                    freeAllocatedMemory(j2);
                    list2.stream().forEach(shuffleBlockInfo2 -> {
                        shuffleBlockInfo2.getData().release();
                    });
                }));
                newArrayList = Lists.newArrayList();
                j = 0;
                j2 = 0;
            }
        }
        if (!newArrayList.isEmpty()) {
            LOG.debug("Build event with " + newArrayList.size() + " blocks and " + j + " bytes");
            long j3 = j2;
            List list3 = newArrayList;
            arrayList.add(new AddBlockEvent(this.taskId, newArrayList, () -> {
                freeAllocatedMemory(j3);
                list3.stream().forEach(shuffleBlockInfo2 -> {
                    shuffleBlockInfo2.getData().release();
                });
            }));
        }
        return arrayList;
    }

    public long spill(long j, MemoryConsumer memoryConsumer) {
        if (!this.memorySpillEnabled || memoryConsumer != this) {
            return 0L;
        }
        List<CompletableFuture<Long>> apply = this.spillFunc.apply(clear());
        try {
            CompletableFuture.allOf((CompletableFuture[]) apply.toArray(new CompletableFuture[apply.size()])).get(this.memorySpillTimeoutSec, TimeUnit.SECONDS);
            long sum = apply.stream().filter(completableFuture -> {
                return completableFuture.isDone();
            }).mapToLong(completableFuture2 -> {
                try {
                    return ((Long) completableFuture2.get()).longValue();
                } catch (Exception e) {
                    return 0L;
                }
            }).sum();
            LOG.info("[taskId: {}] Spill triggered by own, released memory size: {}", this.taskId, Long.valueOf(sum));
            return sum;
        } catch (TimeoutException e) {
            long sum2 = apply.stream().filter(completableFuture3 -> {
                return completableFuture3.isDone();
            }).mapToLong(completableFuture22 -> {
                try {
                    return ((Long) completableFuture22.get()).longValue();
                } catch (Exception e2) {
                    return 0L;
                }
            }).sum();
            LOG.info("[taskId: {}] Spill triggered by own, released memory size: {}", this.taskId, Long.valueOf(sum2));
            return sum2;
        } catch (Throwable th) {
            long sum3 = apply.stream().filter(completableFuture32 -> {
                return completableFuture32.isDone();
            }).mapToLong(completableFuture222 -> {
                try {
                    return ((Long) completableFuture222.get()).longValue();
                } catch (Exception e2) {
                    return 0L;
                }
            }).sum();
            LOG.info("[taskId: {}] Spill triggered by own, released memory size: {}", this.taskId, Long.valueOf(sum3));
            return sum3;
        }
    }

    @VisibleForTesting
    protected long getAllocatedBytes() {
        return this.allocatedBytes.get();
    }

    @VisibleForTesting
    protected long getUsedBytes() {
        return this.usedBytes.get();
    }

    @VisibleForTesting
    protected long getInSendListBytes() {
        return this.inSendListBytes.get();
    }

    public void freeAllocatedMemory(long j) {
        freeMemory(j);
        this.allocatedBytes.addAndGet(-j);
        this.usedBytes.addAndGet(-j);
        this.inSendListBytes.addAndGet(-j);
    }

    public void freeAllMemory() {
        long j = this.allocatedBytes.get();
        if (j > 0) {
            freeMemory(j);
        }
    }

    @VisibleForTesting
    protected Map<Integer, WriterBuffer> getBuffers() {
        return this.buffers;
    }

    @VisibleForTesting
    protected ShuffleWriteMetrics getShuffleWriteMetrics() {
        return this.shuffleWriteMetrics;
    }

    @VisibleForTesting
    protected void setShuffleWriteMetrics(ShuffleWriteMetrics shuffleWriteMetrics) {
        this.shuffleWriteMetrics = shuffleWriteMetrics;
    }

    public long getWriteTime() {
        return this.writeTime;
    }

    public String getManagerCostInfo() {
        return "WriteBufferManager cost copyTime[" + this.copyTime + "], writeTime[" + this.writeTime + "], serializeTime[" + this.serializeTime + "], compressTime[" + this.compressTime + "], estimateTime[" + this.estimateTime + "], requireMemoryTime[" + this.requireMemoryTime + "], uncompressedDataLen[" + this.uncompressedDataLen + "]";
    }

    @VisibleForTesting
    public void setTaskId(String str) {
        this.taskId = str;
    }

    @VisibleForTesting
    public void setSpillFunc(Function<List<ShuffleBlockInfo>, List<CompletableFuture<Long>>> function) {
        this.spillFunc = function;
    }

    @VisibleForTesting
    public void setSendSizeLimit(long j) {
        this.sendSizeLimit = j;
    }
}
