package org.apache.mahout.math.stats;

import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.RandomWrapper;
import org.apache.mahout.math.stats.TDigest;
import org.junit.Assert;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/math/stats/GroupTreeTest.class */
public class GroupTreeTest {
    @Test
    public void testSimpleAdds() {
        GroupTree groupTree = new GroupTree();
        Assert.assertNull(groupTree.floor(new TDigest.Group(34.0d)));
        Assert.assertNull(groupTree.ceiling(new TDigest.Group(34.0d)));
        Assert.assertEquals(0L, groupTree.size());
        Assert.assertEquals(0L, groupTree.sum());
        groupTree.add(new TDigest.Group(1.0d));
        TDigest.Group group = new TDigest.Group(2.0d);
        group.add(3.0d, 1);
        group.add(4.0d, 1);
        groupTree.add(group);
        Assert.assertEquals(2L, groupTree.size());
        Assert.assertEquals(4L, groupTree.sum());
    }

    @Test
    public void testBalancing() {
        GroupTree groupTree = new GroupTree();
        for (int i = 0; i < 101; i++) {
            groupTree.add(new TDigest.Group(i));
        }
        Assert.assertEquals(101L, groupTree.sum());
        Assert.assertEquals(101L, groupTree.size());
        groupTree.checkBalance();
    }

    @Test
    public void testIterators() {
        GroupTree groupTree = new GroupTree();
        for (int i = 0; i < 101; i++) {
            groupTree.add(new TDigest.Group(i / 2));
        }
        Assert.assertEquals(0.0d, groupTree.first().mean(), 0.0d);
        Assert.assertEquals(50.0d, groupTree.last().mean(), 0.0d);
        Iterator it = groupTree.iterator();
        for (int i2 = 0; i2 < 101; i2++) {
            Assert.assertTrue(it.hasNext());
            Assert.assertEquals(i2 / 2, ((TDigest.Group) it.next()).mean(), 0.0d);
        }
        Assert.assertFalse(it.hasNext());
        Iterable tailSet = groupTree.tailSet(new TDigest.Group(34.0d, 0));
        Iterator it2 = tailSet.iterator();
        for (int i3 = 68; i3 < 101; i3++) {
            Assert.assertTrue(it2.hasNext());
            Assert.assertEquals(i3 / 2, ((TDigest.Group) it2.next()).mean(), 0.0d);
        }
        Assert.assertFalse(it2.hasNext());
        Iterator it3 = tailSet.iterator();
        for (int i4 = 68; i4 < 101; i4++) {
            Assert.assertEquals(i4 / 2, ((TDigest.Group) it3.next()).mean(), 0.0d);
        }
        Iterator it4 = groupTree.tailSet(new TDigest.Group(33.0d, 0)).iterator();
        for (int i5 = 66; i5 < 101; i5++) {
            Assert.assertTrue(it4.hasNext());
            Assert.assertEquals(i5 / 2, ((TDigest.Group) it4.next()).mean(), 0.0d);
        }
        Assert.assertFalse(it4.hasNext());
        Iterator it5 = groupTree.tailSet(groupTree.ceiling(new TDigest.Group(34.0d, 0))).iterator();
        for (int i6 = 68; i6 < 101; i6++) {
            Assert.assertTrue(it5.hasNext());
            Assert.assertEquals(i6 / 2, ((TDigest.Group) it5.next()).mean(), 0.0d);
        }
        Assert.assertFalse(it5.hasNext());
        Iterator it6 = groupTree.tailSet(groupTree.floor(new TDigest.Group(34.0d, 0))).iterator();
        for (int i7 = 67; i7 < 101; i7++) {
            Assert.assertTrue(it6.hasNext());
            Assert.assertEquals(i7 / 2, ((TDigest.Group) it6.next()).mean(), 0.0d);
        }
        Assert.assertFalse(it6.hasNext());
    }

    @Test
    public void testFloor() {
        GroupTree groupTree = new GroupTree();
        for (int i = 0; i < 101; i++) {
            groupTree.add(new TDigest.Group(i / 2));
        }
        Assert.assertNull(groupTree.floor(new TDigest.Group(-30.0d)));
    }

    @Test
    public void testRemoveAndSums() {
        GroupTree groupTree = new GroupTree();
        for (int i = 0; i < 101; i++) {
            groupTree.add(new TDigest.Group(i / 2));
        }
        TDigest.Group ceiling = groupTree.ceiling(new TDigest.Group(2.0d, 0));
        groupTree.remove(ceiling);
        ceiling.add(3.0d, 1);
        groupTree.add(ceiling);
        Assert.assertEquals(0L, groupTree.headCount(new TDigest.Group(-1.0d)));
        Assert.assertEquals(0L, groupTree.headSum(new TDigest.Group(-1.0d)));
        Assert.assertEquals(0L, groupTree.headCount(new TDigest.Group(0.0d, 0)));
        Assert.assertEquals(0L, groupTree.headSum(new TDigest.Group(0.0d, 0)));
        Assert.assertEquals(0L, groupTree.headCount(groupTree.ceiling(new TDigest.Group(0.0d, 0))));
        Assert.assertEquals(0L, groupTree.headSum(groupTree.ceiling(new TDigest.Group(0.0d, 0))));
        Assert.assertEquals(2L, groupTree.headCount(new TDigest.Group(1.0d, 0)));
        Assert.assertEquals(2L, groupTree.headSum(new TDigest.Group(1.0d, 0)));
        Assert.assertEquals(2.5d, ((TDigest.Group) groupTree.tailSet(new TDigest.Group(2.1d)).iterator().next()).mean(), 1.0E-9d);
        int i2 = 0;
        Iterator it = groupTree.iterator();
        while (it.hasNext()) {
            TDigest.Group group = (TDigest.Group) it.next();
            if (i2 > 10) {
                break;
            }
            int i3 = i2;
            i2++;
            System.out.printf("%d:%.1f(%d)\t", Integer.valueOf(i3), Double.valueOf(group.mean()), Integer.valueOf(group.count()));
        }
        Assert.assertEquals(5L, groupTree.headCount(new TDigest.Group(2.1d, 0)));
        Assert.assertEquals(5L, groupTree.headSum(new TDigest.Group(2.1d, 0)));
        Assert.assertEquals(6L, groupTree.headCount(new TDigest.Group(2.7d, 0)));
        Assert.assertEquals(7L, groupTree.headSum(new TDigest.Group(2.7d, 0)));
        Assert.assertEquals(101L, groupTree.headCount(new TDigest.Group(200.0d)));
        Assert.assertEquals(102L, groupTree.headSum(new TDigest.Group(200.0d)));
    }

    @Test
    public void testRandomRebalance() {
        RandomUtils.useTestSeed();
        RandomWrapper random = RandomUtils.getRandom();
        GroupTree groupTree = new GroupTree();
        ArrayList newArrayList = Lists.newArrayList();
        for (int i = 0; i < 1000; i++) {
            double nextDouble = random.nextDouble();
            groupTree.add(new TDigest.Group(nextDouble));
            newArrayList.add(Double.valueOf(nextDouble));
            groupTree.checkBalance();
        }
        Collections.sort(newArrayList);
        Iterator it = newArrayList.iterator();
        Iterator it2 = groupTree.iterator();
        while (it2.hasNext()) {
            Assert.assertEquals(((Double) it.next()).doubleValue(), ((TDigest.Group) it2.next()).mean(), 0.0d);
        }
        for (int i2 = 0; i2 < 100; i2++) {
            double doubleValue = ((Double) newArrayList.get(random.nextInt(newArrayList.size()))).doubleValue();
            newArrayList.remove(Double.valueOf(doubleValue));
            groupTree.remove(groupTree.floor(new TDigest.Group(doubleValue)));
        }
        Collections.sort(newArrayList);
        Iterator it3 = newArrayList.iterator();
        Iterator it4 = groupTree.iterator();
        while (it4.hasNext()) {
            Assert.assertEquals(((Double) it3.next()).doubleValue(), ((TDigest.Group) it4.next()).mean(), 0.0d);
        }
        for (int i3 = 0; i3 < newArrayList.size(); i3++) {
            double doubleValue2 = ((Double) newArrayList.get(i3)).doubleValue();
            newArrayList.set(i3, Double.valueOf(doubleValue2 + 10.0d));
            TDigest.Group floor = groupTree.floor(new TDigest.Group(doubleValue2));
            groupTree.remove(floor);
            groupTree.checkBalance();
            floor.add(floor.mean() + 20.0d, 1);
            groupTree.add(floor);
            groupTree.checkBalance();
        }
        Iterator it5 = newArrayList.iterator();
        Iterator it6 = groupTree.iterator();
        while (it6.hasNext()) {
            Assert.assertEquals(((Double) it5.next()).doubleValue(), ((TDigest.Group) it6.next()).mean(), 0.0d);
        }
    }
}
