#! /usr/bin/env python3
"""unwrap_parallel — run snaphu unwrapping in parallel across multiple intf dirs.

Python port of csh unwrap_parallel.csh (X. Xu 2018). The csh version used
GNU parallel; this port uses Python's multiprocessing.Pool for the same
fan-out without the external dependency.

For each line in `intflist`, runs `unwrap_intf` in that directory (or, if
no `unwrap_intf` script is present in the current dir, falls back to
running `snaphu THRESHOLD 0 INTERP` inside the dir directly).

Usage:  unwrap_parallel intflist Ncores [snaphu_threshold] [interp]

Defaults:
  snaphu_threshold = 0.1
  interp           = 0 (use plain snaphu; 1 = snaphu with nearest_grid interp)

Each interferogram dir gets its own log_<dir>.txt.

Note: the legacy convention expects a local `unwrap_intf.csh` script in
cwd. If present, the Python port honors that. If absent, runs snaphu
directly with the given threshold/interp.
"""
import multiprocessing
import os
import subprocess
import sys


def _unwrap_one(args):
    intf_dir, threshold, interp, custom_script = args
    log_file = f"log_{intf_dir}.txt"
    if custom_script:
        cmd = f"{custom_script} {intf_dir} > {log_file} 2>&1"
    else:
        cmd = f"cd {intf_dir} && snaphu.py {threshold} 0 {interp} > ../{log_file} 2>&1"
    rc = subprocess.run(cmd, shell=True).returncode
    return (intf_dir, rc)


def unwrap_parallel():
    if len(sys.argv) not in (3, 4, 5):
        sys.exit(
            "Usage: unwrap_parallel intflist Ncores [snaphu_threshold] [interp]\n"
            "  Run snaphu unwrapping jobs in parallel.\n"
            "  Defaults: threshold=0.1, interp=0.\n"
            "  Run from the intf_all/ folder containing the interferogram subdirs."
        )
    intflist, ncores = sys.argv[1], int(sys.argv[2])
    threshold = sys.argv[3] if len(sys.argv) >= 4 else "0.1"
    interp = sys.argv[4] if len(sys.argv) >= 5 else "0"

    # Legacy convention: a custom unwrap_intf or unwrap_intf.csh in cwd.
    custom_script = None
    for cand in ("unwrap_intf", "unwrap_intf.csh"):
        if os.path.isfile(cand) and os.access(cand, os.X_OK):
            custom_script = f"./{cand}"
            break

    with open(intflist) as f:
        intf_dirs = [ln.strip() for ln in f if ln.strip()]

    print(f"unwrap_parallel: {len(intf_dirs)} jobs across {ncores} cores")
    print(f"  threshold={threshold}, interp={interp}, custom={custom_script or 'none'}")

    args_iter = [(d, threshold, interp, custom_script) for d in intf_dirs]
    with multiprocessing.Pool(processes=ncores) as pool:
        results = pool.map(_unwrap_one, args_iter)

    failed = [d for d, rc in results if rc != 0]
    if failed:
        print(f"WARN: {len(failed)} job(s) exited non-zero: {failed}", file=sys.stderr)
    print(f"Finished {len(results)} unwrap jobs ({len(results) - len(failed)} OK)")


if __name__ == "__main__":
    unwrap_parallel()
