123/itk_elastix.py
2022-05-11 11:38:16 +08:00

302 lines
No EOL
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from pprint import pprint
import os
import re
import shutil
import tempfile
from skimage.metrics import normalized_mutual_information
import itk
metric_patern = r'Final metric value = (\S+)'
metric_prog = re.compile(metric_patern)
'''
ants.registration.interface
myiterations = '2100x1200x1200x10'
'''
# MaximumNumberOfIterations = ['1200']
# MaximumNumberOfIterations = ['2100']
# NumberOfIterations = 1200
NumberOfIterations = 2100
# NumberOfIterations = 4000
def register_aux(fi, mv, debug=False, MaximumNumberOfIterations=[str(NumberOfIterations)]):
parameter_object = itk.ParameterObject.New()
default_rigid_parameter_map = parameter_object.GetDefaultParameterMap('rigid')
default_rigid_parameter_map["AutomaticTransformInitialization"] = ["true"]
# default_rigid_parameter_map["NumberOfSamplesForExactGradient"] = ["100000"]
default_rigid_parameter_map['MaximumNumberOfIterations'] = MaximumNumberOfIterations
parameter_object.AddParameterMap(default_rigid_parameter_map)
# pprint(default_rigid_parameter_map.asdict())
# exit()
outdir1 = tempfile.mkdtemp()
try:
fm, params1 = itk.elastix_registration_method(
fi, mv,
parameter_object=parameter_object,
# log_to_console=True,
log_to_file=True,
output_directory = outdir1,
)
except Exception as ex:
print(ex)
print(os.path.join(outdir1, 'elastix.log'))
# exit()
return {
'metrics': 0,
}
TransformParameterFileName = os.path.join(outdir1, 'TransformParameters.0.txt')
# print(TransformParameterFileName)
# exit()
'''
The DisplacementMagnitudePenalty is a cost function that penalises ||Tμ(x) x||2. You can use this
to invert transforms, by setting the transform to be inverted as an initial transform (using -t0), setting
(HowToCombineTransforms "Compose"), and running elastix with this metric, using the original fixed
image set both as fixed (-f) and moving (-m) image. After that you can manually set the initial transform
in the last parameter file to "NoInitialTransform", and voila, you have the inverse transform! Strictly
speaking, you should then also change the Size/Spacing/Origin/Index/Direction settings to match that of
the moving image. Select it with:
(Metric "DisplacementMagnitudePenalty")
Note that inverting a transformation becomes conceptually very similar to performing an image registration
in this way. Consequently, the same choices are relevant: optimisation algorithm, multiresolution etc...
Note that this procedure was described and evaluated in Metz et al. [2011].
'''
parameter_object2 = itk.ParameterObject.New()
inverse_rigid_parameter_map = parameter_object.GetDefaultParameterMap('rigid')
inverse_rigid_parameter_map["HowToCombineTransforms"] = ["Compose"]
inverse_rigid_parameter_map["Metric"] = ["DisplacementMagnitudePenalty"]
# inverse_rigid_parameter_map["AutomaticTransformInitialization"] = ["true"]
inverse_rigid_parameter_map['MaximumNumberOfIterations'] = MaximumNumberOfIterations
# inverse_rigid_parameter_map['UseAdaptiveStepSizes'] = ['false']
parameter_object2.AddParameterMap(inverse_rigid_parameter_map)
# print(TransformParameterFileName)
# exit()
outdir2 = tempfile.mkdtemp()
mm, params2 = itk.elastix_registration_method(
mv, mv,
parameter_object=parameter_object2,
initial_transform_parameter_file_name = TransformParameterFileName,
log_to_console=debug,
log_to_file=True,
output_directory = outdir2,
)
elastix_log = os.path.join(outdir2, 'elastix.log')
with open(elastix_log) as log:
m = re.search(metric_prog, log.read())
DisplacementMagnitudePenalty = float(m[1])
# print(DisplacementMagnitudePenalty)
# exit()
last_parameter_map = params2.GetParameterMap(0)
# pprint(last_parameter_map.asdict())
# exit()
last_parameter_map["InitialTransformParametersFileName"] = ["NoInitialTransform"]
params2.SetParameterMap(0, last_parameter_map)
# params2.WriteParameterFile('123.txt')
mf = itk.transformix_filter(
fi,
params2)
m1 = normalized_mutual_information(itk.GetArrayViewFromImage(fi), itk.GetArrayViewFromImage(fm))
m2 = normalized_mutual_information(itk.GetArrayViewFromImage(mv), itk.GetArrayViewFromImage(mf))
print(MaximumNumberOfIterations, m1, m2, DisplacementMagnitudePenalty)
shutil.rmtree(outdir1)
shutil.rmtree(outdir2)
# exit()
return {
'fwdtransforms': params1,
'invtransforms': params2,
'warpedfixout': mf,
'warpedmovout': fm,
'metrics': max(m1, m2),
'DisplacementMagnitudePenalty': DisplacementMagnitudePenalty,
}
PixelType = itk.F
Dimension = 3
ImageType = itk.Image[PixelType, Dimension]
# reader = itk.ImageFileReader[ImageType].New()
# reader.SetFileName("image.tif")
# reader.Update()
# image = reader.GetOutput()
# METRIC_THRESHOLD = 1.1
def register(fi, mv, warpedfixout=None, warpedmovout=None, debug=False, iterations_init=NumberOfIterations):
# reader = itk.ImageFileReader[ImageType].New()
# reader.SetFileName(fi)
# reader.Update()
# fixed_image = reader.GetOutput()
fixed_image = itk.imread(fi, itk.F)
moving_image = itk.imread(mv, itk.F)
iterations = iterations_init
# iterations_fin = iterations_init*2
while True:
MaximumNumberOfIterations = [str(iterations)]
r1 = register_aux(fixed_image, moving_image, debug, MaximumNumberOfIterations=MaximumNumberOfIterations)
if 'DisplacementMagnitudePenalty' not in r1: # None?
break
if r1['DisplacementMagnitudePenalty'] < 1:
break
# elif r1['metrics'] > METRIC_THRESHOLD:
# Redo = False
# if iterations>iterations_fin:
# Redo = False
iterations *= 2
while True:
MaximumNumberOfIterations = [str(iterations)]
r2 = register_aux(moving_image, fixed_image, debug, MaximumNumberOfIterations=MaximumNumberOfIterations)
if 'DisplacementMagnitudePenalty' not in r2: # None?
break
elif r2['DisplacementMagnitudePenalty'] < 1:
break
# elif r2['metrics'] > METRIC_THRESHOLD:
# Redo = False
# if iterations>iterations_fin:
# Redo = False
iterations *= 2
if r1['metrics'] > r2['metrics']:
res = r1
else:
if 'invtransforms' not in r2:
return None
res = dict(r2)
res.update({
'fwdtransforms': r2['invtransforms'],
'invtransforms': r2['fwdtransforms'],
'warpedfixout': r2['warpedmovout'],
'warpedmovout': r2['warpedfixout'],
})
assert res['DisplacementMagnitudePenalty'] < 1, 'DisplacementMagnitudePenalty: %f ' % (res['DisplacementMagnitudePenalty'])
if warpedfixout is not None:
itk.imwrite(res['warpedfixout'], warpedfixout)
if warpedmovout is not None:
itk.imwrite(res['warpedmovout'], warpedmovout)
if debug:
pprint(res)
itk.imwrite(fixed_image, '0fixed.nii.gz')
itk.imwrite(moving_image, '0moving.nii.gz')
itk.imwrite(r1['warpedfixout'], '0mf1.nii.gz')
itk.imwrite(r1['warpedmovout'], '0fm1.nii.gz')
itk.imwrite(r2['warpedmovout'], '0mf2.nii.gz')
itk.imwrite(r2['warpedfixout'], '0fm2.nii.gz')
return res
def transform_write(moving, transform, output_filename, is_label=False):
mv = itk.imread(moving)
last_parameter_map = transform.GetParameterMap(0)
if is_label:
# last_parameter_map["InitialTransformParametersFileName"] = ["NoInitialTransform"]
last_parameter_map["ResampleInterpolator"] = ["FinalNearestNeighborInterpolator"]
# last_parameter_map["ResultImagePixelType"] = ["unsigned char"]
t2 = itk.ParameterObject.New()
t2.AddParameterMap(last_parameter_map)
# pprint(t2.GetParameterMap(0).asdict())
output = itk.transformix_filter(
mv.astype(itk.F),
t2)
if is_label:
output=output.astype(itk.UC)
itk.imwrite(output, output_filename)
if __name__ == '__main__':
fi = '/media/nfs/SRS/TSGH2022G4/register_fwd/2131720/case2017.04.21.11.01.48/patient_6_RT_Cyberknife_1mm_Head_20170420091131_3.nii.gz'
mv = '/media/nfs/SRS/TSGH2022G4/register_fwd/2131720/case2017.04.21.11.01.48/patient_T1_SE_GD_20170420102216_3.nii.gz'
fi = '/media/nfs/SRS/TSGH2022G4/image/2131720/patient_6_RT_Cyberknife_1mm_Head_20170420091131_3.nii.gz'
mv = '/media/nfs/SRS/TSGH2022G4/image/2131720/patient_T1_SE_GD_20170420102216_3.nii.gz'
fi = '/media/nfs/SRS/TSGH2022G4/image/1693329/patient_6_Head_CTA_CNY_Head_20100330100138_18.nii.gz'
mv = '/media/nfs/SRS/TSGH2022G4/image/1693329/patient_T1_SE_GD_20100330105831_3.nii.gz'
# 4D image
# fi = '/media/nfs/SRS/NTUH2022G4/image/6053604/patient_1.7_CTA_+_Perfusion_(Subtraction)_20140121125458_103.nii.gz'
# mv = '/media/nfs/SRS/NTUH2022G4/image/6053604/patient_AX_T1+C(3D_2MM)_ZIP_512_20140121105711_5.nii.gz'
# bad registraation
fi = '/media/nfs/SRS/TSGH2022G4/image/2978108/patient_6_RT_Cyberknife_1mm_Head_20170608103902_3.nii.gz'
mv = '/media/nfs/SRS/TSGH2022G4/image/2978108/patient_Ax_T1_SE_CK_GD+_20170608094843_4.nii.gz'
fi = '/media/nfs/SRS/TSGH2022G4/image/2511774/patient_6_RT_Cyberknife_1mm_Head_20141215144427_3.nii.gz'
mv = '/media/nfs/SRS/TSGH2022G4/image/2511774/patient_T1_SE_GD_20141215140945_3.nii.gz'
fi = '/media/nfs/SRS/TSGH2022G4/image/2305719/patient_6_RT_Cyberknife_1mm_Head_20090519105710_4.nii.gz'
mv = '/media/nfs/SRS/TSGH2022G4/image/2305719/patient_T1_SE_GD_20090519104815_3.nii.gz'
# insufficent inverse
fi = '/media/nfs/SRS/TSGH2022G4/image/2441399/patient_2_Liver_3Phase_CK_Abdomen_20111031104910_9_e1.nii.gz'
mv = '/media/nfs/SRS/TSGH2022G4/image/2441399/patient_6_RT_Cyberknife_1mm_Head_20120625103934_3.nii.gz'
fi = '/media/nfs/SRS/TSGH2022G4/image/2232394/patient_6_RT_Cyberknife_1mm_Head_20100222122239_3.nii.gz'
mv = '/media/nfs/SRS/TSGH2022G4/image/2232394/patient_6_RT_Cyberknife_1mm_Head_20100222092318_4.nii.gz'
fi = '/media/nfs/SRS/TSGH2022G4/image/1255468/patient_6_RT_Cyberknife_1mm_Head_20080417105648_3.nii.gz'
mv = '/media/nfs/SRS/TSGH2022G4/image/1255468/patient_ebrain_IR_T1_ax_GD_20110209152748_1102.nii.gz'
# r = register(fi, mv, debug=True)
r = register(fi, mv)
print(r)