try:
from urlparse import urlparse
except ImportError:
from urllib.parse import urlparse
import contextlib
import joblib
import SimpleITK as sitk
import math
import boto3
import numpy as np
import os
from tqdm import tqdm
import paramiko
from awscli.clidriver import create_clidriver
### image preprocessing
[docs]def get_bias_field(img, mask=None, scale=1.0, niters=[50, 50, 50, 50]):
"""Correct bias field in image using the N4ITK algorithm (http://bit.ly/2oFwAun)
Args:
img (SimpleITK.Image): Input image with bias field.
mask (SimpleITK.Image, optional): If used, the bias field will only be corrected within the mask. (the default is None, which results in the whole image being corrected.)
scale (float, optional): Scale at which to compute the bias correction. (the default is 0.25, which results in bias correction computed on an image downsampled to 1/4 of it's original size)
niters (list, optional): Number of iterations per resolution. Each additional entry in the list adds an additional resolution at which the bias is estimated. (the default is [50, 50, 50, 50] which results in 50 iterations per resolution at 4 resolutions)
Returns:
SimpleITK.Image: Bias-corrected image that has the same size and spacing as the input image.
"""
# do in case image has 0 intensities
# add a small constant that depends on
# distribution of intensities in the image
minmaxfilter = sitk.MinimumMaximumImageFilter()
minmaxfilter.Execute(img)
minval = minmaxfilter.GetMinimum()
img_rescaled = sitk.Cast(img, sitk.sitkFloat32) - minval + 1.0
spacing = np.array(img_rescaled.GetSpacing()) / scale
img_ds = imgResample(img_rescaled, spacing=spacing)
img_ds = sitk.Cast(img_ds, sitk.sitkFloat32)
# Calculate bias
if mask is None:
mask = sitk.Image(img_ds.GetSize(), sitk.sitkUInt8) + 1
mask.CopyInformation(img_ds)
else:
if type(mask) is not sitk.SimpleITK.Image:
mask_sitk = sitk.GetImageFromArray(mask)
mask_sitk.CopyInformation(img)
mask = mask_sitk
mask = imgResample(mask, spacing=spacing)
img_ds_bc = sitk.N4BiasFieldCorrection(img_ds, mask, 0.001, niters)
bias_ds = img_ds_bc / sitk.Cast(img_ds, img_ds_bc.GetPixelID())
# Upsample bias
bias = imgResample(bias_ds, spacing=img.GetSpacing(), size=img.GetSize())
return bias
[docs]def imgResample(img, spacing, size=[], useNearest=False, origin=None, outsideValue=0):
"""Resample image to certain spacing and size.
Args:
img (SimpleITK.Image): Input 3D image.
spacing (list): List of length 3 indicating the voxel spacing as [x, y, z]
size (list, optional): List of length 3 indicating the number of voxels per dim [x, y, z] (the default is [], which will use compute the appropriate size based on the spacing.)
useNearest (bool, optional): If True use nearest neighbor interpolation. (the default is False, which will use linear interpolation.)
origin (list, optional): The location in physical space representing the [0,0,0] voxel in the input image. (the default is [0,0,0])
outsideValue (int, optional): value used to pad are outside image (the default is 0)
Returns:
SimpleITK.Image: Resampled input image.
"""
if origin is None:
origin = [0] * 3
if len(spacing) != img.GetDimension():
raise Exception("len(spacing) != " + str(img.GetDimension()))
# Set Size
if size == []:
inSpacing = img.GetSpacing()
inSize = img.GetSize()
size = [
int(math.ceil(inSize[i] * (inSpacing[i] / spacing[i])))
for i in range(img.GetDimension())
]
else:
if len(size) != img.GetDimension():
raise Exception("len(size) != " + str(img.GetDimension()))
# Resample input image
interpolator = [sitk.sitkLinear, sitk.sitkNearestNeighbor][useNearest]
identityTransform = sitk.Transform()
return sitk.Resample(
img,
size,
identityTransform,
interpolator,
origin,
spacing,
img.GetDirection(),
outsideValue,
)
[docs]def get_reorientations(in_orient, out_orient):
"""Generates a list of axes flips and swaps to convert from in_orient to out_orient
Args:
in_orient (str): 3-letter input orientation
out_orient (str): 3-letter output orientation
Raises:
Exception: Exception raised if in_orient or out_orient not valid
Returns:
tuple of lists: New axis order and whether or not each axis needs to be flipped
"""
dimension = len(in_orient)
in_orient = str(in_orient).lower()
out_orient = str(out_orient).lower()
inDirection = ""
outDirection = ""
orientToDirection = {"r": "r", "l": "r", "s": "s", "i": "s", "a": "a", "p": "a"}
for i in range(dimension):
try:
inDirection += orientToDirection[in_orient[i]]
except BaseException:
raise Exception("in_orient '{0}' is invalid.".format(in_orient))
try:
outDirection += orientToDirection[out_orient[i]]
except BaseException:
raise Exception("out_orient '{0}' is invalid.".format(out_orient))
if len(set(inDirection)) != dimension:
raise Exception("in_orient '{0}' is invalid.".format(in_orient))
if len(set(outDirection)) != dimension:
raise Exception("out_orient '{0}' is invalid.".format(out_orient))
order = []
flip = []
for i in range(dimension):
j = inDirection.find(outDirection[i])
order += [j]
flip += [-1 if in_orient[j] != out_orient[i] else 1]
return order, flip
### AWS stuff
# below code from https://stackoverflow.com/questions/42641315/s3-urls-get-bucket-name-and-path
[docs]class S3Url(object):
"""
>>> s = S3Url("s3://bucket/hello/world")
>>> s.bucket
'bucket'
>>> s.key
'hello/world'
>>> s.url
's3://bucket/hello/world'
>>> s = S3Url("s3://bucket/hello/world?qwe1=3#ddd")
>>> s.bucket
'bucket'
>>> s.key
'hello/world?qwe1=3#ddd'
>>> s.url
's3://bucket/hello/world?qwe1=3#ddd'
>>> s = S3Url("s3://bucket/hello/world#foo?bar=2")
>>> s.key
'hello/world#foo?bar=2'
>>> s.url
's3://bucket/hello/world#foo?bar=2'
"""
def __init__(self, url):
self._parsed = urlparse(url, allow_fragments=False)
@property
def bucket(self):
return self._parsed.netloc
@property
def key(self):
if self._parsed.query:
return self._parsed.path.lstrip("/") + "?" + self._parsed.query
else:
return self._parsed.path.lstrip("/")
@property
def url(self):
return self._parsed.geturl()
[docs]def upload_file_to_s3(local_path, s3_bucket, s3_key):
"""Upload file to S3 from local storage
Args:
local_path (str): Local path to file
s3_bucket (str): S3 bucket name
s3_key (str): S3 key to store file at
"""
s3 = boto3.resource("s3")
s3.meta.client.upload_file(local_path, s3_bucket, s3_key)
# below code from https://github.com/boto/boto3/issues/358#issuecomment-372086466
[docs]def aws_cli(*cmd):
"""Run an AWS CLI command
Raises:
RuntimeError: Error running aws cli command.
"""
old_env = dict(os.environ)
try:
# Environment
env = os.environ.copy()
env["LC_CTYPE"] = "en_US.UTF"
os.environ.update(env)
# Run awscli in the same process
exit_code = create_clidriver().main(*cmd)
# Deal with problems
if exit_code > 0:
raise RuntimeError("AWS CLI exited with code {}".format(exit_code))
finally:
os.environ.clear()
os.environ.update(old_env)
[docs]def start_ec2_instance(instance_id, instance_type):
"""Start an EC2 instance
Args:
instance_id (str): ID of EC2 instance to start
instance_type (str): Type of EC2 instance to start
Returns:
str: Public IP address of EC2 instance
"""
# get ec2 client
ec2 = boto3.resource("ec2")
# stop instance in case it is running
ec2.meta.client.stop_instances(InstanceIds=[instance_id])
waiter = ec2.meta.client.get_waiter("instance_stopped")
waiter.wait(InstanceIds=[instance_id])
# make sure instance is the right type
ec2.meta.client.modify_instance_attribute(
InstanceId=instance_id, Attribute="instanceType", Value=instance_type
)
# start instance
ec2.meta.client.start_instances(InstanceIds=[instance_id])
# wait until instance is started up
waiter = ec2.meta.client.get_waiter("instance_status_ok")
waiter.wait(InstanceIds=[instance_id])
# get instance ip address
instance = ec2.Instance(instance_id)
return instance.public_ip_address
# code from https://alexwlchan.net/2019/07/listing-s3-keys/
[docs]def get_matching_s3_keys(bucket, prefix="", suffix=""):
"""
Generate the keys in an S3 bucket.
Args:
bucket (str): Name of the S3 bucket.
prefix (str): Only fetch keys that start with this prefix (optional).
suffix (str): Only fetch keys that end with this suffix (optional).
Yields:
str: S3 keys if they exist with given prefix and suffix
"""
s3 = boto3.client("s3")
kwargs = {"Bucket": bucket, "Prefix": prefix}
while True:
resp = s3.list_objects_v2(**kwargs)
try:
resp["Contents"]
except Exception as e:
print(e)
return None
for obj in resp["Contents"]:
key = obj["Key"]
if key.endswith(suffix):
yield key
try:
kwargs["ContinuationToken"] = resp["NextContinuationToken"]
except KeyError:
break
### stitching
[docs]def download_terastitcher_files(s3_path, local_path):
"""Download terastitcher files from S3
Args:
s3_path (str): S3 path where Terastitcher files might live
local_path (str): Local path to save Terastitcher files
Returns:
bool: True if files exist at s3 path, else False
"""
default_terastitcher_files = [
"xml_import.xml",
"xml_displcompute.xml",
"xml_dislproj.xml",
"xml_displthres.xml",
"xml_merging.xml"
]
s3 = boto3.resource("s3")
s3_url = S3Url(s3_path)
xml_paths = list(
get_matching_s3_keys(s3_url.bucket, prefix=s3_url.key, suffix="xml")
)
xml_paths = [i for i in xml_paths if i.split('/')[-1] in default_terastitcher_files]
if len(xml_paths) < len(default_terastitcher_files):
# all xml files were not at s3_path
return False
# download xml results to local_path
for i in tqdm(xml_paths, desc="downloading xml files from S3"):
fname = i.split("/")[-1]
s3.meta.client.download_file(s3_url.bucket, i, f"{local_path}/{fname}")
return True
### create precomputed volume
[docs]def calc_hierarchy_levels(img_size, lowest_res=1024):
"""Compute max number of mips for given chunk size
Args:
img_size (list): Size of image in x,y,z
lowest_res (int, optional): minimum chunk size in XY. Defaults to 1024.
Returns:
int: Number of mips
"""
max_xy = max(img_size[0:1])
# we add one because 0 is included in the number of downsampling levels
num_levels = max(1, math.ceil(math.log(max_xy / lowest_res, 2)) + 1)
return num_levels
### misc
[docs]@contextlib.contextmanager
def tqdm_joblib(tqdm_object):
"""Context manager to patch joblib to report into tqdm progress bar given as argument"""
class TqdmBatchCompletionCallback:
def __init__(self, time, index, parallel):
self.index = index
self.parallel = parallel
def __call__(self, index):
tqdm_object.update()
if self.parallel._original_iterator is not None:
self.parallel.dispatch_next()
old_batch_callback = joblib.parallel.BatchCompletionCallBack
joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
try:
yield tqdm_object
finally:
joblib.parallel.BatchCompletionCallBack = old_batch_callback
tqdm_object.close()
[docs]def chunks(l, n):
"""Convert a list into n-size chunks (last chunk may have less than n elements)
Args:
l (list): List to chunk
n (int): Elements per chunk
Yields:
list: n-size chunk from l (last chunk may have fewer than n elements)
"""
for i in range(0, len(l), n):
yield l[i : i + n]
[docs]def run_command_on_server(command, ssh_key_path, ip_address, username="ubuntu"):
"""Run command on remote server
Args:
command (str): Command to run
ssh_key_path (str): Local path to ssh key neeed for this server
ip_address (str): IP Address of server to connect to
username (str, optional): Username on remote server. Defaults to "ubuntu".
Returns:
str: Errors encountered on remote server if any
"""
key = paramiko.RSAKey.from_private_key_file(ssh_key_path)
client = paramiko.SSHClient()
client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
# Connect/ssh to an instance
try:
# Here 'ip_address' is public IP of EC2
client.connect(hostname=ip_address, username=username, pkey=key)
# Execute a command after connecting/ssh to an instance
stdin, stdout, stderr = client.exec_command(command, get_pty=True)
for line in iter(stdout.readline, ""):
print(line, end="")
# output = stdout.read().decode('utf-8')
errors = stderr.read().decode("utf-8")
# close the client connection once the job is done
client.close()
return errors
except Exception as e:
print(e)