/* 
   written by: PA Taylor

   Aug. 2016: v1.0, starting

   [PT: Aug 26,2019] v1.2, add in option for freezing one dset's map
   at one location, whilst allowing the other to roam in the usual
   fashion.  Options/functionality added for Zhihao Li.

*/

#include <stdio.h>
#include <stdlib.h>
#include <math.h>
#include <unistd.h>
#include <time.h>
#include "debugtrace.h"
#include "mrilib.h"
#include "mrilib.h"    
#include "3ddata.h"    
#include "DoTrackit.h"
#include "checks_and_balances.h"
#include "rsfc.h"

void usage_SpaceTimeCorr(int detail) 
{
   printf(
"\n"
"  3dSpaceTimeCorr\n"
"  v1.2 (PA Taylor, Aug. 2019)\n"
"\n"
"  This program is for calculating something *similar* to the (Pearson)\n"
"   correlation coefficient between corresponding voxels between two data\n"
"   sets, which is what 3dTcorrelate does.  However, this is program \n"
"   operates differently. Here, two data sets are loaded in, and for each \n"
"   voxel in the brain:\n"
"      + for each data set, an ijk-th voxel is used as a seed to generate a\n"
"        correlation map within a user-defined mask (e.g., whole brain,\n"
"        excluding the seed location where r==1, by definition);\n"
"      + that correlation map is Fisher Z transformed;\n"
"      + the Z-correlation maps are (Pearson) correlated with each other,\n"
"        generating a single correlation coefficient;\n"
"      + the correlation coefficient is stored at the same ijk-th voxel\n"
"        location in the output data set;\n"
"   and the process is repeated.  Thus, the output is a whole brain map\n"
"   of r-correlation coefficients for corresponding voxels from the two data\n"
"   sets, generated by temporal and spatial patterns (-> space+time \n"
"   correlation!).\n"
"\n"
"   This could be useful when someone *wishes* that s/he could use \n"
"   3dTcorrelate on something like resting state FMRI data.  Maybe.\n"
"   Note that this program could take several minutes or more to run,\n"
"   depending on the size of the data set and mask.\n"
"\n"
"* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *\n"
"  \n"
"  + USAGE: Load in 2 data sets and a mask.  This computation can get pretty\n"
"           time consuming-- it depends on the number of voxels N like N**2.\n"
"\n"
"* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *\n"
"\n"
"  + COMMAND:  two 4D data sets need to be put in (order doesn't matter), \n"
"              and a mask also *should* be.\n"
"\n"
"    3dSpaceTimeCorr -insetA FILEA -insetB FILEB -prefix PREFIX   \\\n"
"                   {-mask MASK} {-out_Zcorr}  \n"
"                   {-freeze_insetA_ijk II JJ KK} \n"
"                   {-freeze_insetA_xyz XX YY ZZ} \n"
"\n"
"    where:\n"
"\n"
"  -insetA FILEA  :one 4D data set.\n"
"  -insetB FILEB  :another 4D data set; must have same spatial dimensions as\n"
"                  FILEA, as well as same number of time points.\n"
"\n"
"  -mask MASK     :optional mask.  Highly recommended to use for speed of\n"
"                  calcs (and probably for interpretability, too).\n"
"\n"
"  -prefix PREFIX :output filename/base.\n"
"\n"
"  -out_Zcorr     :switch to output Fisher Z transform of spatial map\n"
"                  correlation (default is Pearson r values).\n"
"\n"
"-freeze_insetA_ijk II JJ KK\n"
"                 :instead of correlating the spatial correlation maps\n"
"                  of A and B that have matching seed locations, with this\n"
"                  option you can 'freeze' the seed voxel location in \n"
"                  the input A dset, while the seed location in B moves\n"
"                  throughout the volume or mask as normal.\n"
"                  Here, one inputs three values, the ijk indices in\n"
"                  the dataset. (See next opt for freezing at xyz location.)\n"
"\n"
"-freeze_insetA_xyz XX YY ZZ\n"
"                 :same behavior as using '-freeze_insetA_ijk ..', but here\n"
"                  one inputs the xyz (physical coordinate) indices.\n"
"\n"
"* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *\n"
"\n"
"  + OUTPUT: \n"
"      A data set with one value at each voxel, representing the space-time \n"
"      correlation of the two input data sets within the input mask.\n"
"\n"
"* * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *\n"
"\n"
"  + EXAMPLE:\n"
"    3dSpaceTimeCorr                       \\\n"
"        -insetA SUB_01.nii.gz             \\\n"
"        -insetB SUB_02.nii.gz             \\\n"
"        -mask   mask_GM.nii.gz            \\\n"
"        -prefix stcorr_01_02              \\\n"
"\n"
"____________________________________________________________________________\n"
          );
	return;
}

int main(int argc, char *argv[]) {
   int i,j,k,ii,jj,kk,mm;
   int idx, ctr;
   int iarg;
   THD_3dim_dataset *insetTIMEA=NULL, *insetTIMEB=NULL;
   THD_3dim_dataset *MASK=NULL;
   THD_3dim_dataset *outset=NULL;
   char *prefix="SPACETIMECORR" ;

   FILE *fout0, *fout1;
   int Nvox=-1;       // tot number vox in data set
   int Nmskd=-1,Nmskdm1=-1;      // tot number vox in mask
   int *Dim=NULL;
   byte ***mskd=NULL; // define mask of where time series are nonzero

   float *mapA=NULL, *mapB=NULL; // will hold time series correlations
   float *scorrAB=NULL;          // will be spatial correlation map
   float *tsX=NULL, *tsY=NULL;
   int np=0, nprog = 0;           // count progress.
   time_t t_start;

   int ZOUT = 0;

   int myloc[3] = {0,0,0};        // general location, will move around

   int myloc_A[3] = {0,0,0};      // for special case of 'freezing' loc in A
   float myloc_xyz_A[3] = {0,0,0};
   int DO_FREEZE_A = 0;           // 1 for ijk; 2 for xyz; 0 for not used
   int FOUND_A = 0;

   int coord_warn[1] = {0};       // for checking xyz->ijk being in FOV
   THD_ivec3 ivec3_aaa;
   THD_fvec3 fv;

   mainENTRY("3dSpaceTimeCorr"); machdep(); 
  
   // ****************************************************************
   // ****************************************************************
   //                    load AFNI stuff
   // ****************************************************************
   // ****************************************************************

   // INFO_message("version: NU");
	
   /** scan args **/
   if (argc == 1) { usage_SpaceTimeCorr(1); exit(0); }
   iarg = 1; 
   while( iarg < argc && argv[iarg][0] == '-' ){
      if( strcmp(argv[iarg],"-help") == 0 || 
          strcmp(argv[iarg],"-h") == 0 ) {
         usage_SpaceTimeCorr(strlen(argv[iarg])>3 ? 2:1);
         exit(0);
      }
		
      if( strcmp(argv[iarg],"-out_Zcorr") == 0) {
         ZOUT=1;
         iarg++ ; continue ;
      }

      if( strcmp(argv[iarg],"-prefix") == 0 ){
         iarg++ ; if( iarg >= argc ) 
                     ERROR_exit("Need argument after '-prefix'");
         prefix = strdup(argv[iarg]) ;
         if( !THD_filename_ok(prefix) ) 
            ERROR_exit("Illegal name after '-prefix'");
         iarg++ ; continue ;
      }
	 
      if( strcmp(argv[iarg],"-insetA") == 0 ){
         iarg++ ; if( iarg >= argc ) 
                     ERROR_exit("Need argument after '-insetA'");

         insetTIMEA = THD_open_dataset(argv[iarg]);
         if( insetTIMEA == NULL )
            ERROR_exit("Can't open time series dataset '%s'.",argv[iarg]);

         Dim = (int *)calloc(4,sizeof(int));
         DSET_load(insetTIMEA); CHECK_LOAD_ERROR(insetTIMEA);
         Nvox = DSET_NVOX(insetTIMEA) ;
         Dim[0] = DSET_NX(insetTIMEA); Dim[1] = DSET_NY(insetTIMEA); 
         Dim[2] = DSET_NZ(insetTIMEA); Dim[3] = DSET_NVALS(insetTIMEA); 

         iarg++ ; continue ;
      }

      if( strcmp(argv[iarg],"-insetB") == 0 ){
         iarg++ ; if( iarg >= argc ) 
                     ERROR_exit("Need argument after '-insetB'");

         insetTIMEB = THD_open_dataset(argv[iarg]);
         if( insetTIMEB == NULL )
            ERROR_exit("Can't open time series dataset '%s'.",argv[iarg]);

         DSET_load(insetTIMEB); CHECK_LOAD_ERROR(insetTIMEB);

         iarg++ ; continue ;
      }

      if( strcmp(argv[iarg],"-mask") == 0 ){
         iarg++ ; if( iarg >= argc ) 
                     ERROR_exit("Need argument after '-mask'");

         MASK = THD_open_dataset(argv[iarg]) ;
         if( MASK == NULL )
            ERROR_exit("Can't open time series dataset '%s'.",argv[iarg]);

         DSET_load(MASK); CHECK_LOAD_ERROR(MASK);
			
         iarg++ ; continue ;
      }

      if( strcmp(argv[iarg],"-freeze_insetA_ijk") == 0 ){
         iarg++ ; if( iarg >= argc ) 
                     ERROR_exit("Need 3 arguments after '-freeze_insetA_ijk'");

         DO_FREEZE_A = 1;
         myloc_A[0] = atoi(argv[iarg]);  iarg++ ;
         myloc_A[1] = atoi(argv[iarg]);  iarg++ ;
         myloc_A[2] = atoi(argv[iarg]);  
			
         iarg++ ; continue ;
      }

      if( strcmp(argv[iarg],"-freeze_insetA_xyz") == 0 ){
         iarg++ ; if( iarg >= argc ) 
                     ERROR_exit("Need 3 arguments after '-freeze_insetA_xyz'");

         DO_FREEZE_A = 2;
         myloc_xyz_A[0] = atof(argv[iarg]);  iarg++ ;
         myloc_xyz_A[1] = atof(argv[iarg]);  iarg++ ;
         myloc_xyz_A[2] = atof(argv[iarg]);  
			
         iarg++ ; continue ;
      }


      ERROR_message("Bad option '%s'\n",argv[iarg]) ;
      suggest_best_prog_option(argv[0], argv[iarg]);
      exit(1);
   }
   
	INFO_message("Have read in options");
   
   // TEST BASIC INPUT PROPERTIES
   if (iarg < 3) {
      ERROR_message("Too few options. Try -help for details.\n");
      exit(1);
   }
	
   if( (!insetTIMEA) || (!insetTIMEB) )
      ERROR_exit("Need both insetA and insetB to be input!");
   
   // check dataset fitting:
   i = CompareSetDims(insetTIMEA, insetTIMEB, 4);
   
   if ( !MASK )
      WARNING_message("No mask input-- "
                      "will correlate across whole volume!\n");
   else
      i = CompareSetDims(insetTIMEA, MASK, 3); // don't check time dim
   
   // [PT: Aug 26, 2019] now allow for 'freezing' a dset...
   if ( DO_FREEZE_A == 2 ) {
      // First, check that XYZ stay within FOV
      for ( i=1 ; i<3 ; i++ )
         fv.xyz[i] = myloc_xyz_A[i];

      ivec3_aaa = THD_3dmm_to_3dind_warn(insetTIMEA,
                                         fv, 
                                         coord_warn);
      if( coord_warn[0] )
         ERROR_exit("%d of the xyz values for the 'frozen' coordinate in\n"
                    "   dset A appear(s) to be outside the FOV.\n"
                    "   These are the estimated ijk indices (check for one\n"
                    "   floored at 0 or ceilinged at max dim ind: %d %d %d):\n"
                    "   %d  %d  %d", coord_warn[0], 
                    Dim[0]-1, Dim[1]-1, Dim[2]-1,
                    ivec3_aaa.ijk[0], ivec3_aaa.ijk[1], ivec3_aaa.ijk[2]);
      else
         INFO_message("XYZ coordinates of 'frozen' point in dset A appear "
                      "to be in FOV.");

      // if user input xyz, get ijk for this dset
      AFNI_xyz_to_ijk( insetTIMEA,
         myloc_xyz_A[0], myloc_xyz_A[1], myloc_xyz_A[2],
         &myloc_A[0], &myloc_A[1], &myloc_A[2] );
      INFO_message("Converted xyz = %f, %f, %f\n"
                   "       to ijk = %d, %d, %d.",
                   myloc_xyz_A[0], myloc_xyz_A[1], myloc_xyz_A[2],
                   myloc_A[0], myloc_A[1], myloc_A[2]);

   } 
   if ( DO_FREEZE_A ) {
      // ... now, check any ijk (either directly entered or from xyz)
      for ( i=1 ; i<3 ; i++ )
         if( (myloc_A[i] < 0) || (myloc_A[i] >= Dim[i]) ) 
            ERROR_exit("Dimension [%d] has illegal ijk value: %d.\n"
                       "This is not in allow index values here: [%d, %d).",
                       i, myloc_A[i], 0, Dim[i]);
      INFO_message("IJK coordinates for 'freezing' dset A "
                   "appear to be OK.");
   }
   INFO_message("XYZ: %f %f %f", myloc_xyz_A[0], myloc_xyz_A[1], 
                myloc_xyz_A[2]);
   INFO_message("IJK: %d %d %d", myloc_A[0], myloc_A[1], myloc_A[2]);

   INFO_message("Checked inputs.");

	
   // ****************************************************************
   // ****************************************************************
   //                    pre-stuff, make storage
   // ****************************************************************
   // ****************************************************************

   mskd = (byte ***) calloc( Dim[0], sizeof(byte **) );
   for ( i = 0 ; i < Dim[0] ; i++ ) 
      mskd[i] = (byte **) calloc( Dim[1], sizeof(byte *) );
   for ( i = 0 ; i < Dim[0] ; i++ ) 
      for ( j = 0 ; j < Dim[1] ; j++ ) 
         mskd[i][j] = (byte *) calloc( Dim[2], sizeof(byte) );

   if( mskd == NULL ) { 
      fprintf(stderr, "\n\n MemAlloc failure (masks).\n\n");
      exit(4);
   }

   // *************************************************************
   // *************************************************************
   //                    Beginning of main loops
   // *************************************************************
   // *************************************************************
	

   // go through once: define data vox
   ctr = 0;
   idx = 0;
   for( k=0 ; k<Dim[2] ; k++ ) 
      for( j=0 ; j<Dim[1] ; j++ ) 
         for( i=0 ; i<Dim[0] ; i++ ) {
            if( MASK ) {
               if( THD_get_voxel(MASK,idx,0)>0 ){
                  mskd[i][j][k] = 1;
                  ctr++;
               }
            }
            else{
               mskd[i][j][k] = 1;
               ctr++;
            }
            idx+= 1; // skip, and mskd and KW are both still 0 from calloc
         }

   Nmskd = ctr;
   Nmskdm1 = Nmskd - 1; // use, because we skip self in ts-corr
   np = (int) Nmskd / 10.;

   if( MASK )
      INFO_message("Made mask: %d voxels", Nmskd);
   else
      INFO_message("No mask: %d voxels in data set", Nvox);


   // **************************************************************
   // **************************************************************
   //                 Store and output
   // **************************************************************
   // **************************************************************

   mapA = (float *)calloc(Nmskdm1, sizeof(float)); 
   mapB = (float *)calloc(Nmskdm1, sizeof(float)); 

   scorrAB = (float *)calloc(Nvox, sizeof(float)); 

   tsX = (float *)calloc(Dim[3], sizeof(float)); 
   tsY = (float *)calloc(Dim[3], sizeof(float)); 

   if( (mapA == NULL) || (mapB == NULL) || (scorrAB == NULL) ||
       (tsX == NULL) || (tsY == NULL)) {
      fprintf(stderr, "\n\n MemAlloc failure.\n\n");
      exit(2);
   }

   
   INFO_message("Now the work begins!");
   t_start = time(NULL);

   if ( DO_FREEZE_A ) {
      // If we have a special location in A to use, make the mapA for
      // it once, before everything else.
      ctr = 0;
      for( k=0 ; k<Dim[2] ; k++ ) 
         for( j=0 ; j<Dim[1] ; j++ ) 
            for( i=0 ; i<Dim[0] ; i++ ) {
               if( (myloc_A[0] == i) && (myloc_A[1] == j) &&
                   (myloc_A[2] == k) ) {
                  if( !mskd[i][j][k] ) {
                     ERROR_exit("Selected voxel for 'freezing' is "
                                "outside the given mask");
                  } else {
                     mm = THD_extract_float_array(ctr,insetTIMEA,tsX);  
                     mm = WB_corr_loop(
                                       tsX,tsY,
                                       insetTIMEA,
                                       Dim,
                                       mskd,
                                       mapA,
                                       myloc_A
                                       );
                     FOUND_A = 1;
                     break; // and stop searching through
                  }
               }
               ctr++;
            }
      if ( !FOUND_A ) 
         ERROR_exit("Didn't find 'frozen' location in A??");

      INFO_message("Set 'frozen' time series from dset A.");
   }


   // go through once: define data vox
   ctr = 0;
   for( k=0 ; k<Dim[2] ; k++ ) 
      for( j=0 ; j<Dim[1] ; j++ ) 
         for( i=0 ; i<Dim[0] ; i++ ) {
            if( mskd[i][j][k] ) {

               myloc[0]=i;
               myloc[1]=j;
               myloc[2]=k;

               if ( !DO_FREEZE_A ) {
                  mm = THD_extract_float_array(ctr,insetTIMEA,tsX);  
                  mm = WB_corr_loop(
                                    tsX,tsY,
                                    insetTIMEA,
                                    Dim,
                                    mskd,
                                    mapA,
                                    myloc
                                    );
                  }
               
               mm = THD_extract_float_array(ctr,insetTIMEB,tsX);  
               mm = WB_corr_loop(
                                 tsX,tsY,
                                 insetTIMEB,
                                 Dim,
                                 mskd,
                                 mapB,
                                 myloc
                                 );

               scorrAB[ctr] = THD_pearson_corr(Nmskdm1, mapA, mapB); 

               nprog++;
               if (nprog % np == 0) {
                  fprintf(stderr,"\t%s %3.0f%% %s -> %.2f min\n",
                          "[", nprog *10./np,"]", 
                          (float) difftime( time(NULL), t_start)/60. );
               }
            }
            ctr++;
         }

   if(ZOUT) {
      INFO_message("Doing Fisher Z transform of output at user behest.");
      ctr = 0;
      for( k=0 ; k<Dim[2] ; k++ ) 
         for( j=0 ; j<Dim[1] ; j++ ) 
            for( i=0 ; i<Dim[0] ; i++ ) {
               if( mskd[i][j][k] ) 
                  scorrAB[ctr] = BOBatanhf(scorrAB[ctr]);
               ctr++;
            }
   }

   // **************************************************************
   // **************************************************************
   //                 Store and output
   // **************************************************************
   // **************************************************************
	
   outset = EDIT_empty_copy( insetTIMEA ) ; 

   EDIT_dset_items( outset,
                    ADN_datum_all , MRI_float, 
                    ADN_ntt       , 1, 
                    ADN_nvals     , 1,
                    ADN_prefix    , prefix,
                    ADN_none ) ;

  if( !THD_ok_overwrite() && THD_is_ondisk(DSET_HEADNAME(outset)) )
    ERROR_exit("Can't overwrite existing dataset '%s'",
               DSET_HEADNAME(outset));

  EDIT_substitute_brick(outset, 0, MRI_float, scorrAB); 
  scorrAB=NULL;
  
  THD_load_statistics(outset);
  tross_Make_History("3dSpaceTimeCorr", argc, argv, outset);
  THD_write_3dim_dataset(NULL, NULL, outset, True);

   // ************************************************************
   // ************************************************************
   //                    Freeing
   // ************************************************************
   // ************************************************************
	
   if(mapA)
      free(mapA);
   if(mapB)
      free(mapB);
   if(scorrAB)
      free(scorrAB);
   if(tsX)
      free(tsX);
   if(tsY)
      free(tsY);

   if(outset){
      DSET_delete(outset);
      free(outset);
   }
   if(insetTIMEA){
      DSET_delete(insetTIMEA);
      free(insetTIMEA);
   }
   if(insetTIMEB){
      DSET_delete(insetTIMEB);
      free(insetTIMEB);
   }

   if( MASK ){
      DSET_delete(MASK);
      free(MASK);
   }
   if(mskd){
      for( i=0 ; i<Dim[0] ; i++) 
         for( j=0 ; j<Dim[1] ; j++) 
            free(mskd[i][j]);
      for( i=0 ; i<Dim[0] ; i++) 
         free(mskd[i]);
      free(mskd);
   }
   
   if(Dim)
      free(Dim);
	
   return 0;
}
