#! /usr/bin/env python3
"""stack — compute mean and standard deviation of a stack of grids + plots.

Python port of csh stack.csh (X. Tong, D. Sandwell 2013). Loops over a
list of .grd files, computes the per-pixel mean and stdev, scales both
by `scale`, and produces PDF plots of each.

Usage:  stack grid.list scale mean.grd std.grd
"""
import os
import subprocess
import sys
from gmtsar_lib import run


def _capture(cmd):
    return subprocess.run(cmd, shell=True, stdout=subprocess.PIPE,
                          check=False).stdout.decode("utf-8").strip()


def _plot_grid(grid_path, label):
    """Reproduce the legacy gmt grdgradient + grdimage + psscale + psconvert plot."""
    stem = os.path.basename(grid_path)
    if stem.endswith(".grd"):
        stem = stem[:-4]
    run(f"gmt grdgradient {stem}.grd -Nt.9 -A0. -G{stem}.grad.grd")
    info = _capture(f"gmt grdinfo -C -L2 {stem}.grd").split()
    if len(info) >= 13:
        limitL = f"{float(info[5]):5.1f}"
        limitU = f"{float(info[6]):5.1f}"
    else:
        limitL, limitU = "-1.0", "1.0"
    run(f"gmt makecpt -Cseis -I -Z -T{limitL}/{limitU}/0.1 -D > {stem}.cpt")
    run(f"gmt grdimage {stem}.grd -I{stem}.grad.grd -C{stem}.cpt "
        f"-JX6.5i -Bxaf+lRange -Byaf+lAzimuth -BWSen -X1.3i -Y3i -P -K > {stem}.ps")
    run(f"gmt psscale -R{stem}.grd -J -DJTC+w5/0.2+h+e -C{stem}.cpt "
        f"-Bxaf+l\"{label}\" -By -O >> {stem}.ps")
    run(f"gmt psconvert -Tf -P -A -Z {stem}.ps")
    print(f"Stack plot: {stem}.pdf")
    run(f"rm -f {stem}.cpt {stem}.grad.grd")


def stack():
    if len(sys.argv) != 5:
        sys.exit(
            "Usage: stack grid.list scale mean.grd std.grd\n"
            "  scale: multiplier applied to both outputs (use 1 for no scale).\n"
            "  All grids in grid.list must have consistent dimensions."
        )
    grid_list, scale, outmean, outstd = sys.argv[1:5]
    if not os.path.isfile(grid_list):
        sys.exit(f"no input file found: {grid_list}")

    with open(grid_list) as f:
        files = [ln.strip() for ln in f if ln.strip()]
    if not files:
        sys.exit("stack: empty input list")

    for name in files:
        if not os.path.isfile(name):
            sys.exit(f" Error: file not found: {name}")

    # Mean
    print("computing the mean of the grids ..")
    for i, name in enumerate(files, start=1):
        if i == 1:
            run(f"gmt grdmath {name} = sum.grd")
        else:
            run(f"gmt grdmath {name} sum.grd ADD = sumtmp.grd")
            run("mv sumtmp.grd sum.grd")
    n = len(files)
    run(f"gmt grdmath sum.grd {n} DIV = {outmean}")

    # Stdev
    print("compute the standard deviation ..")
    for i, name in enumerate(files, start=1):
        if i == 1:
            run(f"gmt grdmath {name} {outmean} SUB SQR = sum2.grd")
        else:
            run(f"gmt grdmath {name} {outmean} SUB SQR sum2.grd ADD = sum2tmp.grd")
            run("mv sum2tmp.grd sum2.grd")
    run(f"gmt grdmath sum2.grd {n} DIV SQRT = {outstd}")

    # Scale + cleanup
    run(f"gmt grdmath {outmean} {scale} MUL = tmp.grd")
    run(f"mv tmp.grd {outmean}")
    run(f"gmt grdmath {outstd} {scale} MUL = tmp.grd")
    run(f"mv tmp.grd {outstd}")
    run("rm -f sum.grd sum2.grd")

    _plot_grid(outmean, "Mean of Image Stack")
    _plot_grid(outstd, "Std. Dev. of Image Stack")


if __name__ == "__main__":
    stack()
