package ie.dcu.eval; import static java.lang.Math.sqrt; /** * Computes a variety of statistical measures and indices to evaluate binary * classifiers based on a confusion matrix computed against a reference. * * @author Kevin McGuinness */ public class BinaryClassifierEvaluation { public final double tp; public final double fp; public final double tn; public final double fn; public final double total; public BinaryClassifierEvaluation(int[][] c) { this(c[0][0], c[0][1], c[1][0], c[1][1]); } public BinaryClassifierEvaluation(long[][] c) { this(c[0][0], c[0][1], c[1][0], c[1][1]); } public BinaryClassifierEvaluation(long tp, long fn, long fp, long tn) { check(this.tp = tp); check(this.fp = fp); check(this.tn = tn); check(this.fn = fn); this.total = tp + fp + tn + fn; } public final double getTotal() { return total; } public final double getPositiveReference() { return tp + fn; } public final double getNegativeReference() { return tn + fp; } public final double getPositiveResponse() { return tp + fp; } public final double getNegativeResponse() { return tn + fn; } public final double getCorrectResponse() { return tp + tn; } public final double getIncorrectResponse() { return fp + fn; } public final double getReferenceLikelihood() { return (tp + fn) / total; } public final double getResponseLikelihood() { return (tp + fp) / total; } public final double getAccuracy() { return (tp + tn) / total; } public final double getRecall() { return tp / (tp + fn); } public final double getPrecision() { return tp / (tp + fp); } public final double getRejectionRecall() { return tn / (tn + fp); } public final double getRejectionPrecision() { return tn / (tn + fn); } public final double getFMeasure() { return getFMeasure(1.0); } public final double getFMeasure(double beta) { double b = beta * beta, p = getPrecision(), r = getRecall(); return (1.0 + b) * p * r / (r + (b * p)); } public final double getRandIndex() { // Same as accuracy in the binary case return getAccuracy(); } public final double getJaccardIndex() { // Jaccard similarity coefficient J return tp / (tp + fn + fp); } public final double getJaccardDistance() { // Jaccard distance J' return (fp + fn) / (fp + fn + tp); } public final double getFowlkesMallowsIndex() { // Fowlkes-Mallows 83 - A method for comparing 2 hierarchical clusterings double tk = sq(tp) + sq(fn) + sq(fp) + sq(tn) - total; double pk = sq(tp + fp) + sq(fn + tn) - total; double qk = sq(tp + fn) + sq(fp + tn) - total; return tk / sqrt(pk*qk); } public final double getChiSquared() { double d = tp * tn - fp * fn; return total * d * d / ((tp + fn) * (fp + tn) * (tp + fp) * (fn + tn)); } public final double getPhiSquared() { return getChiSquared() / total; } public final double getYulesQ() { return (tp * tn - fp * fn) / (tp * tn + fp * fn); } public final double getYulesY() { return (sqrt(tp * tn) - sqrt(fp * fn)) / (sqrt(tp * tn) + sqrt(fp * fn)); } public final double getAccuracyDeviation() { double a = getAccuracy(); return sqrt(a * (1.0 - a) / total); } public final double getRandomAccuracy() { double ref = (tp + fn) / total; double res = (tp + fp) / total; return ref * res + (1.0 - ref) * (1.0 - res); } public final double getRandomAccuracyUnbiased() { double avg = tp / total + (fn + fp) / (2.0 * total); return avg * avg + (1.0 - avg) * (1.0 - avg); } public final double getKappa() { return kappa(getAccuracy(), getRandomAccuracy()); } public final double getKappaUnbiased() { return kappa(getAccuracy(), getRandomAccuracyUnbiased()); } public final double getKappaNoPrevalence() { return 2.0 * getAccuracy() - 1.0; } private static double kappa(double observed, double expected) { return (observed - expected) / (1 - expected); } private static double sq(double x) { return x*x; } private static void check(double v) { if (v < 0) { throw new IllegalArgumentException(); } } }