package org.apache.flink.runtime.operators;

import java.util.ArrayList;
import java.util.Iterator;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.functions.RichGroupReduceFunction;
import org.apache.flink.api.common.typeutils.TypeComparator;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.IntComparator;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.runtime.TupleComparator;
import org.apache.flink.api.java.typeutils.runtime.TupleSerializer;
import org.apache.flink.runtime.operators.testutils.DelayingIterator;
import org.apache.flink.runtime.operators.testutils.DiscardingOutputCollector;
import org.apache.flink.runtime.operators.testutils.ExpectedTestException;
import org.apache.flink.runtime.operators.testutils.InfiniteIntTupleIterator;
import org.apache.flink.runtime.operators.testutils.UnaryOperatorTestBase;
import org.apache.flink.runtime.operators.testutils.UniformIntTupleGenerator;
import org.apache.flink.util.Collector;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/flink/runtime/operators/CombineTaskTest.class */
public class CombineTaskTest extends UnaryOperatorTestBase<RichGroupReduceFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>>, Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> {
    private static final long COMBINE_MEM = 3145728;
    private final double combine_frac;
    private final ArrayList<Tuple2<Integer, Integer>> outList;
    private final TypeSerializer<Tuple2<Integer, Integer>> serializer;
    private final TypeComparator<Tuple2<Integer, Integer>> comparator;

    @RichGroupReduceFunction.Combinable
    /* loaded from: input_file:org/apache/flink/runtime/operators/CombineTaskTest$MockCombiningReduceStub.class */
    public static class MockCombiningReduceStub extends RichGroupReduceFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> {
        private static final long serialVersionUID = 1;

        public void reduce(Iterable<Tuple2<Integer, Integer>> iterable, Collector<Tuple2<Integer, Integer>> collector) {
            int i = 0;
            int i2 = 0;
            for (Tuple2<Integer, Integer> tuple2 : iterable) {
                i = ((Integer) tuple2.f0).intValue();
                i2 += ((Integer) tuple2.f1).intValue();
            }
            collector.collect(new Tuple2(Integer.valueOf(i), Integer.valueOf(i2)));
        }

        public void combine(Iterable<Tuple2<Integer, Integer>> iterable, Collector<Tuple2<Integer, Integer>> collector) {
            reduce(iterable, collector);
        }
    }

    @RichGroupReduceFunction.Combinable
    /* loaded from: input_file:org/apache/flink/runtime/operators/CombineTaskTest$MockFailingCombiningReduceStub.class */
    public static final class MockFailingCombiningReduceStub extends RichGroupReduceFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> {
        private static final long serialVersionUID = 1;
        private int cnt;

        public void reduce(Iterable<Tuple2<Integer, Integer>> iterable, Collector<Tuple2<Integer, Integer>> collector) {
            int i = 0;
            int i2 = 0;
            for (Tuple2<Integer, Integer> tuple2 : iterable) {
                i = ((Integer) tuple2.f0).intValue();
                i2 += ((Integer) tuple2.f1).intValue();
            }
            collector.collect(new Tuple2(Integer.valueOf(i), Integer.valueOf(i2 - i)));
        }

        public void combine(Iterable<Tuple2<Integer, Integer>> iterable, Collector<Tuple2<Integer, Integer>> collector) {
            int i = 0;
            int i2 = 0;
            for (Tuple2<Integer, Integer> tuple2 : iterable) {
                i = ((Integer) tuple2.f0).intValue();
                i2 += ((Integer) tuple2.f1).intValue();
            }
            int i3 = this.cnt + 1;
            this.cnt = i3;
            if (i3 >= 10) {
                throw new ExpectedTestException();
            }
            collector.collect(new Tuple2(Integer.valueOf(i), Integer.valueOf(i2 - i)));
        }
    }

    public CombineTaskTest(ExecutionConfig executionConfig) {
        super(executionConfig, COMBINE_MEM, 0);
        this.outList = new ArrayList<>();
        this.serializer = new TupleSerializer(Tuple2.class, new TypeSerializer[]{IntSerializer.INSTANCE, IntSerializer.INSTANCE});
        this.comparator = new TupleComparator(new int[]{0}, new TypeComparator[]{new IntComparator(true)}, new TypeSerializer[]{IntSerializer.INSTANCE});
        this.combine_frac = 3145728.0d / getMemoryManager().getMemorySize();
    }

    @Test
    public void testCombineTask() {
        try {
            setInput(new UniformIntTupleGenerator(100, 20, false), this.serializer);
            addDriverComparator(this.comparator);
            addDriverComparator(this.comparator);
            setOutput(this.outList, this.serializer);
            getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_COMBINE);
            getTaskConfig().setRelativeMemoryDriver(this.combine_frac);
            getTaskConfig().setFilehandlesDriver(2);
            testDriver(new GroupReduceCombineDriver(), MockCombiningReduceStub.class);
            int i = 0;
            for (int i2 = 1; i2 < 20; i2++) {
                i += i2;
            }
            Assert.assertTrue(this.outList.size() == 100);
            Iterator<Tuple2<Integer, Integer>> it = this.outList.iterator();
            while (it.hasNext()) {
                Assert.assertTrue(((Integer) it.next().f1).intValue() == i);
            }
            this.outList.clear();
        } catch (Exception e) {
            e.printStackTrace();
            Assert.fail(e.getMessage());
        }
    }

    @Test
    public void testFailingCombineTask() {
        try {
            setInput(new UniformIntTupleGenerator(100, 20, false), this.serializer);
            addDriverComparator(this.comparator);
            addDriverComparator(this.comparator);
            setOutput(new DiscardingOutputCollector());
            getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_COMBINE);
            getTaskConfig().setRelativeMemoryDriver(this.combine_frac);
            getTaskConfig().setFilehandlesDriver(2);
            try {
                testDriver(new GroupReduceCombineDriver(), MockFailingCombiningReduceStub.class);
                Assert.fail("Exception not forwarded.");
            } catch (ExpectedTestException e) {
            }
        } catch (Exception e2) {
            e2.printStackTrace();
            Assert.fail(e2.getMessage());
        }
    }

    @Test
    public void testCancelCombineTaskSorting() {
        try {
            setInput(new DelayingIterator(new InfiniteIntTupleIterator(), 1), this.serializer);
            addDriverComparator(this.comparator);
            addDriverComparator(this.comparator);
            setOutput(new DiscardingOutputCollector());
            getTaskConfig().setDriverStrategy(DriverStrategy.SORTED_GROUP_COMBINE);
            getTaskConfig().setRelativeMemoryDriver(this.combine_frac);
            getTaskConfig().setFilehandlesDriver(2);
            final GroupReduceCombineDriver groupReduceCombineDriver = new GroupReduceCombineDriver();
            Thread thread = new Thread() { // from class: org.apache.flink.runtime.operators.CombineTaskTest.1
                @Override // java.lang.Thread, java.lang.Runnable
                public void run() {
                    try {
                        CombineTaskTest.this.testDriver(groupReduceCombineDriver, MockFailingCombiningReduceStub.class);
                    } catch (Exception e) {
                    }
                }
            };
            thread.start();
            Thread.sleep(500L);
            groupReduceCombineDriver.cancel();
            long currentTimeMillis = System.currentTimeMillis() + 10000;
            do {
                thread.interrupt();
                thread.join(5000L);
                if (!thread.isAlive()) {
                    break;
                }
            } while (System.currentTimeMillis() < currentTimeMillis);
            Assert.assertFalse("Task did not cancel properly within in 10 seconds.", thread.isAlive());
        } catch (Exception e) {
            e.printStackTrace();
            Assert.fail(e.getMessage());
        }
    }
}
