257 lines
9.3 KiB
Python
Executable file
257 lines
9.3 KiB
Python
Executable file
from pprint import pprint
|
||
|
||
import os
|
||
import re
|
||
import shutil
|
||
import tempfile
|
||
|
||
from skimage.metrics import normalized_mutual_information
|
||
|
||
import itk
|
||
|
||
NumberOfIterations = 2100
|
||
# NumberOfIterations = 4000
|
||
|
||
metric_patern = r'Final metric value = (\S+)'
|
||
metric_prog = re.compile(metric_patern)
|
||
|
||
|
||
|
||
class elastix_reg:
|
||
|
||
|
||
def register_aux(self, fi, mv, debug=False, MaximumNumberOfIterations=[str(NumberOfIterations)]):
|
||
|
||
parameter_object = itk.ParameterObject.New()
|
||
default_rigid_parameter_map = parameter_object.GetDefaultParameterMap('rigid')
|
||
default_rigid_parameter_map["AutomaticTransformInitialization"] = ["true"]
|
||
# default_rigid_parameter_map["NumberOfSamplesForExactGradient"] = ["100000"]
|
||
default_rigid_parameter_map['MaximumNumberOfIterations'] = MaximumNumberOfIterations
|
||
parameter_object.AddParameterMap(default_rigid_parameter_map)
|
||
|
||
# pprint(default_rigid_parameter_map.asdict())
|
||
# exit()
|
||
|
||
outdir1 = tempfile.mkdtemp()
|
||
|
||
try:
|
||
fm, params1 = itk.elastix_registration_method(
|
||
fi, mv,
|
||
parameter_object=parameter_object,
|
||
# log_to_console=True,
|
||
log_to_file=True,
|
||
output_directory = outdir1,
|
||
)
|
||
except Exception as ex:
|
||
print(ex)
|
||
print(os.path.join(outdir1, 'elastix.log'))
|
||
# exit()
|
||
return {
|
||
'metrics': 0,
|
||
}
|
||
|
||
TransformParameterFileName = os.path.join(outdir1, 'TransformParameters.0.txt')
|
||
|
||
# print(TransformParameterFileName)
|
||
# exit()
|
||
|
||
'''
|
||
The DisplacementMagnitudePenalty is a cost function that penalises ||Tμ(x) − x||2. You can use this
|
||
to invert transforms, by setting the transform to be inverted as an initial transform (using -t0), setting
|
||
(HowToCombineTransforms "Compose"), and running elastix with this metric, using the original fixed
|
||
image set both as fixed (-f) and moving (-m) image. After that you can manually set the initial transform
|
||
in the last parameter file to "NoInitialTransform", and voila, you have the inverse transform! Strictly
|
||
speaking, you should then also change the Size/Spacing/Origin/Index/Direction settings to match that of
|
||
the moving image. Select it with:
|
||
(Metric "DisplacementMagnitudePenalty")
|
||
Note that inverting a transformation becomes conceptually very similar to performing an image registration
|
||
in this way. Consequently, the same choices are relevant: optimisation algorithm, multiresolution etc...
|
||
Note that this procedure was described and evaluated in Metz et al. [2011].
|
||
'''
|
||
|
||
parameter_object2 = itk.ParameterObject.New()
|
||
inverse_rigid_parameter_map = parameter_object.GetDefaultParameterMap('rigid')
|
||
inverse_rigid_parameter_map["HowToCombineTransforms"] = ["Compose"]
|
||
inverse_rigid_parameter_map["Metric"] = ["DisplacementMagnitudePenalty"]
|
||
|
||
# inverse_rigid_parameter_map["AutomaticTransformInitialization"] = ["true"]
|
||
inverse_rigid_parameter_map['MaximumNumberOfIterations'] = MaximumNumberOfIterations
|
||
# inverse_rigid_parameter_map['UseAdaptiveStepSizes'] = ['false']
|
||
parameter_object2.AddParameterMap(inverse_rigid_parameter_map)
|
||
|
||
|
||
# print(TransformParameterFileName)
|
||
# exit()
|
||
|
||
|
||
outdir2 = tempfile.mkdtemp()
|
||
mm, params2 = itk.elastix_registration_method(
|
||
mv, mv,
|
||
parameter_object=parameter_object2,
|
||
initial_transform_parameter_file_name = TransformParameterFileName,
|
||
log_to_console=debug,
|
||
log_to_file=True,
|
||
output_directory = outdir2,
|
||
)
|
||
|
||
elastix_log = os.path.join(outdir2, 'elastix.log')
|
||
with open(elastix_log) as log:
|
||
m = re.search(metric_prog, log.read())
|
||
|
||
DisplacementMagnitudePenalty = float(m[1])
|
||
# print(DisplacementMagnitudePenalty)
|
||
# exit()
|
||
|
||
last_parameter_map = params2.GetParameterMap(0)
|
||
# pprint(last_parameter_map.asdict())
|
||
# exit()
|
||
# WARNING: The parameter name "InitialTransformParametersFileName" is deprecated. Please use "InitialTransformParameterFileName" (without letter 's') instead.
|
||
last_parameter_map["InitialTransformParametersFileName"] = ["NoInitialTransform"]
|
||
last_parameter_map["InitialTransformParameterFileName"] = ["NoInitialTransform"]
|
||
|
||
params2.SetParameterMap(0, last_parameter_map)
|
||
# params2.WriteParameterFile('123.txt')
|
||
mf = itk.transformix_filter(
|
||
fi,
|
||
params2)
|
||
|
||
m1 = normalized_mutual_information(itk.GetArrayViewFromImage(fi), itk.GetArrayViewFromImage(fm))
|
||
m2 = normalized_mutual_information(itk.GetArrayViewFromImage(mv), itk.GetArrayViewFromImage(mf))
|
||
|
||
print(MaximumNumberOfIterations, m1, m2, DisplacementMagnitudePenalty)
|
||
|
||
shutil.rmtree(outdir1)
|
||
shutil.rmtree(outdir2)
|
||
# exit()
|
||
|
||
|
||
return {
|
||
'fwdtransforms': params1,
|
||
'invtransforms': params2,
|
||
'warpedfixout': mf,
|
||
'warpedmovout': fm,
|
||
'metrics': max(m1, m2),
|
||
'DisplacementMagnitudePenalty': DisplacementMagnitudePenalty,
|
||
}
|
||
|
||
|
||
# PixelType = itk.F
|
||
# Dimension = 3
|
||
# ImageType = itk.Image[PixelType, Dimension]
|
||
|
||
|
||
# METRIC_THRESHOLD = 1.1
|
||
def __init__(self, fi, mv, warpedfixout=None, warpedmovout=None, debug=False, iterations_init=NumberOfIterations):
|
||
self.debug = debug
|
||
|
||
# reader = itk.ImageFileReader[ImageType].New()
|
||
# reader.SetFileName(fi)
|
||
# reader.Update()
|
||
# fixed_image = reader.GetOutput()
|
||
|
||
fixed_image = itk.imread(fi, itk.F)
|
||
moving_image = itk.imread(mv, itk.F)
|
||
|
||
iterations = iterations_init
|
||
# iterations_fin = iterations_init*2
|
||
|
||
|
||
while True:
|
||
MaximumNumberOfIterations = [str(iterations)]
|
||
r1 = self.register_aux(fixed_image, moving_image, debug, MaximumNumberOfIterations=MaximumNumberOfIterations)
|
||
|
||
if 'DisplacementMagnitudePenalty' not in r1: # None?
|
||
break
|
||
if r1['DisplacementMagnitudePenalty'] < 1:
|
||
break
|
||
# elif r1['metrics'] > METRIC_THRESHOLD:
|
||
# Redo = False
|
||
|
||
# if iterations>iterations_fin:
|
||
# Redo = False
|
||
|
||
iterations *= 2
|
||
|
||
|
||
while True:
|
||
MaximumNumberOfIterations = [str(iterations)]
|
||
r2 = self.register_aux(moving_image, fixed_image, debug, MaximumNumberOfIterations=MaximumNumberOfIterations)
|
||
|
||
if 'DisplacementMagnitudePenalty' not in r2: # None?
|
||
break
|
||
elif r2['DisplacementMagnitudePenalty'] < 1:
|
||
break
|
||
# elif r2['metrics'] > METRIC_THRESHOLD:
|
||
# Redo = False
|
||
|
||
# if iterations>iterations_fin:
|
||
# Redo = False
|
||
|
||
iterations *= 2
|
||
|
||
|
||
|
||
if r1['metrics'] > r2['metrics']:
|
||
self.res = r1
|
||
else:
|
||
|
||
if 'invtransforms' not in r2:
|
||
return None
|
||
|
||
self.res = dict(r2)
|
||
self.res.update({
|
||
'fwdtransforms': r2['invtransforms'],
|
||
'invtransforms': r2['fwdtransforms'],
|
||
'warpedfixout': r2['warpedmovout'],
|
||
'warpedmovout': r2['warpedfixout'],
|
||
})
|
||
|
||
|
||
assert self.res['DisplacementMagnitudePenalty'] < 1, 'DisplacementMagnitudePenalty: %f ' % (self.res['DisplacementMagnitudePenalty'])
|
||
|
||
|
||
if warpedfixout is not None:
|
||
itk.imwrite(self.res['warpedfixout'], warpedfixout)
|
||
if warpedmovout is not None:
|
||
itk.imwrite(self.res['warpedmovout'], warpedmovout)
|
||
|
||
if debug:
|
||
pprint(self.res)
|
||
itk.imwrite(fixed_image, '0fixed.nii.gz')
|
||
itk.imwrite(moving_image, '0moving.nii.gz')
|
||
itk.imwrite(r1['warpedfixout'], '0mf1.nii.gz')
|
||
itk.imwrite(r1['warpedmovout'], '0fm1.nii.gz')
|
||
itk.imwrite(r2['warpedmovout'], '0mf2.nii.gz')
|
||
itk.imwrite(r2['warpedfixout'], '0fm2.nii.gz')
|
||
|
||
# return res
|
||
|
||
def get_metrics(self):
|
||
return self.res['metrics']
|
||
|
||
def write_warpedmovout(self, out):
|
||
itk.imwrite(self.res['warpedmovout'], out)
|
||
|
||
def transform(self, moving, output_filename, is_label=False):
|
||
transform1 = self.res['fwdtransforms']
|
||
mv = itk.imread(moving)
|
||
last_parameter_map = transform1.GetParameterMap(0)
|
||
|
||
if is_label:
|
||
# last_parameter_map["InitialTransformParametersFileName"] = ["NoInitialTransform"]
|
||
last_parameter_map["ResampleInterpolator"] = ["FinalNearestNeighborInterpolator"]
|
||
# last_parameter_map["ResultImagePixelType"] = ["unsigned char"]
|
||
|
||
t2 = itk.ParameterObject.New()
|
||
t2.AddParameterMap(last_parameter_map)
|
||
# pprint(t2.GetParameterMap(0).asdict())
|
||
|
||
output = itk.transformix_filter(
|
||
mv.astype(itk.F),
|
||
t2,
|
||
log_to_console=self.debug,
|
||
)
|
||
|
||
if is_label:
|
||
output=output.astype(itk.UC)
|
||
itk.imwrite(output, output_filename)
|