/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality.rl.agent;

import ai.djl.modality.rl.agent.RlAgent;
import ai.djl.modality.rl.env.RlEnv;
import ai.djl.ndarray.NDList;
import ai.djl.training.tracker.Tracker;
import ai.djl.util.RandomUtils;

public class EpsilonGreedy
implements RlAgent {
    private RlAgent baseAgent;
    private Tracker exploreRate;
    private int counter;

    public EpsilonGreedy(RlAgent baseAgent, Tracker exploreRate) {
        this.baseAgent = baseAgent;
        this.exploreRate = exploreRate;
    }

    @Override
    public NDList chooseAction(RlEnv env, boolean training) {
        if (training) {
            int n = this.counter++;
            if (RandomUtils.random() < (double)this.exploreRate.getNewValue(n)) {
                return env.getActionSpace().randomAction();
            }
        }
        return this.baseAgent.chooseAction(env, training);
    }

    @Override
    public void trainBatch(RlEnv.Step[] batchSteps) {
        this.baseAgent.trainBatch(batchSteps);
    }
}

