/*
  This matrix multiply driver was originally written by Jason Riedy.
  Most of the code (and comments) are his, but there are several
  things I've hacked up.  It seemed to work for him, so any errors
  are probably mine.

  Ideally, you won't need to change this file.  You may want to change
  a few settings to speed debugging runs, but remember to change back
  to the original settings during final testing.

  The output format: "Size: %u\tmflop/s: %g\n"
*/


#include <stdlib.h>
#include <stdio.h>
#include <string.h>

#include <float.h>
#include <math.h>

#include <sys/types.h>
#include <sys/resource.h>

#include <unistd.h>
#include <time.h>

/*
  We try to run enough iterations to get reasonable timings.  The matrices
  are multiplied at least MIN_RUNS times.  If that doesn't take MIN_CPU_SECS
  seconds, then we double the number of iterations and try again.

  You may want to modify these to speed debugging...
*/
#define MIN_RUNS     4
#define MIN_CPU_SECS 1.0

/*
  Your function _MUST_ have the following signature:
*/
extern void square_dgemm (const unsigned M, 
                          const double *A, const double *B, double *C);

/*
  Note the strange sizes...  You'll see some interesting effects
  around some of the powers-of-two.
*/
#ifndef FASTTEST
/*
  Change that to zero for an abbreviated list while debugging...
*/
const unsigned test_sizes[] = {
     20,
     24,
     31,
     32,
     48,
     64,
     73,
     96,
     97,
     127,
     128,
     129,
     163,
     191,
     192,
     229,
     255,
     256,
     257,
     319,
     320,
     321,
     417,
     479,
     480,
     511,
     512,
};

# define MAX_SIZE 512u
#else
const unsigned test_sizes[] = {
     20,
     24,
     31,
     32,
     48,
     127,
     128,
     129,
     255,
     256,
     257,
     319,
     320,
     321,
     511,
     512,
};
#  define MAX_SIZE 512u
#endif


#define N_SIZES ((unsigned) sizeof (test_sizes) / sizeof (unsigned))

double A[MAX_SIZE * MAX_SIZE];
double B[MAX_SIZE * MAX_SIZE];
double C[MAX_SIZE * MAX_SIZE];

void 
matrix_init (double *A)
{
     unsigned i;

     for (i = 0; i < MAX_SIZE*MAX_SIZE; ++i) {
          A[i] = drand48 ();
     }
}

void 
matrix_clear (double *C) 
{
     memset (C, 0, MAX_SIZE * MAX_SIZE * sizeof (double));
}

/*
  Dot products satisfy the following error bound:
   float(sum a_i * b_i) = sum a_i * b_i * (1 + delta_i)
  where delta_i <= n * epsilon.  In order to check your matrix
  multiply, we compute each element in term and make sure that
  your product is within three times the given error bound.
  We make it three times because there are three sources of
  error:

   - the roundoff error in your multiply
   - the roundoff error in our multiply
   - the roundoff error in computing the error bound

  That last source of error is not so significant, but that's a
  story for another day.
 */
void
validate_dgemm (const unsigned M,
                const double *A, const double *B, double *C)
{
    unsigned i, j, k;

    matrix_clear (C);
    square_dgemm (M, A, B, C);

    for (i = 0; i < M; ++i) {
        for (j = 0; j < M; ++j) {

            double dotprod = 0;
            double errorbound = 0;
            double err;

            for (k = 0; k < M; ++k) {
                double prod = A[k*M + i] * B[j*M + k];
                dotprod += prod;
                errorbound += fabs(prod);
            }
            errorbound *= (M * DBL_EPSILON);

            err = fabs(C[j*M + i] - dotprod);
            if (err > 3*errorbound) {
                printf("Matrix multiply failed.\n");
                printf("C(%d,%d) should be %g, was %g\n", i, j,
                       C[j*M + i], dotprod);
                printf("Error of %g, acceptable limit %g\n",
                       err, 3*errorbound);
                exit(-1);
            }
        }
    }
}

double
time_dgemm (const unsigned M,
            const double *A, const double *B, double *C)
{
    /*
      clock() normally measures milliseconds.  In non-Alpha Linux,
      though, it's limited by the HZ in include/asm-i386/param.h.
      The setting is, imho, unreasonably low for faster machines,
      but...  It's a subject of flame wars.  sigh.

      Timing under Linux is not fun.
    */

    clock_t cpu_time;
    clock_t last_clock;
    double mflops, mflop_s;
    double secs = -1;

    unsigned num_iterations = MIN_RUNS;
    unsigned i;

    while (secs < MIN_CPU_SECS) {

        cpu_time = 0;

        matrix_clear (C);
        last_clock = clock();
        for (i = 0; i < num_iterations; ++i) {
            square_dgemm (M, A, B, C);
        }
        cpu_time += clock() - last_clock;

        mflops  = 2.0 * num_iterations * M * M * M / 1.0e6;
        secs    = cpu_time / ((double) CLOCKS_PER_SEC);
        mflop_s = mflops/secs;

        num_iterations *= 2;
    }

    return mflop_s;
}

int
main (void)
{
     unsigned sz_i;
     double mflop_s;

     matrix_init (A);
     matrix_init (B);

     for (sz_i = 0; sz_i < N_SIZES; ++sz_i) {

          const unsigned M = test_sizes[sz_i];

#ifndef NO_VALIDATE
          validate_dgemm (M, A, B, C);
#endif
          mflop_s = time_dgemm(M, A, B, C);    

          printf ("Size: %u\tmflop/s: %g\n", M, mflop_s);
     }

     return 0;
}

