#!/bin/tcsh

@global_parse `basename $0` "$*" ; if ($status) exit 0

# compute optimally combining weights for echoes
#
# inputs:  echo times and one run of multi-echo EPI data
# outputs: a multi-volume dataset of voxelwise weights
#
# use:     OC for a run would be the sum of each echo time series times
#          its computed weight.
#
#          Voxels that do not solve to an appropriate decay will get the
#          average of the echoes (unless a better idea comes along).
#
# Should this be left as a little script or turned into a
# more formal python program?  Either way, add some options...

# system of equations from Javier's presentation:
#
#   log(mean(S(TE_1))) ~= -mean(R2s(x))*TE_1 + log(So(x))
#   log(mean(S(TE_2))) ~= -mean(R2s(x))*TE_2 + log(So(x))
#   log(mean(S(TE_3))) ~= -mean(R2s(x))*TE_3 + log(So(x))
#
#
# New method: solve the same system of equations, but without taking means.
#             So solve this across all time points.  To be efficient, demean
#             each time point (across echoes), demean the TEs, and make a
#             single regressor with the appropriate resulting TE values.
#
# Then T2* = 1/mean(R2s(x)), and weights come from:
#
#                TE_n*e^-(TE_n/T2*)
#   w(TE_n) = -------------------------
#             sum_n[TE_n*e^-(TE_n/T2*)]
#
#
# Note that T2* is the reciprocal of m(R2s), and that it is subsequently
# applied via TE_n/T2*.  So apply via T2s_inv (multiplicative), to avoid
# repeated division.
# 
# 
# ** bad voxels (e.g. m(R2s)<0) will get weights of 1.0/necho
# 

set VERSION = 0.3
set prog = @compute_OC_weights


set def_to_equal = 0
set echo_dsets = ()
set etimes_file = ''
set oc_method = 'OC_A'
set prefix = OC_weights
set t2s_limit = 300
set tolerance = 0.001
set work_dir = OC.weight.results
set verb = 1

set etimes_list = ( )
set times_in_ms = 1
set necho = 0
set root_dir = `pwd`

if ( $#argv < 1 ) goto HELP

set ac = 1
while ( $ac <= $#argv )
   if ( "$argv[$ac]" == "-help" || "$argv[$ac]" == "-h" ) then
      goto HELP
   else if ( "$argv[$ac]" == "-hist" ) then
      goto HIST

   else if ( "$argv[$ac]" == "-def_to_equal" ) then
      @ ac ++
      if ( $ac > $#argv ) then
         echo "** missing parameter for option '-def_to_equal'"
         exit 1
      endif
      if ( "$argv[$ac]" == yes ) then
         set def_to_equal = 1
      else if ( "$argv[$ac]" == no ) then
         set def_to_equal = 0
      else
         echo "** -def_to_equal requires yes/no, have $argv[$ac]"
         exit 1
      endif

   else if ( "$argv[$ac]" == "-echo_dsets" ) then
      @ ac ++
      set acbase = $ac
      while ( $ac <= $#argv )
         set o0 = "`echo $argv[$ac] | cut -b 1`"
         if ( "$o0" == '-' ) break
         # have something
         @ ac += 1
      end

      # ac index is never good here
      @ ac -= 1

      if ( $ac < ($acbase + 1) ) then
         echo "** -echo_dsets requires at least 2 datasets"
         exit 1
      endif

      # take them as echo datasets
      set echo_dsets = ( $argv[$acbase-$ac] )

   else if ( "$argv[$ac]" == "-echo_times_file" ) then
      @ ac ++
      if ( $ac > $#argv ) then
         echo "** missing parameter for option '-echo_times_file'"
         exit 1
      endif
      set etimes_file = $argv[$ac]

   else if ( "$argv[$ac]" == "-echo_times" ) then
      @ ac ++
      if ( $ac > $#argv ) then
         echo "** missing parameter for option '-echo_times'"
         exit 1
      endif
      set etimes_list = ( $argv[$ac] )

   else if ( "$argv[$ac]" == "-oc_method" ) then
      @ ac ++
      if ( $ac > $#argv ) then
         echo "** missing parameter for option '-oc_method'"
         exit 1
      endif
      set oc_method = $argv[$ac]
      if ( $oc_method != OC_A && $oc_method != 'OC_B' ) then
         echo "** invalid -oc_method = '$oc_method'"
         exit 1
      endif

   else if ( "$argv[$ac]" == "-prefix" ) then
      @ ac ++
      if ( $ac > $#argv ) then
         echo "** missing parameter for option '-prefix'"
         exit 1
      endif
      set prefix = $argv[$ac]

      # if there is a view, whine
      set vstr = `@parse_afni_name $prefix | awk '{print $3}'`
      set vchr = `echo $vstr | cut -b 1`
      if ( "$vchr" == "+" ) then
         echo "** please remove the view ($vstr) from the -prefix param"
         echo "   at the end of: -prefix $prefix"
         exit 1
      endif

   else if ( "$argv[$ac]" == "-sum_weight_tolerance" ) then
      @ ac ++
      if ( $ac > $#argv ) then
         echo "** missing parameter for option '-sum_weight_tolerance'"
         exit 1
      endif
      set tolerance = $argv[$ac]
   else if ( "$argv[$ac]" == "-t2_star_limit" ) then
      @ ac ++
      if ( $ac > $#argv ) then
         echo "** missing parameter for option '-t2_star_limit'"
         exit 1
      endif
      set t2s_limit = $argv[$ac]
   else if ( "$argv[$ac]" == "-verb" ) then
      @ verb += 1
      if ( $verb > 2 ) then
         set echo
      endif
   else if ( "$argv[$ac]" == "-work_dir" ) then
      @ ac ++
      if ( $ac > $#argv ) then
         echo "** missing parameter for option '-work_dir'"
         exit 1
      endif
      set work_dir = $argv[$ac]
   else
      echo "** bad option $ac : '$argv[$ac]'"
      exit 1
   endif

   @ ac += 1
end

# ----------------------------------------------------------------------
# check on the state of things, and try to make an echo list

# check for echo datasets
if ( $#echo_dsets < 2 ) then
   echo "** must have at least 2 -echo_dsets files"
   exit 1
endif


# if etimes_list is not already input, try to set from file
if ( $#etimes_list == 0 ) then
   if ( ! -f $etimes_file ) then
      echo "** missing -echo_times_file $etimes_file"
      exit 1
   endif
   set etimes_list = ( `1dcat $etimes_file` )
endif

# see if we have echo times at all now
set necho = $#etimes_list
if ( $necho == 0 ) then
   echo "** missing echo times from -echo_times_file or -echo_times"
   exit 1
endif

# if first echo is greater than 1, assume times are in ms
set times_in_ms = `ccalc -i "step($etimes_list[1]-1.0)"`
if ( ! $times_in_ms ) then
    echo "++ times seem to be in seconds, will scale to ms"
    set etimes_list = ( `echo $etimes_list | 1deval -a stdin: -expr 'a*1000'` )
endif

if ( $necho != $#echo_dsets ) then
   echo "** have $necho echo times, but $#echo_dsets echo datasets"
   exit 1
endif

# and note nt
set nt = `3dinfo -nt $echo_dsets[1]`

# ----------------------------------------------------------------------
# seem to have valid inputs, maybe chat a little
if ( $verb ) then
   echo ""
   echo "++ num echo dsets = $#echo_dsets"
   echo "++ echo times     = $etimes_list"
   echo "++ weight prefix  = $prefix"
   echo "++ work_dir       = $work_dir"
   echo "++ T2* limit      = $t2s_limit"
   echo "++ NT             = $nt"
   echo ""
endif



# try to set the view
set view = `3dinfo -av_space $echo_dsets[1]`
if ( "$view" == "NO-DSET" || "$view" == "" ) then
   echo "** failed to set view from $echo_dsets[1]"
   exit 1
endif

if ( $verb ) echo "-- have $necho echoes"
if ( $verb == 2 ) echo "-- have view $view"

set eind_list = ( `count -digits 1 1 $necho` )

# ----------------------------------------------------------------------
# make a new work directory and get to it
if ( -d $work_dir ) then
   echo "** output dir $work_dir already exists, exiting..."
   exit 1
endif

mkdir $work_dir


# ----------------------------------------------------------------------
# get inputs or basic computations into working directory

if ( $oc_method == OC_A ) then
   # get means for each echo
   set cur_prefix = s1.mean
   foreach eind ( $eind_list ) 
      3dTstat -mean -prefix $work_dir/$cur_prefix.e$eind $echo_dsets[$eind]
   end
else if ( $oc_method == OC_B ) then
   set cur_prefix = s1.log
   foreach eind ( $eind_list ) 
      3dcalc -a $echo_dsets[$eind] -expr 'log(a)' -float \
             -prefix $work_dir/$cur_prefix.0.e$eind
   end
else
   echo "** unknown oc_method $oc_method"
   exit 1
endif

# create new echo times in working dir and adjust var
# (echo times may have been converted to ms; times are vertical)
set etimes_file = echo_times.1D
echo $etimes_list | tr ' ' '\n' > $work_dir/$etimes_file


# ----------------------------------------------------------------------
# enter results dir and proceed
cd $work_dir

if ( $oc_method == OC_A ) then
   # get means for each echo
   foreach eind ( $eind_list ) 
      3dcalc -a $cur_prefix.e$eind$view -expr 'log(a)' -prefix s2.log.e$eind
   end

   # "temporal" catenation
   3dTcat -TR 1 -prefix s3.tcat.logs s2.log.e*.HEAD


   # regress to solve for R(x) and log(So(x))
   3dDeconvolve -input s3.tcat.logs$view   \
           -polort 0                       \
           -num_stimts 1                   \
           -stim_file 1 $etimes_file       \
           -x1D x.xmat.1D                  \
           -fitts s4.fitts                 \
           -bucket s4.bucket               \
           -bout

   # and create an R2 volume for the next step
   3dbucket -prefix s4.R2 s4.bucket$view'[2]'

else if ( $oc_method == OC_B ) then

   # demean the log time series (remove the mean log at each time point)
   3dMean -prefix $cur_prefix.1.mean $cur_prefix.0.e*$view.HEAD
   foreach eind ( $eind_list ) 
      3dcalc -a $cur_prefix.0.e$eind$view -b $cur_prefix.1.mean$view \
             -expr a-b -prefix $cur_prefix.2.e$eind.demean
   end

   # demean echo times
   set emean = `echo $etimes_list | 3dTstat -prefix - 1D:stdin\'`
   set etimes_dem = ( `1deval -a $etimes_file -expr a-$emean` )
   echo "++ demeaned echoes: $etimes_dem"

   # now create echo regressor: NT TE1 vals, NT TE2 vals, etc
   echo "-- creating regressor with $nt of each demeaned echo time"
   set ereg_vals = ()
   foreach eind ( $eind_list )
      # set replist = ( `python -c "print('$etimes_dem[$eind] '*$nt)"` )
      # use printf and sed instead of python, seems safter
      set replist = ( `printf "%${nt}s" " " | sed "s/ /$etimes_dem[$eind] /g"` )
      set ereg_vals = ( $ereg_vals $replist )
   end
   echo $ereg_vals | tr ' ' '\n' > reg_etimes_demean.1D
   3dDeconvolve -input $cur_prefix.2.e*.demean$view.HEAD   \
        -polort -1 -num_stimts 1                           \
        -stim_file 1 reg_etimes_demean.1D -stim_label 1 R2 \
        -x1D x.xmat.1D -tout -bucket s4.bucket
   3dbucket -prefix s4.R2 s4.bucket$view'[1]'
   
endif

# output volumes are:
#    0 : full-F
#    1 : constant term : log(S0(x))
#    2 : linear term   : -R2(x)


# ---------------------------------------------------------------------------
# compute T2* in multiple forms (for QC):
#
#    raw:     without any limits -1/mean(R2)
#    limited: if T2* < 0 or T2* > $t2s_limit, use $t2s_limit
#    inverse: (reciprocal) use 1/T2*, since there is little gain
#             in dividing twice
#             - if T2* < 0, apply as 0
#             - if sum of TE_i*e^-(TE_i / T2*) is not close to 1, apply as 0
#           ==> for 0, use weight = 1/num_echoes

# raw T2*: save a volume that has no applied limit, just for QC
3dcalc -a s4.R2$view -expr "-1/a" -datum float -prefix s5.t2.star.raw

# limited: T2*, but $t2s_limit if T2* < 0 or T2* > $t2s_limit
3dcalc -a s4.R2$view                                            \
       -expr "step(-a)*min(-1/a,$t2s_limit)+step(a)*$t2s_limit" \
       -datum float -prefix s5.t2.star

# inverse: 1/T2*, to use multiplication rather than repeated division
#          (so no t2s_limit?)
#        * 1/T2* is actually applied, if positive
3dcalc -a s4.R2$view -expr "-a*step(-a)" -datum float -prefix s5.t2.star.inv

# ---------------------------------------------------------------------------
# make a product time series (for any <= 0 values, replace with 1)
# TE_n * e^-(TE_n/T2*)
# ** use 1/T2* here, so product is   TE_n * e^-(TE_n * t2s_inv)
# ** relatedly, apply $t2s_limit here too, so that as R2 goes to zero and
#    negative, results converge to it, rather than jumping to a flat average
#
# if we default to equal weights, denote with zero in the prod/sum values
# - the default default is, uh, what are we talking about?
#   no wait, the default default is to use t2s_limit
if ( $def_to_equal ) then
   3dcalc -a s5.t2.star.inv$view -b $etimes_file        \
          -expr "step(a)*b*exp(-b*a)"                   \
          -prefix s6.prod
else
   3dcalc -a s5.t2.star.inv$view -b $etimes_file                           \
          -expr "step(a)*b*exp(-b*a)+(1-step(a))*b*exp(-b/$t2s_limit)"     \
          -prefix s6.prod
endif

# and sum for weight denominator
3dTstat -sum -prefix s7.sum.te s6.prod$view

# make a time series of weights per echo, which is now just prod/sum
# (guard against truncation errors by limiting to 1)
3dcalc -a s6.prod$view -b s7.sum.te$view -expr 'min(a/b,1)' \
       -prefix s8.weights.orig
set result_dset = s8.weights.orig

# as a safeguard, mask the weights by voxels with valid sums
# (within 'tolerance' of 1)
3dTstat -sum -prefix s9.sum.weights s8.weights.orig$view
3dcalc -a s9.sum.weights$view -prefix s9.tolerance.fail \
       -expr "step(a)*(1-within(a,1-$tolerance,1+$tolerance))"
3dcalc -a s9.sum.weights$view -prefix s9.tolerance.mask \
       -expr "within(a,1-$tolerance,1+$tolerance)"

# if default is mean across echoes, insert into result
if ( $def_to_equal ) then
   set wval = `ccalc 1/$necho`
   3dcalc -a s8.weights.orig$view -b s9.tolerance.mask$view \
          -expr "bool(b)*a+(1-bool(b))*$wval" -prefix s10.masked.weights
   set result_dset = s10.masked.weights
else
   set result_dset = s8.weights.orig
endif
        

# copy back to $prefix
cd $root_dir
3dbucket -prefix $prefix $work_dir/$result_dset$view
3dNotes -h "$prog $argv" $prefix$view

if ( $status ) then
   echo "** failed to create $prefix$view"
else
   echo ""
   echo "============================================================"
   echo ""
   echo "++ final output is $prefix$view"
   echo ""
   echo "   consider the command:"
   echo ""
   echo "   3dMean -weightset $prefix$view -prefix opt.combined \\"
   echo "          $echo_dsets"
   echo ""
   echo "============================================================"
   echo ""
endif

exit 0


# ===========================================================================
# show help 
HELP:
cat << EOF

$prog           - compute optimally combined weights dataset

   Given echo times (in a text file) and one run of multi-echo EPI data,
   compute a dataset that can be used to combine the echoes.  The weight
   dataset would have one volume per echo, which can be used to combine
   the echoes into a single dataset.  The same echoes can be applied to
   all runs.

       3dMean -weightset weights+tlrc -prefix opt.combined \
              echo1+tlrc echo2+tlrc echo3+tlrc

   For clarity, a similar 3dcalc computation would look like:

       3dcalc -a echo1+tlrc        -b echo2+tlrc        -c echo3+tlrc        \
              -d weights+tlrc'[0]' -e weights+tlrc'[1]' -f weights+tlrc'[2]' \
              -expr 'a*d+b*e+c*f'  -prefix opt.combined

   

   ----------------------------------------------------------------------

   These computations are based on the system of equations from:

      o Posse, S., Wiese, S., Gembris, D., Mathiak, K., Kessler, C.,
        Grosse-Ruyken, M.L., Elghahwagi, B., Richards, T., Dager, S.R.,
        Kiselev, V.G.
        Enhancement of BOLD-contrast sensitivity by single-shot multi-echo
        functional MR imaging.
        Magnetic Resonance in Medicine 42:87–97 (1999)

      o Prantik Kundu, Souheil J. Inati, Jennifer W. Evans, Wen-Ming Luh,
        Peter A. Bandettini
        Differentiating BOLD and non-BOLD signals in fMRI time series using
        multi-echo EPI
        NeuroImage 60 (2012) 1759–1770

      o a summer 2017 presentation by Javier Gonzalez-Castillo

   ----------------------------------------------------------------------

   After solving:
  
     log(mean(S(TE_1))) ~= -mean(R2s(x))*TE_1 + log(So(x))
     log(mean(S(TE_2))) ~= -mean(R2s(x))*TE_2 + log(So(x))
     log(mean(S(TE_3))) ~= -mean(R2s(x))*TE_3 + log(So(x))
  
   then T2* = 1/mean(R2s(x)), and weights come from:
  
                  TE_n*e^-(TE_n/T2*)
     w(TE_n) = -------------------------
               sum_n[TE_n*e^-(TE_n/T2*)]

   Bad, naughty voxels are defined as those with either negative T2* values,
   or for which the sum of the weights is not sufficiently close to 1, which
   would probably mean that there were computational truncation errors, likely
   due to R2s being very close to 0.

   so "fail" if
         mean(R2s) <= 0
   or
         abs(1-sum[weights]) > 'tolerance'

   In such cases, the weights will default to the result based on the maximum
   T2* value (unless "-def_to_equal yes" is applied, in which case the default
   is 1/number_of_echoes, which is equal weighting across echoes).
  
   ----------------------------------------------------------------------

   examples:

      1. basic

         @compute_OC_weights -echo_times_file etimes.1D \
                -echo_dsets pb02*r01*volreg*.HEAD
  
      2. Specify working directory and resulting weights dataset prefix.
         Then use the weight dataset to combine the echoes.

         @compute_OC_weights -echo_times_file etimes.1D \
                -echo_dsets pb02*r01*volreg*.HEAD       \
                -prefix OC.weights.run1 -work_dir OC.work.run1

         3dMean -weightset OC.weights.run1+tlrc -prefix epi_run1_OC \
                pb02*r01*volreg*.HEAD

   ----------------------------------------------------------------------

   random babble:

      The T2* map is not actually used, but rather 1/T2* (to avoid repeated
      division).

      T2* is restricted to the range (0, T2S_LIMIT), where the default limit is
      300 (see -t2_star_limit).

      A "bad" T2* value (T2* <= 0 or T2* > T2S_LIMIT) will lead to use of the
      limit T2S_LIMIT, so that as R2 decreases and goes negative, the results
      converge.

      If the sum of the weights is not almost exactly 1.0 (see the option,
      -sum_weight_tolerance), the weights will also default to equal (see
      option -def_to_equal).

      Basically, the program is designed such that either a reasonable T2*
      is computed and applied, or the weighting result will be 1/num_echoes.

   ----------------------------------------------------------------------

   required parameters:

      -echo_times "TE1 TE2 ..." - specify echo times
                                  (use quotes to pass list as one parameter)

           e.g. -echo_times "15 30.5 41"

        Specify echo times as a list.

        Use either -echo_times or -echo_times_files.


      -echo_times_file FILE     - specify file with echo times
                                  (e.g. it could contain 15 30.5 41)

        Specify echo times from a text file.

        Use either -echo_times or -echo_times_files.


      -echo_dsets D1 D2 D3      - specify one run of multi-echo EPI data, e.g.:

           e.g. -echo_dsets pb03.SUBJ.r01.e*.volreg+tlrc.HEAD

        Provide the echo datasets for a single run of multi-echo EPI data.


   general options:

      -def_to_equal yes/no      - specify whether to default to equal weights
                                  (default = no)

        In the case where T2* seems huge or <= 0, or if the sum of the
        fractional weights is not close to 1 (see -tolerance), one might
        want to apply default weights equal to 1/num_echoes (so echoes
        are weighted equally).

        Without this, the weighting for such 'bad' voxels is based on the
        T2* limit.  See -t2_star_limit.

      -oc_method METHOD         - specify which method to employ

           e.g.     -oc_method OC_B
           default: -oc_method OC_A

        The OC_B method differs from OC_A by solving for T2* using log(mean())
        to solving log() over time, with the intention of being more accurate.

           methods:

              OC_A      : compute T2* from log(mean(time series))
                          this is the original implementation

              OC_B      : compute T2* from log(time series)

         * So far, testing has shown almost undetectable differences, so it
           may be a moot point.

      -prefix PREFIX            - specify prefix of resulting OC weights dataset

           e.g. -prefix OC.weights.SUBJ

      -sum_weight_tolerance TOL - tolerance for summed weight diff from 1.0
                                  (default = 0.001)

           e.g. -sum_weight_tolerance 0.0001

        This option only applies to the "-def_to_equal yes" case.

        If echo means (at some voxel) do not follow a decay curve, there
        could be truncation errors in weighting computation that lead to
        weights which do not sum to 1.0.  If abs(1-sum) > tolerance, such a
        voxel will be set in the tolerance.fail dataset.

        The default effect of this failure is to get equal weights across
        the echoes.

      -t2_star_limit LIMIT      - specify limit for T2* values
                                  (default = 300)

        When the system of equations does not show a reasonably fast decay,
        the slopes will be such that T2* is huge or possibly negative.  In such
        cases, it is applied as the LIMIT from this option.

      -work_dir WDIR            - specify directory to compute results in

        All the processing is done in a new sub-directory.  If this program
        is to be applied one run at a time, it is important to specify such
        working directories to keep the names unique.

      -verb                     - increase verbosity of output


   terminal options:

       -help
       -hist
       -ver

   ----------------------------------------------------------------------
   R Reynolds, Feb, 2016               Thanks to Javier Gonzalez-Castillo
  
EOF
exit


# ===========================================================================
# show history 
HIST:
cat << EOF

modification history for $prog :

   0.0  12 Feb 2018 : started
   0.1  16 Feb 2018 : ready for use
                    - added -def_to_equal, -sum_weight_tolerance
                    - allow echo times file to be horizontal, too
                    - do not apply T2*, but leave as reciprocal
                    - help updates
   0.2  21 Feb 2018 : more babble; apply T2* < 0 as T2* limit
   0.3  23 Feb 2018 : added -echo_times, to list on command line
   0.4  10 Dec 2019 : added 3dMean -weightset example
   0.5  20 Dec 2019 : added 3dMean hint at end of execution

EOF
exit
