Source code for ndreg

#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import print_function
import numpy as np
import SimpleITK as sitk
import matplotlib.pyplot as plt
import tempfile
import shutil
from matplotlib.ticker import ScalarFormatter
import util, registerer, preprocessor, plotter

[docs]def register_brain(atlas, img, modality, outdir=None): """Register 3D mouse brain to the Allen Reference atlas using affine and deformable registration. Parameters: ---------- atlas : {SimpleITK.SimpleITK.Image} Allen reference atlas or other atlas to register data to. img : {SimpleITK.SimpleITK.Image} Input observed 3D mouse brain volume modality : {str} Can be 'lavision' or 'colm' for either modality. outdir : {str}, optional Path to output directory to store intermediates. (the default is None, which will store all outputs in './') Returns ------- SimpleITK.SimpleITK.Image The atlas deformed to fit the input image. """ if outdir is None: outdir = './' final_transform = register_affine(sitk.Normalize(atlas), img, learning_rate=1e-1, grad_tol=4e-6, use_mi=False, iters=50, shrink_factors=[4,2,1], sigmas=[0.4, 0.2, 0.1], verbose=False) # save the affine transformation to outdir # make the dir if it doesn't exist util.dir_make(outdir) sitk.WriteTransform(final_transform, outdir + 'atlas_to_observed_affine.txt') atlas_affine = registerer.resample(atlas, final_transform, img, default_value=util.img_percentile(atlas,0.01)) img_affine = registerer.resample(img, final_transform.GetInverse(), atlas, default_value=util.img_percentile(img,0.01)) # whiten both images only before lddmm atlas_affine_w = sitk.AdaptiveHistogramEqualization(atlas_affine, [10,10,10], alpha=0.25, beta=0.25) img_w = sitk.AdaptiveHistogramEqualization(img, [10,10,10], alpha=0.25, beta=0.25) # then run lddmm e = 5e-3 s = 0.1 atlas_lddmm, field, inv_field = register_lddmm(sitk.Normalize(atlas_affine_w), sitk.Normalize(img_w), alpha_list=[0.05], scale_list = [0.0625, 0.125, 0.25, 0.5, 1.0], epsilon_list=e, sigma=s, min_epsilon_list=e*1e-6, use_mi=False, iterations=50, verbose=True, out_dir=outdir + 'lddmm')
return atlas_lddmm
[docs]def register_affine(atlas, img, learning_rate=1e-2, iters=200, min_step=1e-10, shrink_factors=[1], sigmas=[.150], use_mi=False, grad_tol=1e-6, verbose=False): """ Performs affine registration between an atlas an an image given that they have the same spacing. """ registration_method = sitk.ImageRegistrationMethod() # Similarity metric settings. # registration_method.SetMetricAsMeanSquares() if use_mi: registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=128) else: registration_method.SetMetricAsMeanSquares() # registration_method.SetMetricSamplingStrategy(registration_method.RANDOM) # registration_method.SetMetricSamplingPercentage(0.01) registration_method.SetInterpolator(sitk.sitkBSpline) # Optimizer settings. registration_method.SetOptimizerAsRegularStepGradientDescent(learningRate=learning_rate, minStep=min_step, # estimateLearningRate=registration_method.EachIteration, gradientMagnitudeTolerance=grad_tol, numberOfIterations=iters) registration_method.SetOptimizerScalesFromPhysicalShift() # Setup for the multi-resolution framework. registration_method.SetShrinkFactorsPerLevel(shrinkFactors=shrink_factors) registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=sigmas) registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn() # initial transform initial_transform = sitk.AffineTransform(atlas.GetDimension()) length = np.array(atlas.GetSize())*np.array(atlas.GetSpacing()) initial_transform.SetCenter(length/2.0) # Don't optimize in-place, we would possibly like to run this cell multiple times. registration_method.SetInitialTransform(initial_transform) # Connect all of the observers so that we can perform plotting during registration. if verbose: registration_method.AddCommand(sitk.sitkStartEvent, util.start_plot) registration_method.AddCommand(sitk.sitkEndEvent, util.end_plot) registration_method.AddCommand(sitk.sitkMultiResolutionIterationEvent, util.update_multires_iterations) registration_method.AddCommand(sitk.sitkIterationEvent, lambda: util.plot_values(registration_method)) final_transform = registration_method.Execute(sitk.Cast(img, sitk.sitkFloat32), sitk.Cast(atlas, sitk.sitkFloat32))
return final_transform
[docs]def register_lddmm(affine_img, target_img, alpha_list=0.05, scale_list=[0.0625, 0.125, 0.25, 0.5, 1.0], epsilon_list=1e-4, min_epsilon_list=1e-10, sigma=0.1, use_mi=False, iterations=200, inMask=None, refMask=None, verbose=True, out_dir=''): if sigma == None: sigma = (0.1/target_img.GetNumberOfPixels()) (field, invField) = imgMetamorphosisComposite(affine_img, target_img, alphaList=alpha_list, scaleList=scale_list, epsilonList=epsilon_list, minEpsilonList=min_epsilon_list, sigma=sigma, useMI=use_mi, inMask=inMask, refMask=refMask, iterations=iterations, verbose=verbose, outDirPath=out_dir) source_lddmm = registerer.imgApplyField(affine_img, field, size=target_img.GetSize(), spacing=target_img.GetSpacing())
return source_lddmm, field, invField
[docs]def register_rigid(atlas, img, learning_rate=1e-2, iters=200, min_step=1e-10, shrink_factors=[1], sigmas=[.150], use_mi=False, grad_tol=1e-6, verbose=False): """ Performs affine registration between an atlas an an image given that they have the same spacing. """ registration_method = sitk.ImageRegistrationMethod() # Similarity metric settings. # registration_method.SetMetricAsMeanSquares() if use_mi: registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=128) else: registration_method.SetMetricAsMeanSquares() # registration_method.SetMetricSamplingStrategy(registration_method.RANDOM) # registration_method.SetMetricSamplingPercentage(0.01) registration_method.SetInterpolator(sitk.sitkBSpline) # Optimizer settings. registration_method.SetOptimizerAsRegularStepGradientDescent(learningRate=learning_rate, minStep=min_step, # estimateLearningRate=registration_method.EachIteration, gradientMagnitudeTolerance=grad_tol, numberOfIterations=iters) registration_method.SetOptimizerScalesFromPhysicalShift() # Setup for the multi-resolution framework. registration_method.SetShrinkFactorsPerLevel(shrinkFactors=shrink_factors) registration_method.SetSmoothingSigmasPerLevel(smoothingSigmas=sigmas) registration_method.SmoothingSigmasAreSpecifiedInPhysicalUnitsOn() # initial transform initial_transform = sitk.VersorRigid3DTransform() length = np.array(atlas.GetSize())*np.array(atlas.GetSpacing()) initial_transform.SetCenter(length/2.0) # Don't optimize in-place, we would possibly like to run this cell multiple times. registration_method.SetInitialTransform(initial_transform) # Connect all of the observers so that we can perform plotting during registration. if verbose: registration_method.AddCommand(sitk.sitkStartEvent, util.start_plot) registration_method.AddCommand(sitk.sitkEndEvent, util.end_plot) registration_method.AddCommand(sitk.sitkMultiResolutionIterationEvent, util.update_multires_iterations) registration_method.AddCommand(sitk.sitkIterationEvent, lambda: util.plot_values(registration_method)) final_transform = registration_method.Execute(sitk.Cast(img, sitk.sitkFloat32), sitk.Cast(atlas, sitk.sitkFloat32))
return final_transform
[docs]def imgMetamorphosis(inImg, refImg, alpha=0.02, beta=0.05, scale=1.0, iterations=1000, epsilon=None, minEpsilon=None, sigma=1e-4, useNearest=False, useBias=False, useMI=False, verbose=False, debug=False, inMask=None, refMask=None, outDirPath=""): """ Performs Metamorphic LDDMM between input and reference images """ useTempDir = False if outDirPath == "": useTempDir = True outDirPath = tempfile.mkdtemp() + "/" else: outDirPath = util.dir_make(outDirPath) inPath = outDirPath + "in.img" util.imgWrite(inImg, inPath) refPath = outDirPath + "ref.img" util.imgWrite(refImg, refPath) outPath = outDirPath + "out.img" fieldPath = outDirPath + "field.vtk" invFieldPath = outDirPath + "invField.vtk" binPath = util.ndregDirPath + "metamorphosis " steps = 5 command = binPath + " --in {0} --ref {1} --out {2} --alpha {3} --beta {4} --field {5} --invfield {6} --iterations {7} --scale {8} --steps {9} --verbose ".format( inPath, refPath, outPath, alpha, beta, fieldPath, invFieldPath, iterations, scale, steps) if(not useBias): command += " --mu 0" if(useMI): command += " --cost 1 --sigma {}".format(sigma) if not(epsilon is None): command += " --epsilon {0}".format(epsilon) else: command += " --epsilon 1e-3" else: command += " --sigma {}".format(sigma) if not(epsilon is None): command += " --epsilon {0}".format(epsilon) else: command += " --epsilon 1e-3" if not(minEpsilon is None): command += " --epsilonmin {0}".format(minEpsilon) if(inMask): inMaskPath = outDirPath + "inMask.img" util.imgWrite(inMask, inMaskPath) command += " --inmask " + inMaskPath if(refMask): refMaskPath = outDirPath + "refMask.img" util.imgWrite(refMask, refMaskPath) command += " --refmask " + refMaskPath if debug: command = "/usr/bin/time -v " + command print(command) # os.system(command) (_, logText) = util.run_shell_command(command, verbose=verbose) logPath = outDirPath + "log.txt" util.txt_write(logText, logPath) field = util.imgRead(fieldPath) invField = util.imgRead(invFieldPath) if useTempDir: shutil.rmtree(outDirPath)
return (field, invField)
[docs]def imgMetamorphosisComposite(inImg, refImg, alphaList=0.02, betaList=0.05, scaleList=1.0, iterations=1000, epsilonList=None, minEpsilonList=None, sigma=1e-4, useNearest=False, useBias=False, useMI=False, inMask=None, refMask=None, verbose=True, debug=False, outDirPath=""): """ Performs Metamorphic LDDMM between input and reference images """ useTempDir = False if outDirPath == "": useTempDir = True outDirPath = tempfile.mkdtemp() + "/" else: outDirPath = util.dir_make(outDirPath) if util.is_number(alphaList): alphaList = [float(alphaList)] if util.is_number(betaList): betaList = [float(betaList)] if util.is_number(scaleList): scaleList = [float(scaleList)] numSteps = max(len(alphaList), len(betaList), len(scaleList)) if util.is_number(epsilonList): epsilonList = [float(epsilonList)] * numSteps elif epsilonList is None: epsilonList = [None] * numSteps if util.is_number(minEpsilonList): minEpsilonList = [float(minEpsilonList)] * numSteps elif minEpsilonList is None: minEpsilonList = [None] * numSteps if len(alphaList) != numSteps: if len(alphaList) != 1: raise Exception( "Legth of alphaList must be 1 or same length as betaList or scaleList") else: alphaList *= numSteps if len(betaList) != numSteps: if len(betaList) != 1: raise Exception( "Legth of betaList must be 1 or same length as alphaList or scaleList") else: betaList *= numSteps if len(scaleList) != numSteps: if len(scaleList) != 1: raise Exception( "Legth of scaleList must be 1 or same length as alphaList or betaList") else: scaleList *= numSteps origInImg = inImg origInMask = inMask for step in range(numSteps): alpha = alphaList[step] beta = betaList[step] scale = scaleList[step] epsilon = epsilonList[step] minEpsilon = minEpsilonList[step] stepDirPath = outDirPath + "step" + str(step) + "/" if(verbose): print("\nStep {0}: alpha={1}, beta={2}, scale={3}".format( step, alpha, beta, scale)) (field, invField) = imgMetamorphosis(inImg, refImg, alpha, beta, scale, iterations, epsilon, minEpsilon, sigma, useNearest, useBias, useMI, verbose, debug, inMask=inMask, refMask=refMask, outDirPath=stepDirPath) if step == 0: compositeField = field compositeInvField = invField else: compositeField = fieldApplyField(field, compositeField) compositeInvField = fieldApplyField(compositeInvField, invField, size=field.GetSize( ), spacing=field.GetSpacing()) # force invField to be same size as field if outDirPath != "": fieldPath = stepDirPath + "field.vtk" invFieldPath = stepDirPath + "invField.vtk" util.imgWrite(compositeInvField, invFieldPath) util.imgWrite(compositeField, fieldPath) inImg = imgApplyField(origInImg, compositeField, size=refImg.GetSize()) if(inMask): inMask = imgApplyField(origInMask, compositeField, size=refImg.GetSize(), useNearest=True) # Write final results if outDirPath != "": util.imgWrite(compositeField, outDirPath + "field.vtk") util.imgWrite(compositeInvField, outDirPath + "invField.vtk") util.imgWrite(inImg, outDirPath + "out.img") util.imgWrite(plotter.imgChecker(inImg, refImg), outDirPath + "checker.img") if useTempDir: shutil.rmtree(outDirPath)
return (compositeField, compositeInvField)