123/registration/elastix_reg.py
2025-02-01 15:57:22 +08:00

257 lines
9.3 KiB
Python
Executable file
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from pprint import pprint
import os
import re
import shutil
import tempfile
from skimage.metrics import normalized_mutual_information
import itk
NumberOfIterations = 2100
# NumberOfIterations = 4000
metric_patern = r'Final metric value = (\S+)'
metric_prog = re.compile(metric_patern)
class elastix_reg:
def register_aux(self, fi, mv, debug=False, MaximumNumberOfIterations=[str(NumberOfIterations)]):
parameter_object = itk.ParameterObject.New()
default_rigid_parameter_map = parameter_object.GetDefaultParameterMap('rigid')
default_rigid_parameter_map["AutomaticTransformInitialization"] = ["true"]
# default_rigid_parameter_map["NumberOfSamplesForExactGradient"] = ["100000"]
default_rigid_parameter_map['MaximumNumberOfIterations'] = MaximumNumberOfIterations
parameter_object.AddParameterMap(default_rigid_parameter_map)
# pprint(default_rigid_parameter_map.asdict())
# exit()
outdir1 = tempfile.mkdtemp()
try:
fm, params1 = itk.elastix_registration_method(
fi, mv,
parameter_object=parameter_object,
# log_to_console=True,
log_to_file=True,
output_directory = outdir1,
)
except Exception as ex:
print(ex)
print(os.path.join(outdir1, 'elastix.log'))
# exit()
return {
'metrics': 0,
}
TransformParameterFileName = os.path.join(outdir1, 'TransformParameters.0.txt')
# print(TransformParameterFileName)
# exit()
'''
The DisplacementMagnitudePenalty is a cost function that penalises ||Tμ(x) x||2. You can use this
to invert transforms, by setting the transform to be inverted as an initial transform (using -t0), setting
(HowToCombineTransforms "Compose"), and running elastix with this metric, using the original fixed
image set both as fixed (-f) and moving (-m) image. After that you can manually set the initial transform
in the last parameter file to "NoInitialTransform", and voila, you have the inverse transform! Strictly
speaking, you should then also change the Size/Spacing/Origin/Index/Direction settings to match that of
the moving image. Select it with:
(Metric "DisplacementMagnitudePenalty")
Note that inverting a transformation becomes conceptually very similar to performing an image registration
in this way. Consequently, the same choices are relevant: optimisation algorithm, multiresolution etc...
Note that this procedure was described and evaluated in Metz et al. [2011].
'''
parameter_object2 = itk.ParameterObject.New()
inverse_rigid_parameter_map = parameter_object.GetDefaultParameterMap('rigid')
inverse_rigid_parameter_map["HowToCombineTransforms"] = ["Compose"]
inverse_rigid_parameter_map["Metric"] = ["DisplacementMagnitudePenalty"]
# inverse_rigid_parameter_map["AutomaticTransformInitialization"] = ["true"]
inverse_rigid_parameter_map['MaximumNumberOfIterations'] = MaximumNumberOfIterations
# inverse_rigid_parameter_map['UseAdaptiveStepSizes'] = ['false']
parameter_object2.AddParameterMap(inverse_rigid_parameter_map)
# print(TransformParameterFileName)
# exit()
outdir2 = tempfile.mkdtemp()
mm, params2 = itk.elastix_registration_method(
mv, mv,
parameter_object=parameter_object2,
initial_transform_parameter_file_name = TransformParameterFileName,
log_to_console=debug,
log_to_file=True,
output_directory = outdir2,
)
elastix_log = os.path.join(outdir2, 'elastix.log')
with open(elastix_log) as log:
m = re.search(metric_prog, log.read())
DisplacementMagnitudePenalty = float(m[1])
# print(DisplacementMagnitudePenalty)
# exit()
last_parameter_map = params2.GetParameterMap(0)
# pprint(last_parameter_map.asdict())
# exit()
# WARNING: The parameter name "InitialTransformParametersFileName" is deprecated. Please use "InitialTransformParameterFileName" (without letter 's') instead.
last_parameter_map["InitialTransformParametersFileName"] = ["NoInitialTransform"]
last_parameter_map["InitialTransformParameterFileName"] = ["NoInitialTransform"]
params2.SetParameterMap(0, last_parameter_map)
# params2.WriteParameterFile('123.txt')
mf = itk.transformix_filter(
fi,
params2)
m1 = normalized_mutual_information(itk.GetArrayViewFromImage(fi), itk.GetArrayViewFromImage(fm))
m2 = normalized_mutual_information(itk.GetArrayViewFromImage(mv), itk.GetArrayViewFromImage(mf))
print(MaximumNumberOfIterations, m1, m2, DisplacementMagnitudePenalty)
shutil.rmtree(outdir1)
shutil.rmtree(outdir2)
# exit()
return {
'fwdtransforms': params1,
'invtransforms': params2,
'warpedfixout': mf,
'warpedmovout': fm,
'metrics': max(m1, m2),
'DisplacementMagnitudePenalty': DisplacementMagnitudePenalty,
}
# PixelType = itk.F
# Dimension = 3
# ImageType = itk.Image[PixelType, Dimension]
# METRIC_THRESHOLD = 1.1
def __init__(self, fi, mv, warpedfixout=None, warpedmovout=None, debug=False, iterations_init=NumberOfIterations):
self.debug = debug
# reader = itk.ImageFileReader[ImageType].New()
# reader.SetFileName(fi)
# reader.Update()
# fixed_image = reader.GetOutput()
fixed_image = itk.imread(fi, itk.F)
moving_image = itk.imread(mv, itk.F)
iterations = iterations_init
# iterations_fin = iterations_init*2
while True:
MaximumNumberOfIterations = [str(iterations)]
r1 = self.register_aux(fixed_image, moving_image, debug, MaximumNumberOfIterations=MaximumNumberOfIterations)
if 'DisplacementMagnitudePenalty' not in r1: # None?
break
if r1['DisplacementMagnitudePenalty'] < 1:
break
# elif r1['metrics'] > METRIC_THRESHOLD:
# Redo = False
# if iterations>iterations_fin:
# Redo = False
iterations *= 2
while True:
MaximumNumberOfIterations = [str(iterations)]
r2 = self.register_aux(moving_image, fixed_image, debug, MaximumNumberOfIterations=MaximumNumberOfIterations)
if 'DisplacementMagnitudePenalty' not in r2: # None?
break
elif r2['DisplacementMagnitudePenalty'] < 1:
break
# elif r2['metrics'] > METRIC_THRESHOLD:
# Redo = False
# if iterations>iterations_fin:
# Redo = False
iterations *= 2
if r1['metrics'] > r2['metrics']:
self.res = r1
else:
if 'invtransforms' not in r2:
return None
self.res = dict(r2)
self.res.update({
'fwdtransforms': r2['invtransforms'],
'invtransforms': r2['fwdtransforms'],
'warpedfixout': r2['warpedmovout'],
'warpedmovout': r2['warpedfixout'],
})
assert self.res['DisplacementMagnitudePenalty'] < 1, 'DisplacementMagnitudePenalty: %f ' % (self.res['DisplacementMagnitudePenalty'])
if warpedfixout is not None:
itk.imwrite(self.res['warpedfixout'], warpedfixout)
if warpedmovout is not None:
itk.imwrite(self.res['warpedmovout'], warpedmovout)
if debug:
pprint(self.res)
itk.imwrite(fixed_image, '0fixed.nii.gz')
itk.imwrite(moving_image, '0moving.nii.gz')
itk.imwrite(r1['warpedfixout'], '0mf1.nii.gz')
itk.imwrite(r1['warpedmovout'], '0fm1.nii.gz')
itk.imwrite(r2['warpedmovout'], '0mf2.nii.gz')
itk.imwrite(r2['warpedfixout'], '0fm2.nii.gz')
# return res
def get_metrics(self):
return self.res['metrics']
def write_warpedmovout(self, out):
itk.imwrite(self.res['warpedmovout'], out)
def transform(self, moving, output_filename, is_label=False):
transform1 = self.res['fwdtransforms']
mv = itk.imread(moving)
last_parameter_map = transform1.GetParameterMap(0)
if is_label:
# last_parameter_map["InitialTransformParametersFileName"] = ["NoInitialTransform"]
last_parameter_map["ResampleInterpolator"] = ["FinalNearestNeighborInterpolator"]
# last_parameter_map["ResultImagePixelType"] = ["unsigned char"]
t2 = itk.ParameterObject.New()
t2.AddParameterMap(last_parameter_map)
# pprint(t2.GetParameterMap(0).asdict())
output = itk.transformix_filter(
mv.astype(itk.F),
t2,
log_to_console=self.debug,
)
if is_label:
output=output.astype(itk.UC)
itk.imwrite(output, output_filename)