from pprint import pprint import os import re import shutil import tempfile from skimage.metrics import normalized_mutual_information import itk NumberOfIterations = 2100 # NumberOfIterations = 4000 metric_patern = r'Final metric value = (\S+)' metric_prog = re.compile(metric_patern) class elastix_reg: def register_aux(self, fi, mv, debug=False, MaximumNumberOfIterations=[str(NumberOfIterations)]): parameter_object = itk.ParameterObject.New() default_rigid_parameter_map = parameter_object.GetDefaultParameterMap('rigid') default_rigid_parameter_map["AutomaticTransformInitialization"] = ["true"] # default_rigid_parameter_map["NumberOfSamplesForExactGradient"] = ["100000"] default_rigid_parameter_map['MaximumNumberOfIterations'] = MaximumNumberOfIterations parameter_object.AddParameterMap(default_rigid_parameter_map) # pprint(default_rigid_parameter_map.asdict()) # exit() outdir1 = tempfile.mkdtemp() try: fm, params1 = itk.elastix_registration_method( fi, mv, parameter_object=parameter_object, # log_to_console=True, log_to_file=True, output_directory = outdir1, ) except Exception as ex: print(ex) print(os.path.join(outdir1, 'elastix.log')) # exit() return { 'metrics': 0, } TransformParameterFileName = os.path.join(outdir1, 'TransformParameters.0.txt') # print(TransformParameterFileName) # exit() ''' The DisplacementMagnitudePenalty is a cost function that penalises ||Tμ(x) − x||2. You can use this to invert transforms, by setting the transform to be inverted as an initial transform (using -t0), setting (HowToCombineTransforms "Compose"), and running elastix with this metric, using the original fixed image set both as fixed (-f) and moving (-m) image. After that you can manually set the initial transform in the last parameter file to "NoInitialTransform", and voila, you have the inverse transform! Strictly speaking, you should then also change the Size/Spacing/Origin/Index/Direction settings to match that of the moving image. Select it with: (Metric "DisplacementMagnitudePenalty") Note that inverting a transformation becomes conceptually very similar to performing an image registration in this way. Consequently, the same choices are relevant: optimisation algorithm, multiresolution etc... Note that this procedure was described and evaluated in Metz et al. [2011]. ''' parameter_object2 = itk.ParameterObject.New() inverse_rigid_parameter_map = parameter_object.GetDefaultParameterMap('rigid') inverse_rigid_parameter_map["HowToCombineTransforms"] = ["Compose"] inverse_rigid_parameter_map["Metric"] = ["DisplacementMagnitudePenalty"] # inverse_rigid_parameter_map["AutomaticTransformInitialization"] = ["true"] inverse_rigid_parameter_map['MaximumNumberOfIterations'] = MaximumNumberOfIterations # inverse_rigid_parameter_map['UseAdaptiveStepSizes'] = ['false'] parameter_object2.AddParameterMap(inverse_rigid_parameter_map) # print(TransformParameterFileName) # exit() outdir2 = tempfile.mkdtemp() mm, params2 = itk.elastix_registration_method( mv, mv, parameter_object=parameter_object2, initial_transform_parameter_file_name = TransformParameterFileName, log_to_console=debug, log_to_file=True, output_directory = outdir2, ) elastix_log = os.path.join(outdir2, 'elastix.log') with open(elastix_log) as log: m = re.search(metric_prog, log.read()) DisplacementMagnitudePenalty = float(m[1]) # print(DisplacementMagnitudePenalty) # exit() last_parameter_map = params2.GetParameterMap(0) # pprint(last_parameter_map.asdict()) # exit() # WARNING: The parameter name "InitialTransformParametersFileName" is deprecated. Please use "InitialTransformParameterFileName" (without letter 's') instead. last_parameter_map["InitialTransformParametersFileName"] = ["NoInitialTransform"] last_parameter_map["InitialTransformParameterFileName"] = ["NoInitialTransform"] params2.SetParameterMap(0, last_parameter_map) # params2.WriteParameterFile('123.txt') mf = itk.transformix_filter( fi, params2) m1 = normalized_mutual_information(itk.GetArrayViewFromImage(fi), itk.GetArrayViewFromImage(fm)) m2 = normalized_mutual_information(itk.GetArrayViewFromImage(mv), itk.GetArrayViewFromImage(mf)) print(MaximumNumberOfIterations, m1, m2, DisplacementMagnitudePenalty) shutil.rmtree(outdir1) shutil.rmtree(outdir2) # exit() return { 'fwdtransforms': params1, 'invtransforms': params2, 'warpedfixout': mf, 'warpedmovout': fm, 'metrics': max(m1, m2), 'DisplacementMagnitudePenalty': DisplacementMagnitudePenalty, } # PixelType = itk.F # Dimension = 3 # ImageType = itk.Image[PixelType, Dimension] # METRIC_THRESHOLD = 1.1 def __init__(self, fi, mv, warpedfixout=None, warpedmovout=None, debug=False, iterations_init=NumberOfIterations): self.debug = debug # reader = itk.ImageFileReader[ImageType].New() # reader.SetFileName(fi) # reader.Update() # fixed_image = reader.GetOutput() fixed_image = itk.imread(fi, itk.F) moving_image = itk.imread(mv, itk.F) iterations = iterations_init # iterations_fin = iterations_init*2 while True: MaximumNumberOfIterations = [str(iterations)] r1 = self.register_aux(fixed_image, moving_image, debug, MaximumNumberOfIterations=MaximumNumberOfIterations) if 'DisplacementMagnitudePenalty' not in r1: # None? break if r1['DisplacementMagnitudePenalty'] < 1: break # elif r1['metrics'] > METRIC_THRESHOLD: # Redo = False # if iterations>iterations_fin: # Redo = False iterations *= 2 while True: MaximumNumberOfIterations = [str(iterations)] r2 = self.register_aux(moving_image, fixed_image, debug, MaximumNumberOfIterations=MaximumNumberOfIterations) if 'DisplacementMagnitudePenalty' not in r2: # None? break elif r2['DisplacementMagnitudePenalty'] < 1: break # elif r2['metrics'] > METRIC_THRESHOLD: # Redo = False # if iterations>iterations_fin: # Redo = False iterations *= 2 if r1['metrics'] > r2['metrics']: self.res = r1 else: if 'invtransforms' not in r2: return None self.res = dict(r2) self.res.update({ 'fwdtransforms': r2['invtransforms'], 'invtransforms': r2['fwdtransforms'], 'warpedfixout': r2['warpedmovout'], 'warpedmovout': r2['warpedfixout'], }) assert self.res['DisplacementMagnitudePenalty'] < 1, 'DisplacementMagnitudePenalty: %f ' % (self.res['DisplacementMagnitudePenalty']) if warpedfixout is not None: itk.imwrite(self.res['warpedfixout'], warpedfixout) if warpedmovout is not None: itk.imwrite(self.res['warpedmovout'], warpedmovout) if debug: pprint(self.res) itk.imwrite(fixed_image, '0fixed.nii.gz') itk.imwrite(moving_image, '0moving.nii.gz') itk.imwrite(r1['warpedfixout'], '0mf1.nii.gz') itk.imwrite(r1['warpedmovout'], '0fm1.nii.gz') itk.imwrite(r2['warpedmovout'], '0mf2.nii.gz') itk.imwrite(r2['warpedfixout'], '0fm2.nii.gz') # return res def get_metrics(self): return self.res['metrics'] def write_warpedmovout(self, out): itk.imwrite(self.res['warpedmovout'], out) def transform(self, moving, output_filename, is_label=False): transform1 = self.res['fwdtransforms'] mv = itk.imread(moving) last_parameter_map = transform1.GetParameterMap(0) if is_label: # last_parameter_map["InitialTransformParametersFileName"] = ["NoInitialTransform"] last_parameter_map["ResampleInterpolator"] = ["FinalNearestNeighborInterpolator"] # last_parameter_map["ResultImagePixelType"] = ["unsigned char"] t2 = itk.ParameterObject.New() t2.AddParameterMap(last_parameter_map) # pprint(t2.GetParameterMap(0).asdict()) output = itk.transformix_filter( mv.astype(itk.F), t2, log_to_console=self.debug, ) if is_label: output=output.astype(itk.UC) itk.imwrite(output, output_filename)