Source code for ndreg
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function
import numpy as np
import SimpleITK as sitk
import matplotlib.pyplot as plt
import tempfile
import shutil
from matplotlib.ticker import ScalarFormatter
import util, registerer, preprocessor, plotter
[docs]def register_brain(atlas, img, modality, outdir=None):
"""Register 3D mouse brain to the Allen Reference atlas using affine and deformable registration.
Parameters:
----------
atlas : {SimpleITK.SimpleITK.Image}
Allen reference atlas or other atlas to register data to.
img : {SimpleITK.SimpleITK.Image}
Input observed 3D mouse brain volume
modality : {str}
Can be 'lavision' or 'colm' for either modality.
outdir : {str}, optional
Path to output directory to store intermediates. (the default is None, which will store all outputs in './')
Returns
-------
SimpleITK.SimpleITK.Image
The atlas deformed to fit the input image.
"""
if outdir is None: outdir = './'
final_transform = register_affine(sitk.Normalize(atlas),
img,
learning_rate=1e-1,
grad_tol=4e-6,
use_mi=False,
iters=50,
shrink_factors=[4,2,1],
sigmas=[0.4, 0.2, 0.1],
verbose=False)
# save the affine transformation to outdir
# make the dir if it doesn't exist
util.dir_make(outdir)
sitk.WriteTransform(final_transform, outdir + 'atlas_to_observed_affine.txt')
atlas_affine = registerer.resample(atlas, final_transform, img, default_value=util.img_percentile(atlas,0.01))
img_affine = registerer.resample(img, final_transform.GetInverse(), atlas, default_value=util.img_percentile(img,0.01))
# whiten both images only before lddmm
atlas_affine_w = sitk.AdaptiveHistogramEqualization(atlas_affine, [10,10,10], alpha=0.25, beta=0.25)
img_w = sitk.AdaptiveHistogramEqualization(img, [10,10,10], alpha=0.25, beta=0.25)
# then run lddmm
e = 5e-3
s = 0.1
atlas_lddmm, field, inv_field = register_lddmm(sitk.Normalize(atlas_affine_w),
sitk.Normalize(img_w),
alpha_list=[0.05],
scale_list = [0.0625, 0.125, 0.25, 0.5, 1.0],
epsilon_list=e, sigma=s,
min_epsilon_list=e*1e-6,
use_mi=False, iterations=50, verbose=True,
out_dir=outdir + 'lddmm')
return atlas_lddmm
[docs]def register_affine(atlas, img, learning_rate=1e-2, iters=200, min_step=1e-10, shrink_factors=[1], sigmas=[.150], use_mi=False, grad_tol=1e-6, verbose=False):
"""
Performs affine registration between an atlas an an image given that they have the same spacing.
"""
registration_method = sitk.ImageRegistrationMethod()
# Similarity metric settings.
# registration_method.SetMetricAsMeanSquares()
if use_mi: registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=128)
else: registration_method.SetMetricAsMeanSquares()
# registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
# registration_method.SetMetricSamplingPercentage(0.01)
registration_method.SetInterpolator(sitk.sitkBSpline)
# Optimizer settings.
registration_method.SetOptimizerAsRegularStepGradientDescent(learningRate=learning_rate,
minStep=min_step,
# estimateLearningRate=registration_method.EachIteration,
gradientMagnitudeTolerance=grad_tol,
numberOfIterations=iters)
registration_method.SetOptimizerScalesFromPhysicalShift()
# Setup for the multi-resolution framework.
registration_method.SetShrinkFactorsPerLevel(shrinkFactors=shrink_factors)
registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=sigmas)
registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()
# initial transform
initial_transform = sitk.AffineTransform(atlas.GetDimension())
length = np.array(atlas.GetSize())*np.array(atlas.GetSpacing())
initial_transform.SetCenter(length/2.0)
# Don't optimize in-place, we would possibly like to run this cell multiple times.
registration_method.SetInitialTransform(initial_transform)
# Connect all of the observers so that we can perform plotting during registration.
if verbose:
registration_method.AddCommand(sitk.sitkStartEvent, util.start_plot)
registration_method.AddCommand(sitk.sitkEndEvent, util.end_plot)
registration_method.AddCommand(sitk.sitkMultiResolutionIterationEvent, util.update_multires_iterations)
registration_method.AddCommand(sitk.sitkIterationEvent, lambda: util.plot_values(registration_method))
final_transform = registration_method.Execute(sitk.Cast(img, sitk.sitkFloat32),
sitk.Cast(atlas, sitk.sitkFloat32))
return final_transform
[docs]def register_lddmm(affine_img, target_img, alpha_list=0.05, scale_list=[0.0625, 0.125, 0.25, 0.5, 1.0],
epsilon_list=1e-4, min_epsilon_list=1e-10, sigma=0.1, use_mi=False, iterations=200, inMask=None,
refMask=None, verbose=True, out_dir=''):
if sigma == None:
sigma = (0.1/target_img.GetNumberOfPixels())
(field, invField) = imgMetamorphosisComposite(affine_img, target_img,
alphaList=alpha_list,
scaleList=scale_list,
epsilonList=epsilon_list,
minEpsilonList=min_epsilon_list,
sigma=sigma,
useMI=use_mi,
inMask=inMask,
refMask=refMask,
iterations=iterations,
verbose=verbose,
outDirPath=out_dir)
source_lddmm = registerer.imgApplyField(affine_img, field,
size=target_img.GetSize(),
spacing=target_img.GetSpacing())
return source_lddmm, field, invField
[docs]def register_rigid(atlas, img, learning_rate=1e-2, iters=200, min_step=1e-10, shrink_factors=[1], sigmas=[.150], use_mi=False, grad_tol=1e-6, verbose=False):
"""
Performs affine registration between an atlas an an image given that they have the same spacing.
"""
registration_method = sitk.ImageRegistrationMethod()
# Similarity metric settings.
# registration_method.SetMetricAsMeanSquares()
if use_mi: registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=128)
else: registration_method.SetMetricAsMeanSquares()
# registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
# registration_method.SetMetricSamplingPercentage(0.01)
registration_method.SetInterpolator(sitk.sitkBSpline)
# Optimizer settings.
registration_method.SetOptimizerAsRegularStepGradientDescent(learningRate=learning_rate,
minStep=min_step,
# estimateLearningRate=registration_method.EachIteration,
gradientMagnitudeTolerance=grad_tol,
numberOfIterations=iters)
registration_method.SetOptimizerScalesFromPhysicalShift()
# Setup for the multi-resolution framework.
registration_method.SetShrinkFactorsPerLevel(shrinkFactors=shrink_factors)
registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=sigmas)
registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn()
# initial transform
initial_transform = sitk.VersorRigid3DTransform()
length = np.array(atlas.GetSize())*np.array(atlas.GetSpacing())
initial_transform.SetCenter(length/2.0)
# Don't optimize in-place, we would possibly like to run this cell multiple times.
registration_method.SetInitialTransform(initial_transform)
# Connect all of the observers so that we can perform plotting during registration.
if verbose:
registration_method.AddCommand(sitk.sitkStartEvent, util.start_plot)
registration_method.AddCommand(sitk.sitkEndEvent, util.end_plot)
registration_method.AddCommand(sitk.sitkMultiResolutionIterationEvent, util.update_multires_iterations)
registration_method.AddCommand(sitk.sitkIterationEvent, lambda: util.plot_values(registration_method))
final_transform = registration_method.Execute(sitk.Cast(img, sitk.sitkFloat32),
sitk.Cast(atlas, sitk.sitkFloat32))
return final_transform
[docs]def imgMetamorphosis(inImg, refImg, alpha=0.02, beta=0.05, scale=1.0, iterations=1000, epsilon=None, minEpsilon=None, sigma=1e-4, useNearest=False,
useBias=False, useMI=False, verbose=False, debug=False, inMask=None, refMask=None, outDirPath=""):
"""
Performs Metamorphic LDDMM between input and reference images
"""
useTempDir = False
if outDirPath == "":
useTempDir = True
outDirPath = tempfile.mkdtemp() + "/"
else:
outDirPath = util.dir_make(outDirPath)
inPath = outDirPath + "in.img"
util.imgWrite(inImg, inPath)
refPath = outDirPath + "ref.img"
util.imgWrite(refImg, refPath)
outPath = outDirPath + "out.img"
fieldPath = outDirPath + "field.vtk"
invFieldPath = outDirPath + "invField.vtk"
binPath = util.ndregDirPath + "metamorphosis "
steps = 5
command = binPath + " --in {0} --ref {1} --out {2} --alpha {3} --beta {4} --field {5} --invfield {6} --iterations {7} --scale {8} --steps {9} --verbose ".format(
inPath, refPath, outPath, alpha, beta, fieldPath, invFieldPath, iterations, scale, steps)
if(not useBias):
command += " --mu 0"
if(useMI):
command += " --cost 1 --sigma {}".format(sigma)
if not(epsilon is None):
command += " --epsilon {0}".format(epsilon)
else:
command += " --epsilon 1e-3"
else:
command += " --sigma {}".format(sigma)
if not(epsilon is None):
command += " --epsilon {0}".format(epsilon)
else:
command += " --epsilon 1e-3"
if not(minEpsilon is None):
command += " --epsilonmin {0}".format(minEpsilon)
if(inMask):
inMaskPath = outDirPath + "inMask.img"
util.imgWrite(inMask, inMaskPath)
command += " --inmask " + inMaskPath
if(refMask):
refMaskPath = outDirPath + "refMask.img"
util.imgWrite(refMask, refMaskPath)
command += " --refmask " + refMaskPath
if debug:
command = "/usr/bin/time -v " + command
print(command)
# os.system(command)
(_, logText) = util.run_shell_command(command, verbose=verbose)
logPath = outDirPath + "log.txt"
util.txt_write(logText, logPath)
field = util.imgRead(fieldPath)
invField = util.imgRead(invFieldPath)
if useTempDir:
shutil.rmtree(outDirPath)
return (field, invField)
[docs]def imgMetamorphosisComposite(inImg, refImg, alphaList=0.02, betaList=0.05, scaleList=1.0, iterations=1000, epsilonList=None, minEpsilonList=None, sigma=1e-4,
useNearest=False, useBias=False, useMI=False, inMask=None, refMask=None, verbose=True, debug=False, outDirPath=""):
"""
Performs Metamorphic LDDMM between input and reference images
"""
useTempDir = False
if outDirPath == "":
useTempDir = True
outDirPath = tempfile.mkdtemp() + "/"
else:
outDirPath = util.dir_make(outDirPath)
if util.is_number(alphaList):
alphaList = [float(alphaList)]
if util.is_number(betaList):
betaList = [float(betaList)]
if util.is_number(scaleList):
scaleList = [float(scaleList)]
numSteps = max(len(alphaList), len(betaList), len(scaleList))
if util.is_number(epsilonList):
epsilonList = [float(epsilonList)] * numSteps
elif epsilonList is None:
epsilonList = [None] * numSteps
if util.is_number(minEpsilonList):
minEpsilonList = [float(minEpsilonList)] * numSteps
elif minEpsilonList is None:
minEpsilonList = [None] * numSteps
if len(alphaList) != numSteps:
if len(alphaList) != 1:
raise Exception(
"Legth of alphaList must be 1 or same length as betaList or scaleList")
else:
alphaList *= numSteps
if len(betaList) != numSteps:
if len(betaList) != 1:
raise Exception(
"Legth of betaList must be 1 or same length as alphaList or scaleList")
else:
betaList *= numSteps
if len(scaleList) != numSteps:
if len(scaleList) != 1:
raise Exception(
"Legth of scaleList must be 1 or same length as alphaList or betaList")
else:
scaleList *= numSteps
origInImg = inImg
origInMask = inMask
for step in range(numSteps):
alpha = alphaList[step]
beta = betaList[step]
scale = scaleList[step]
epsilon = epsilonList[step]
minEpsilon = minEpsilonList[step]
stepDirPath = outDirPath + "step" + str(step) + "/"
if(verbose):
print("\nStep {0}: alpha={1}, beta={2}, scale={3}".format(
step, alpha, beta, scale))
(field, invField) = imgMetamorphosis(inImg, refImg,
alpha,
beta,
scale,
iterations,
epsilon,
minEpsilon,
sigma,
useNearest,
useBias,
useMI,
verbose,
debug,
inMask=inMask,
refMask=refMask,
outDirPath=stepDirPath)
if step == 0:
compositeField = field
compositeInvField = invField
else:
compositeField = fieldApplyField(field, compositeField)
compositeInvField = fieldApplyField(compositeInvField, invField, size=field.GetSize(
), spacing=field.GetSpacing()) # force invField to be same size as field
if outDirPath != "":
fieldPath = stepDirPath + "field.vtk"
invFieldPath = stepDirPath + "invField.vtk"
util.imgWrite(compositeInvField, invFieldPath)
util.imgWrite(compositeField, fieldPath)
inImg = imgApplyField(origInImg, compositeField, size=refImg.GetSize())
if(inMask):
inMask = imgApplyField(origInMask,
compositeField, size=refImg.GetSize(), useNearest=True)
# Write final results
if outDirPath != "":
util.imgWrite(compositeField, outDirPath + "field.vtk")
util.imgWrite(compositeInvField, outDirPath + "invField.vtk")
util.imgWrite(inImg, outDirPath + "out.img")
util.imgWrite(plotter.imgChecker(inImg, refImg), outDirPath + "checker.img")
if useTempDir:
shutil.rmtree(outDirPath)
return (compositeField, compositeInvField)