package ie.dcu.stats; import java.awt.*; import java.awt.geom.*; import javax.swing.*; import ie.dcu.array.Arrays; import static java.awt.RenderingHints.*; /** * Implementation of the inversion method for simulating random variables. * * @author Kevin McGuinness */ public final class InversionMethod { /** * The probability distribution function */ private final double[] pdf; /** * The cumulative distribution function */ private final double[] cdf; /** * Initialize the inversion method with the given discrete probability * distribution function. The distribution function will be normalized the * sum of all it's elements is one. * * The values of the pdf function must be all greater than 0 and contain at * least one positive value. * * @param pdf * A probability distribution function. */ public InversionMethod(double[] pdf) { this.pdf = normalize(check(pdf)); this.cdf = cumsum(this.pdf); } /** * Returns the distribution size */ public final int size() { return this.pdf.length; } /** * Extract a random variable from to the distribution. * * The returned value is in the range [0..n) where n is the size of the * distribution. * * Complexity is O(log(n)). * * @return A random variable */ public final int random() { return lbound(cdf, Math.random()); } /** * Returns the arg min{x : array[x] >= value} * * Complexity is O(log(n)). */ private static int lbound(double[] array, double value) { int lo = 0; int hi = array.length - 1; while (lo <= hi) { int mid = (lo + hi) >> 1; double x = array[mid]; if (x < value) { lo = mid + 1; } else if (x > value) { hi = mid - 1; } else { return mid; } } return lo; } /** * Check the values of the pdf function are ok. */ private static double[] check(double[] pdf) { double min = Arrays.min(pdf); if (min < 0) { throw new IllegalArgumentException("pdf contains values < 0"); } double max = Arrays.max(pdf); if (max == 0) { throw new IllegalArgumentException("pdf contains no nonzero values"); } return pdf; } /** * Normalize the pdf function so that it sums to 1. */ private static double[] normalize(double[] pdf) { double[] result = new double[pdf.length]; double sum = Arrays.sum(pdf); for (int i = 0; i < pdf.length; i++) { result[i] = pdf[i] / sum; } return result; } /** * Calculate the cumulative sum of the array. */ private static double[] cumsum(double[] array) { double[] result = new double[array.length]; double sum = 0.0; for (int i = 0; i < array.length; i++) { sum += array[i]; result[i] = sum; } return result; } public static void main(String[] args) { double[] pdf = normal(5); //{1,2,3,4,5,4,5,3,2,1}; InversionMethod var = new InversionMethod(pdf); int[] variables = new int[10000]; for (int i = 0; i < variables.length; i++) { variables[i] = var.random(); } plothist("Normal Distribution", variables); } public static double[] normal(int n) { double[] result = new double[2*n+1]; double sigma = n / 3.0; System.out.println(sigma); double scalef = 1.0 / (Math.sqrt(2.0*Math.PI) * sigma); double denom = 2*sigma*sigma; for (int i = -n; i <= n; i++) { result[i+n] = scalef * Math.exp(-(i*i)/denom); } return result; } private static JFrame plothist(String title, int[] values) { final int[] hist = histogram(values); final Color[] colors = new Color[hist.length]; for (int i = 0; i < colors.length; i++) { float r = (float) Math.random(); float g = (float) Math.random(); float b = (float) Math.random(); colors[i] = new Color(r,g,b); } JFrame frame = new JFrame(title); JPanel panel = new JPanel() { private static final long serialVersionUID = 1L; public void paint(Graphics graphics) { Graphics2D g = (Graphics2D) graphics; FontMetrics fontMetrics = g.getFontMetrics(); double w = getWidth(); double h = getHeight(); double spaceAtBottom = fontMetrics.getHeight() + 4; double spaceAtTop = 3; double spaceBetweenBars = 5; g.setRenderingHint(KEY_ANTIALIASING, VALUE_ANTIALIAS_ON); g.setRenderingHint(KEY_TEXT_ANTIALIASING, VALUE_TEXT_ANTIALIAS_ON); g.setRenderingHint(KEY_FRACTIONALMETRICS, VALUE_FRACTIONALMETRICS_ON); g.setColor(Color.white); g.fill(new Rectangle2D.Double(0,0,w,h)); g.setColor(Color.black); g.draw(new Line2D.Double(0,h-spaceAtBottom,w,h-spaceAtBottom)); double hmax = Arrays.max(hist); double hcount = hist.length; double space = w / hcount; Rectangle2D.Double bar = new Rectangle2D.Double(0,0,0,0); for (int i = 0; i < hist.length; i++) { bar.x = i * space + spaceBetweenBars / 2; bar.width = space - spaceBetweenBars; bar.height = (h-(spaceAtBottom+spaceAtTop)) * (hist[i] / hmax); bar.y = h - bar.height - spaceAtBottom; g.setColor(colors[i]); g.fill(bar); g.setColor(Color.black); g.draw(bar); String text = String.valueOf(hist[i]); Rectangle2D bounds = fontMetrics.getStringBounds(text, g); double tx = bar.getCenterX() - bounds.getWidth() / 2.0; double ty = bar.getMinY() + bounds.getHeight() + 1; if (ty > h - spaceAtBottom - 1) { // place over the bar ty = bar.getMinY() - 2; } g.drawString(text, (int) tx, (int) ty); text = String.valueOf(i); bounds = fontMetrics.getStringBounds(text, g); tx = bar.getCenterX() - bounds.getWidth() / 2.0; ty = bar.getMaxY() + bounds.getHeight(); g.drawString(text, (int) tx, (int) ty); } } public Dimension getPreferredSize() { Dimension screenSize = Toolkit.getDefaultToolkit().getScreenSize(); int w = 50 * hist.length; int h = 400; w = Math.min(Math.max(w, 100), screenSize.width); h = Math.min(h, screenSize.height); return new Dimension(w,h); } }; frame.setLayout(new BorderLayout()); frame.add(panel); frame.pack(); frame.setLocationRelativeTo(frame); frame.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE); frame.setVisible(true); return frame; } private static int[] histogram(int[] values) { int[] bins = new int[Arrays.max(values)+1]; java.util.Arrays.fill(bins, 0); for (int i = 0; i < values.length; i++) { bins[values[i]]++; } return bins; } }