2023-08-08 22:04:06 +00:00
from pprint import pprint
import os
import re
import shutil
import tempfile
from skimage . metrics import normalized_mutual_information
import itk
NumberOfIterations = 2100
# NumberOfIterations = 4000
metric_patern = r ' Final metric value = ( \ S+) '
metric_prog = re . compile ( metric_patern )
class elastix_reg :
def register_aux ( self , fi , mv , debug = False , MaximumNumberOfIterations = [ str ( NumberOfIterations ) ] ) :
parameter_object = itk . ParameterObject . New ( )
default_rigid_parameter_map = parameter_object . GetDefaultParameterMap ( ' rigid ' )
default_rigid_parameter_map [ " AutomaticTransformInitialization " ] = [ " true " ]
# default_rigid_parameter_map["NumberOfSamplesForExactGradient"] = ["100000"]
default_rigid_parameter_map [ ' MaximumNumberOfIterations ' ] = MaximumNumberOfIterations
parameter_object . AddParameterMap ( default_rigid_parameter_map )
# pprint(default_rigid_parameter_map.asdict())
# exit()
outdir1 = tempfile . mkdtemp ( )
try :
fm , params1 = itk . elastix_registration_method (
fi , mv ,
parameter_object = parameter_object ,
# log_to_console=True,
log_to_file = True ,
output_directory = outdir1 ,
)
except Exception as ex :
print ( ex )
print ( os . path . join ( outdir1 , ' elastix.log ' ) )
# exit()
return {
' metrics ' : 0 ,
}
TransformParameterFileName = os . path . join ( outdir1 , ' TransformParameters.0.txt ' )
# print(TransformParameterFileName)
# exit()
'''
The DisplacementMagnitudePenalty is a cost function that penalises | | Tμ ( x ) − x | | 2. You can use this
to invert transforms , by setting the transform to be inverted as an initial transform ( using - t0 ) , setting
( HowToCombineTransforms " Compose " ) , and running elastix with this metric , using the original fixed
image set both as fixed ( - f ) and moving ( - m ) image . After that you can manually set the initial transform
in the last parameter file to " NoInitialTransform " , and voila , you have the inverse transform ! Strictly
speaking , you should then also change the Size / Spacing / Origin / Index / Direction settings to match that of
the moving image . Select it with :
( Metric " DisplacementMagnitudePenalty " )
Note that inverting a transformation becomes conceptually very similar to performing an image registration
in this way . Consequently , the same choices are relevant : optimisation algorithm , multiresolution etc . . .
Note that this procedure was described and evaluated in Metz et al . [ 2011 ] .
'''
parameter_object2 = itk . ParameterObject . New ( )
inverse_rigid_parameter_map = parameter_object . GetDefaultParameterMap ( ' rigid ' )
inverse_rigid_parameter_map [ " HowToCombineTransforms " ] = [ " Compose " ]
inverse_rigid_parameter_map [ " Metric " ] = [ " DisplacementMagnitudePenalty " ]
# inverse_rigid_parameter_map["AutomaticTransformInitialization"] = ["true"]
inverse_rigid_parameter_map [ ' MaximumNumberOfIterations ' ] = MaximumNumberOfIterations
# inverse_rigid_parameter_map['UseAdaptiveStepSizes'] = ['false']
parameter_object2 . AddParameterMap ( inverse_rigid_parameter_map )
# print(TransformParameterFileName)
# exit()
outdir2 = tempfile . mkdtemp ( )
mm , params2 = itk . elastix_registration_method (
mv , mv ,
parameter_object = parameter_object2 ,
initial_transform_parameter_file_name = TransformParameterFileName ,
log_to_console = debug ,
log_to_file = True ,
output_directory = outdir2 ,
)
elastix_log = os . path . join ( outdir2 , ' elastix.log ' )
with open ( elastix_log ) as log :
m = re . search ( metric_prog , log . read ( ) )
DisplacementMagnitudePenalty = float ( m [ 1 ] )
# print(DisplacementMagnitudePenalty)
# exit()
last_parameter_map = params2 . GetParameterMap ( 0 )
# pprint(last_parameter_map.asdict())
# exit()
2025-02-01 07:57:22 +00:00
# WARNING: The parameter name "InitialTransformParametersFileName" is deprecated. Please use "InitialTransformParameterFileName" (without letter 's') instead.
2023-08-08 22:04:06 +00:00
last_parameter_map [ " InitialTransformParametersFileName " ] = [ " NoInitialTransform " ]
2025-02-01 07:57:22 +00:00
last_parameter_map [ " InitialTransformParameterFileName " ] = [ " NoInitialTransform " ]
2023-08-08 22:04:06 +00:00
params2 . SetParameterMap ( 0 , last_parameter_map )
# params2.WriteParameterFile('123.txt')
mf = itk . transformix_filter (
fi ,
params2 )
m1 = normalized_mutual_information ( itk . GetArrayViewFromImage ( fi ) , itk . GetArrayViewFromImage ( fm ) )
m2 = normalized_mutual_information ( itk . GetArrayViewFromImage ( mv ) , itk . GetArrayViewFromImage ( mf ) )
print ( MaximumNumberOfIterations , m1 , m2 , DisplacementMagnitudePenalty )
shutil . rmtree ( outdir1 )
shutil . rmtree ( outdir2 )
# exit()
return {
' fwdtransforms ' : params1 ,
' invtransforms ' : params2 ,
' warpedfixout ' : mf ,
' warpedmovout ' : fm ,
' metrics ' : max ( m1 , m2 ) ,
' DisplacementMagnitudePenalty ' : DisplacementMagnitudePenalty ,
}
# PixelType = itk.F
# Dimension = 3
# ImageType = itk.Image[PixelType, Dimension]
# METRIC_THRESHOLD = 1.1
def __init__ ( self , fi , mv , warpedfixout = None , warpedmovout = None , debug = False , iterations_init = NumberOfIterations ) :
2025-02-01 07:57:22 +00:00
self . debug = debug
2023-08-08 22:04:06 +00:00
# reader = itk.ImageFileReader[ImageType].New()
# reader.SetFileName(fi)
# reader.Update()
# fixed_image = reader.GetOutput()
fixed_image = itk . imread ( fi , itk . F )
moving_image = itk . imread ( mv , itk . F )
iterations = iterations_init
# iterations_fin = iterations_init*2
while True :
MaximumNumberOfIterations = [ str ( iterations ) ]
r1 = self . register_aux ( fixed_image , moving_image , debug , MaximumNumberOfIterations = MaximumNumberOfIterations )
if ' DisplacementMagnitudePenalty ' not in r1 : # None?
break
if r1 [ ' DisplacementMagnitudePenalty ' ] < 1 :
break
# elif r1['metrics'] > METRIC_THRESHOLD:
# Redo = False
# if iterations>iterations_fin:
# Redo = False
iterations * = 2
while True :
MaximumNumberOfIterations = [ str ( iterations ) ]
r2 = self . register_aux ( moving_image , fixed_image , debug , MaximumNumberOfIterations = MaximumNumberOfIterations )
if ' DisplacementMagnitudePenalty ' not in r2 : # None?
break
elif r2 [ ' DisplacementMagnitudePenalty ' ] < 1 :
break
# elif r2['metrics'] > METRIC_THRESHOLD:
# Redo = False
# if iterations>iterations_fin:
# Redo = False
iterations * = 2
if r1 [ ' metrics ' ] > r2 [ ' metrics ' ] :
self . res = r1
else :
if ' invtransforms ' not in r2 :
return None
self . res = dict ( r2 )
self . res . update ( {
' fwdtransforms ' : r2 [ ' invtransforms ' ] ,
' invtransforms ' : r2 [ ' fwdtransforms ' ] ,
' warpedfixout ' : r2 [ ' warpedmovout ' ] ,
' warpedmovout ' : r2 [ ' warpedfixout ' ] ,
} )
assert self . res [ ' DisplacementMagnitudePenalty ' ] < 1 , ' DisplacementMagnitudePenalty: %f ' % ( self . res [ ' DisplacementMagnitudePenalty ' ] )
if warpedfixout is not None :
itk . imwrite ( self . res [ ' warpedfixout ' ] , warpedfixout )
if warpedmovout is not None :
itk . imwrite ( self . res [ ' warpedmovout ' ] , warpedmovout )
if debug :
pprint ( self . res )
itk . imwrite ( fixed_image , ' 0fixed.nii.gz ' )
itk . imwrite ( moving_image , ' 0moving.nii.gz ' )
itk . imwrite ( r1 [ ' warpedfixout ' ] , ' 0mf1.nii.gz ' )
itk . imwrite ( r1 [ ' warpedmovout ' ] , ' 0fm1.nii.gz ' )
itk . imwrite ( r2 [ ' warpedmovout ' ] , ' 0mf2.nii.gz ' )
itk . imwrite ( r2 [ ' warpedfixout ' ] , ' 0fm2.nii.gz ' )
# return res
def get_metrics ( self ) :
return self . res [ ' metrics ' ]
2025-02-01 07:57:22 +00:00
def write_warpedmovout ( self , out ) :
itk . imwrite ( self . res [ ' warpedmovout ' ] , out )
2023-08-08 22:04:06 +00:00
def transform ( self , moving , output_filename , is_label = False ) :
transform1 = self . res [ ' fwdtransforms ' ]
mv = itk . imread ( moving )
last_parameter_map = transform1 . GetParameterMap ( 0 )
if is_label :
# last_parameter_map["InitialTransformParametersFileName"] = ["NoInitialTransform"]
last_parameter_map [ " ResampleInterpolator " ] = [ " FinalNearestNeighborInterpolator " ]
# last_parameter_map["ResultImagePixelType"] = ["unsigned char"]
t2 = itk . ParameterObject . New ( )
t2 . AddParameterMap ( last_parameter_map )
# pprint(t2.GetParameterMap(0).asdict())
output = itk . transformix_filter (
mv . astype ( itk . F ) ,
2025-02-01 07:57:22 +00:00
t2 ,
log_to_console = self . debug ,
)
2023-08-08 22:04:06 +00:00
if is_label :
output = output . astype ( itk . UC )
itk . imwrite ( output , output_filename )