package fft;

public class Fft {

    public final int size;

    public final Complex[] input;

    public Complex[] output;

    public Fft(int n) {
        if (checkSize(n)) {
            size = n;
            input = new Complex[size];
        } else
            throw new IllegalArgumentException("The argument must be power of 2");
    }

    // Check whether a number is a power of 2
    public static boolean checkSize(int a) {
        for (long b = 1; true; b *= 2) {
            if (a == b)
                return true;
            if (a < b)
                return false;
        }
    }

    public void setX(double re[], double im[]) {
        if (re.length != size || im.length != size)
            throw new IllegalArgumentException("Size of both arguments must be equal to " + this.size);
        for (int a = 0; a < size; a++)
            input[a] = new Complex(re[a], im[a]);
    }

    public double[] getRe() {
        double[] a = new double[size];
        for (int i = 0; i < size; i++)
            a[i] = output[i].getRe();
        return a;
    }

    public double[] getIm() {
        double[] a = new double[size];
        for (int i = 0; i < size; i++)
            a[i] = output[i].getIm();
        return a;
    }

    public Complex[] fftRec(Complex[] in, Complex[] coefs) {
        int len = in.length;
        if (len > 1) {
            int half = len / 2;
            Complex[] in0 = new Complex[half];
            Complex[] in1 = new Complex[half];
            for (int i = 0; i < half; i++) {
                in0[i] = in[2 * i];
                in1[i] = in[2 * i + 1];
            }
            Complex[] out0 = fftRec(in0, coefs);
            Complex[] out1 = fftRec(in1, coefs);
            Complex out[] = new Complex[len];
            for (int r = 0; r < half; r++) {
                Complex coeff = coefs[size * r / len]; //new Complex((double)r / (double)len);
                Complex par = out1[r].mul(coeff);
                out[r] = out0[r].add(par);
                out[r + half] = out0[r].sub(par);
            }
            return out;
        } else
            return in;
    }

    public void fft() {
        int half = size / 2;
        Complex[] coefs = new Complex[half];
        for (int a = 0; a < half; a++)
            coefs[a] = new Complex((double)a / (double)size);
        output = fftRec(input, coefs);
    }

    /************* to be supercompiled *************/

    public static double[] test(double[] arg) {
        int size = 4;
        if (arg.length != 2 * size)
            throw new IllegalArgumentException();

        double re[] = new double[size];
        double im[] = new double[size];
        for (int i = 0; i < size; i++) {
            re[i] = arg[2 * i];
            im[i] = arg[2 * i + 1];
        }

        Fft fft = new Fft(size);
        fft.setX(re, im);
        fft.fft();

        double[] outRe = fft.getRe();
        double[] outIm = fft.getIm();
        double res[] = new double[2 * size];
        for (int i = 0; i < size; i++) {
            res[2 * i] = outRe[i];
            res[2 * i + 1] = outIm[i];
        }
        return res;
    }

    /***********************************************/

    public static void main(String args[]) {
        if (args.length < 1) {
            System.out.println("Usage:  java Fft <number of iterations> [<number of tests>]");
            System.out.println("Sample: java Fft 1000 5");
            return;
        }
        int iters = Integer.valueOf(args[0]).intValue();
        int tests = 1;
        if (args.length > 1)
            tests = Integer.valueOf(args[1]).intValue();

        final double[] arg = {1, 0, 0, 1, 2, 1, 1, 2}; // size = 4
        //final double[] arg = {1,0, 0,1, 2,1, 1,2, 1,0, 0,-1, 2,1, 1,2,-1,0, 0,1, -2,1, 1,2,-1,0, 0,-1, 2,1, 1,2}; // size = 16

        double[] res1 = null;
        for (int i = 0; i < 100; i++) { // to warm up JVM
            res1 = test(arg);
        }
        for (int t = 0; t < tests; t++) {
            long start = System.currentTimeMillis();
            for (int i = 0; i < iters; i++) {
                res1 = test(arg);
            }
            long end = System.currentTimeMillis();
            System.out.println("Time of one iteration = " + (end - start) / (float)iters * 1000 + " microsec");
        }
    }
}

class Complex {

    public final double re, im;

    Complex(double re, double im) {
        this.re = re;
        this.im = im;
    }

    Complex(double r) {
        this(Math.cos(2 * Math.PI * r), Math.sin(2 * Math.PI * r));
    }

    Complex add(Complex that) {
        return new Complex(this.re + that.re, this.im + that.im);
    }

    Complex sub(Complex that) {
        return new Complex(this.re - that.re, this.im - that.im);
    }

    Complex mul(Complex that) {
        return new Complex(this.re * that.re - this.im * that.im, this.re * that.im + this.im * that.re);
    }

    public String toString() {
        return "" + re + "+i*" + im;
    }

    double getRe() {
        return this.re;
    }

    double getIm() {
        return this.im;
    }
}