# local imports
from .util import (
    S3Url,
    upload_file_to_s3,
)
# generate xml_import and terastitcher commands
from configparser import ConfigParser
import math
import argparse
from psutil import virtual_memory
import joblib
import boto3
import os
import subprocess
import shlex
from glob import glob
from tqdm import tqdm
parastitcher_path = f"{os.path.dirname(os.path.realpath(__file__))}/parastitcher.py"
paraconverter_path = f"{os.path.dirname(os.path.realpath(__file__))}/paraconverter.py"
python_path = "python"
STITCH_ONLY, COMPUTE_ONLY, ALL_STEPS = range(3)
[docs]def write_import_xml(fname_importxml, scanned_matrix, metadata):
    """Write xml_import file for Terastitcher based on COLM metadata
    Args:
        fname_importxml (str): Path to wheer xml_import.xml should be stored
        scanned_matrix (list of lists): List of locations that have been imaged by the microscope
        metadata (dict): Metadata assocated with this COLM experiment
    """
    img_regex = ".*.tiff"
    eofl = "\r\n"
    with open(fname_importxml, "w") as fp:
        fp.writelines(
            [
                f'<?xml version="1.0" encoding="UTF-8" ?>{eofl}',
                f'<!DOCTYPE TeraStitcher SYSTEM "TeraStitcher.DTD">{eofl}',
                f'<TeraStitcher volume_format="TiledXY|2Dseries">{eofl}',
                f"\t<stacks_dir value=\"{metadata['stack_dir']}\" />{eofl}",
                f'\t<ref_sys ref1="1" ref2="2" ref3="3" />{eofl}',
                f"\t<voxel_dims V=\"{metadata['voxel_size'][1]}\" H=\"{metadata['voxel_size'][0]}\" D=\"{metadata['voxel_size'][2]}\" />{eofl}",
                f"\t<origin V=\"{metadata['origin'][1]}\" H=\"{metadata['origin'][0]}\" D=\"{metadata['origin'][2]}\" />{eofl}",
                f"\t<mechanical_displacements V=\"{metadata['mechanical_displacements'][1]}\" H=\"{metadata['mechanical_displacements'][0]}\" />{eofl}",
                f"\t<dimensions stack_rows=\"{metadata['grid_size_Y']}\" stack_columns=\"{metadata['grid_size_X']}\" stack_slices=\"{metadata['num_slices']}\" />{eofl}",
                f"\t<STACKS>{eofl}",
            ]
        )
        # print(metadata['grid_size_Y'])
        # print(metadata['grid_size_X'])
        for j in range(metadata["grid_size_Y"]):
            for i in range(metadata["grid_size_X"]):
                abs_X_ef = i * metadata["abs_X"]
                abs_Y_ef = j * metadata["abs_Y"]
                folder_num = i + j * metadata["grid_size_X"]
                dir_name = f"LOC{folder_num:03}"
                if scanned_matrix[j][i] == "1":
                    loc_string = f"\t\t<Stack N_CHANS=\"1\" N_BYTESxCHAN=\"2\" ROW=\"{j}\" COL=\"{i}\" ABS_V=\"{abs_Y_ef}\" ABS_H=\"{abs_X_ef}\" ABS_D=\"0\" STITCHABLE=\"no\" DIR_NAME=\"{dir_name}\" Z_RANGES=\"[0,{metadata['num_slices']})\" IMG_REGEX=\"{img_regex}\">{eofl}"
                else:
                    loc_string = f'\t\t<Stack N_CHANS="1" N_BYTESxCHAN="2" ROW="{j}" COL="{i}" ABS_V="{abs_Y_ef}" ABS_H="{abs_X_ef}" ABS_D="0" STITCHABLE="no" DIR_NAME="" Z_RANGES="" IMG_REGEX="{img_regex}">{eofl}'
                fp.writelines(
                    [
                        loc_string,
                        f"\t\t\t<NORTH_displacements />{eofl}",
                        f"\t\t\t<EAST_displacements />{eofl}",
                        f"\t\t\t<SOUTH_displacements />{eofl}",
                        f"\t\t\t<WEST_displacements />{eofl}",
                        f"\t\t</Stack>{eofl}",
                    ]
                )
        fp.writelines([f"\t</STACKS>{eofl}", f"</TeraStitcher>{eofl}"]) 
[docs]def write_terastitcher_commands(fname_ts, metadata, stitched_dir, do_steps):
    """Generate Terastitcher commands from metadata
    Args:
        fname_ts (str): Path to bash file to store Terastitcher commands
        metadata (dict): Metadata information about experiment
        stitched_dir (str): Path to where stitched data will be stored
        do_steps (int): Indicator of which steps to run
    Returns:
        list of str: List of Terastitcher commands to run
    """
    eofl = "\n"
    subvoldim = 60
    # subvoldim = max(metadata['num_slices']//num_processes,20)
    mem = virtual_memory()
    num_cpus = joblib.cpu_count()
    num_processes = min(
        math.floor(
            mem.total
            / (
                (metadata["num_pix"] ** 2)
                * 4
                * (min(metadata["grid_size_X"], metadata["grid_size_Y"]) + 1)
                * subvoldim
            )
        )
        + 1,
        num_cpus,
    )
    depth = 5
    num_proc_merge = min(
        math.floor(mem.total / (metadata["height"] * metadata["width"] * 2 * depth)),
        num_cpus,
    )
    print(f"num processes to use for stitching is: {num_processes}")
    # step1 = f"terastitcher --test --projin={metadata['stack_dir']}/xml_import.xml --imout_depth=16 --sparse_data{eofl}"
    step2 = f"mpirun -n {num_processes} {python_path} {parastitcher_path} -2 --projin=\"{metadata['stack_dir']}/xml_import.xml\" --projout=\"{metadata['stack_dir']}/xml_displcomp.xml\" --sV={metadata['sV']} --sH={metadata['sH']} --sD={metadata['sD']} --subvoldim={subvoldim} --sparse_data --exectimes --exectimesfile=\"{metadata['stack_dir']}/t_displcomp\"{eofl}"
    step3 = f"terastitcher --displproj --projin=\"{metadata['stack_dir']}/xml_displcomp.xml\" --projout=\"{metadata['stack_dir']}/xml_displproj.xml\" --sparse_data{eofl}"
    step4 = f"terastitcher --displthres --projin=\"{metadata['stack_dir']}/xml_displproj.xml\" --projout=\"{metadata['stack_dir']}/xml_displthres.xml\" --threshold=0.3 --sparse_data{eofl}"
    step5 = f"terastitcher --placetiles --projin=\"{metadata['stack_dir']}/xml_displthres.xml\"{eofl}"
    step6 = f"mpirun -n {num_proc_merge} {python_path} {paraconverter_path} -s=\"{metadata['stack_dir']}/xml_merging.xml\" -d=\"{stitched_dir}\" --sfmt=\"TIFF (unstitched, 3D)\" --dfmt=\"TIFF (series, 2D)\" --height={metadata['height']} --width={metadata['width']} --depth={depth}{eofl}"
    ts_commands = []
    if do_steps == STITCH_ONLY:
        # do step 5 in case xml_merging doesnt exist
        ts_commands.extend([step5, step6])
    elif do_steps == COMPUTE_ONLY:
        # do step 5 twice to fix terastitcher issue where placetiles cant find tiffs in folder        
        ts_commands.extend([step2, step3, step4, step5, step5])
    else:
        # do step 5 twice to fix terastitcher issue
        ts_commands.extend([step2, step3, step4, step5, step5, step6])
    with open(fname_ts, "w") as fp:
        fp.writelines(ts_commands)
    return ts_commands 
[docs]def get_scanned_cells(fname_scanned_cells):
    """Read Scanned Cells.txt file from COLM into list
    Args:
        fname_scanned_cells (str): Path to scanned cells file.
    Returns:
        list of lists: Indicates whether or not a given location has been imaged on the COLM
    """
    # read scanned matrix file
    scanned_matrix = []
    with open(fname_scanned_cells, "r") as fp:
        for line in fp.readlines():
            x = line.strip().split(",")
            scanned_matrix.append(x)
    return scanned_matrix 
[docs]def generate_stitching_commands(
    stitched_dir, stack_dir, metadata_s3_bucket, metadata_s3_path, do_steps=ALL_STEPS
):
    """Generate Terastitcher stitching commands given COLM metadata files.
    Args:
        stitched_dir (str): Path to store stitched data at.
        stack_dir (str): Path to unstiched raw data.
        metadata_s3_bucket (str): Name of S3 bucket in which metdata is located.
        metadata_s3_path (str): Specific path to metadata files in the bucket
        do_steps (int, optional): Represents which Terastitcher steps to run. Defaults to ALL_STEPS (2).
    Returns:
        tuple (dict, list of str): Metadata and list of Terastitcher commands
    """
    # download COLM metadata files
    # if they don't exist locally
    scanned_cells_path = f"{stack_dir}/Scanned Cells.txt"
    config_file_path = f"{stack_dir}/Experiment.ini"
    if not os.path.exists(scanned_cells_path) or not os.path.exists(config_file_path):
        s3 = boto3.resource("s3")
        s3.Object(
            metadata_s3_bucket, f"{metadata_s3_path}/Scanned Cells.txt"
        ).download_file(scanned_cells_path)
        s3.Object(
            metadata_s3_bucket, f"{metadata_s3_path}/Experiment.ini"
        ).download_file(config_file_path)
    # get metadata
    metadata = get_metadata(config_file_path)
    metadata["stack_dir"] = stack_dir
    # load scanned cells to indicate which locations contain data
    scanned_matrix = get_scanned_cells(scanned_cells_path)
    # write xml_import file for terastitcher
    fname_importxml = f"{stack_dir}/xml_import.xml"
    write_import_xml(fname_importxml, scanned_matrix, metadata)
    fname_ts = f"{stack_dir}/terastitcher_commands.sh"
    ts_commands = write_terastitcher_commands(
        fname_ts, metadata, stitched_dir, do_steps
    )
    return metadata, ts_commands 
[docs]def run_terastitcher(
    raw_data_path,
    stitched_data_path,
    input_s3_path,
    log_s3_path=None,
    stitch_only=False,
    compute_only=False,
):
    """Run Terastitcher commands to fully stitch raw data.
    Args:
        raw_data_path (str): Path to raw data (VW0 folder for COLM data)
        stitched_data_path (str): Path to where stitched data will be stored
        input_s3_path (str): S3 Path to where raw data and metadata live
        log_s3_path (str, optional): S3 path to store intermediates and XML files for Terastitcher. Defaults to None.
        stitch_only (bool, optional): Do stitching only if True. Defaults to False.
        compute_only (bool, optional): Compute alignments only if True. Defaults to False.
    Returns:
        dict: Metadata associated with this sample from Experiment.ini file (COLM data)
    """
    input_s3_url = S3Url(input_s3_path.strip("/"))
    if stitch_only:
        do_steps = STITCH_ONLY
    elif compute_only:
        do_steps = COMPUTE_ONLY
    else:
        do_steps = ALL_STEPS
    metadata, commands = generate_stitching_commands(
        stitched_data_path,
        raw_data_path,
        input_s3_url.bucket,
        input_s3_url.key,
        do_steps,
    )
    # run the Terastitcher commands
    mdata = f"{raw_data_path}/mdata.bin"
    for i in commands:
        print(i)
        subprocess.run(shlex.split(i))
        if os.path.exists(mdata):
            os.remove(mdata)
    # # upload xml results to log_s3_path if not None
    # # and if not stitch_only
    if log_s3_path:
        log_s3_url = S3Url(log_s3_path.strip("/"))
        files_to_save = glob(f"{raw_data_path}/*.xml")
        for i in tqdm(files_to_save, desc="saving xml files to S3"):
            out_path = i.split("/")[-1]
            upload_file_to_s3(i, log_s3_url.bucket, f"{log_s3_url.key}/{out_path}")
    return metadata 
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        "Create xml_import.xml file and terastitcher_commands.sh from Experiment.ini file"
    )
    parser.add_argument(
        "--stitched_dir",
        help="Directory to  store stitched tifs.",
        type=str,
        default="/home/ubuntu/ssd2/stitched_data",
    )
    parser.add_argument(
        "--stack_dir",
        help="Path to VW0 directory with tiles stored in LOC* folders.",
        type=str,
        default="/home/ubuntu/ssd1/VW0",
    )
    parser.add_argument(
        "--config_file",
        help="Path to Experiment.ini file",
        type=str,
        default="/home/ubuntu/ssd1/Experiment.ini",
    )
    parser.add_argument(
        "--scanned_cells",
        help="Path to Scanned Cells.txt file",
        type=str,
        default="/home/ubuntu/ssd1/Scanned Cells.txt",
    )
    parser.add_argument(
        "--stitch_only",
        help="If true, only run the stitching commands from existing xml_merging.xml file",
        type=bool,
        default=False,
    )
    args = parser.parse_args()
    generate_stitching_commands(
        args.stitched_dir,
        args.stack_dir,
        args.config_file,
        args.scanned_cells,
        args.stitch_only,
    )