/******************************************************************************
 * Copyright (C) 2015 Sebastiaan R. Hogenbirk                                 *
 * * This program is free software: you can redistribute it and/or modify * it
 * under the terms of the GNU General Public License as published by * the Free
 * Software Foundation, either version 3 of the License, or * (at your option)
 * any later version. * * This program is distributed in the hope that it will
 * be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of *
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the * GNU General
 * Public License for more details. * * You should have received a copy of the
 * GNU General Public License * along with this program.  If not, see
 * <http://www.gnu.org/licenses/>. *
 ******************************************************************************/

package thorwin.math.accelerator;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.Optional;
import java.util.logging.Level;
import java.util.logging.Logger;


/**
 * Interface to native BLAS functions.
 */
public final class Blas {

    static {
        try {
            String extension = ".so";

            if (System.getProperty("os.name").contains("Mac")) {
                extension = ".jnilib";
            }

            byte[] buffer = new byte[1024];
            File file = File.createTempFile("lib", extension);
            file.deleteOnExit();

            String microarchitecture = getSupportedMicroarchitecture();

            try (InputStream in = Blas.class.getResourceAsStream("Blas-" + microarchitecture + extension);
                 FileOutputStream out = new FileOutputStream(file)) {
                int len = in.read(buffer);
                while (len > 0) {
                    out.write(buffer, 0, len);
                    len = in.read(buffer);
                }
            }
            System.load(file.getAbsolutePath());
        } catch (IOException e) {
            throw new IllegalStateException(e);
        }
    }

    /**
     * Returns the supported micro-architecture.
     * @return micro-architecture
     */
    private static String getSupportedMicroarchitecture() {
        switch (getMicroarchitecture()) {
            case "SKYLAKE":
            case "BROADWELL":
            case "HASWELL":
                return "HASWELL";
            case "SANDYBRIDGE":
            case "IVYBRIDGE":
                return "SANDYBRIDGE";
            case "BULLDOZER":
                return "BULLDOZER";
            case "PILEDRIVER":
            case "STEAMROLLER":
            case "EXCAVATOR":
                return "PILEDRIVER";
        }
        return "GENERIC";
    }

    /**
     * Returns the actual micro-architecture of the running CPU.
     * Only a limited number of Intel and AMD x64 processors are
     * detected.
     * @return micro-architecture
     */
    private static String getMicroarchitecture() {
        Optional<CpuID> cpuid = CpuID.getCpuID();

        if (cpuid.isPresent()) {
            int family = cpuid.get().getFamily();
            int model = cpuid.get().getModel();

            switch (family) {
                // Intel
                case 0x6:
                    switch (model) {
                        case 0x2A: // Sandy Bridge
                        case 0x2D: // Sandy Bridge-E
                            return "SANDYBRIDGE";
                        case 0x3A: // Ivy Bridge
                        case 0x3E:
                            return "IVYBRIDGE";
                        case 0x3C: // Haswell
                        case 0x3F: // Haswell-E
                        case 0x45:
                        case 0x46:
                            return "HASWELL";
                        case 0x3d:
                        case 0x47:
                        case 0x4f:
                        case 0x56:
                            return "BROADWELL";
                        case 0x4e:
                        case 0x5e:
                            return "SKYLAKE";
                    }
                    break;

                // AMD 0x15
                case 0x15:
                    if (model < 0xf) {
                        return "BULDOZER";
                    }
                    if (model >= 0x10 && model <= 0x2f) {
                        return "PILEDRIVER";
                    }
                    if (model >= 0x30 && model <= 0x4f){
                        return "STEAMROLLER";
                    }
                    if (model >= 0x60 && model <= 0x7f){
                        return "EXCAVATOR";
                    }
            }
        }
        return "GENERIC";
    }


    /**
     * BLAS DGEMM function call. <i>alpha&times;op(a)&times; + beta&times;c</i>
     *
     * @param order  array pack order
     * @param transa transpose state of matrix a
     * @param transb transpose state of matrix b
     * @param m      row count of matrix a and c
     * @param n      column count of matrix b and c
     * @param k      column count of matrix a and row count of matrix b
     * @param alpha  scalar alpha
     * @param a      matrix a
     * @param b      matrix b
     * @param beta   scalar beta
     * @param c      matrix c, overwritten
     */
    public static void dgemm(Order order,
                             Transpose transa,
                             Transpose transb,
                             int m,
                             int n,
                             int k,
                             double alpha,
                             double[] a,
                             double[] b,
                             double beta,
                             double[] c) {
        dgemm(order.value(),
                transa.value(),
                transb.value(),
                m,
                n,
                k,
                alpha,
                a,
                transa == Transpose.NO_TRANSPOSE ? k : m,
                b,
                transb == Transpose.NO_TRANSPOSE ? n : k,
                beta,
                c,
                n);
    }


    /**
     * BLAS DGEMM function call.
     *
     * @param order  array pack order
     * @param transa transpose state of matrix a
     * @param transb transpose state of matrix b
     * @param m      row count of matrix a and c
     * @param n      column count of matrix b and c
     * @param k      column count of matrix a and row count of matrix b
     * @param alpha  scalar alpha
     * @param a      matrix a
     * @param lda    first dimension of a
     * @param b      matrix b
     * @param ldb    first dimension of b
     * @param beta   scalar beta
     * @param c      matrix c, overwritten
     * @param ldc    first dimension of c
     */
    private static native void dgemm(int order,
                                     int transa,
                                     int transb,
                                     int m,
                                     int n,
                                     int k,
                                     double alpha,
                                     double[] a,
                                     int lda,
                                     double[] b,
                                     int ldb,
                                     double beta,
                                     double[] c,
                                     int ldc);


    /**
     * BLAS DAXPY function call. Calculates <i>Y<sub>out</sub> = A &times; X +
     * Y</i>
     *
     * @param da scale A
     * @param dx input vector X
     * @param dy input/out vector Y
     */
    public static void daxpy(double da, double[] dx, double[] dy) {
        if (dx.length != dy.length) {
            throw new IllegalArgumentException("arrays should have equal lengths");
        }
        daxpy(dx.length, da, dx, 1, dy, 1);
    }


    /**
     * BLAS DAXPY function call
     *
     * @param n    dimension
     * @param da   value A
     * @param dx   vector X
     * @param incx increment
     * @param dy   vector Y
     * @param incy increment Y
     */
    private static native void daxpy(int n,
                                     double da,
                                     double[] dx,
                                     int incx,
                                     double[] dy,
                                     int incy);
}
