package org.tensorflow.framework.constraints;

import org.tensorflow.Operand;
import org.tensorflow.framework.utils.CastHelper;
import org.tensorflow.op.Ops;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/framework/constraints/UnitNorm.class */
public class UnitNorm extends Constraint {
    public static final int AXIS_DEFAULT = 0;
    private final int[] axes;

    public UnitNorm(Ops ops) {
        this(ops, 0);
    }

    public UnitNorm(Ops ops, int i) {
        this(ops, new int[]{i});
    }

    public UnitNorm(Ops ops, int[] iArr) {
        super(ops);
        this.axes = iArr;
    }

    @Override // org.tensorflow.framework.constraints.Constraint
    public <T extends TNumber> Operand<T> call(Operand<T> operand) {
        Class type = operand.type();
        Ops tf = getTF();
        return tf.math.div(operand, tf.math.add(CastHelper.cast(tf, tf.constant(1.0E-7f), type), norm(operand, getAxes())));
    }

    public int[] getAxes() {
        return this.axes;
    }
}
