from pprint import pprint import os import tempfile from skimage.metrics import normalized_mutual_information import ants class ants_reg: def register_aux(self, fi, mv): mytx = ants.registration( fixed=fi, moving=mv, # type_of_transform = 'Rigid', type_of_transform = 'SyNRA', # verbose=True, ) # print(mytx['fwdtransforms'][0]) fwdtransforms = ants.read_transform(mytx['fwdtransforms'][0]) m1 = normalized_mutual_information(fi.numpy(), mytx['warpedmovout'].numpy()) m2 = normalized_mutual_information(mv.numpy(), mytx['warpedfixout'].numpy()) print(m1, m2) return { 'fwdtransforms': fwdtransforms, # 'invtransforms': fwdtransforms.invert(), # inverseTransform(): incompatible function arguments. The following argument types are supported: # 1. inverseTransform(arg: ants.lib.AntsTransformF22, /) -> ants.lib.AntsTransformF22 # 2. inverseTransform(arg: ants.lib.AntsTransformF33, /) -> ants.lib.AntsTransformF33 # 3. inverseTransform(arg: ants.lib.AntsTransformF44, /) -> ants.lib.AntsTransformF44 # 4. inverseTransform(arg: ants.lib.AntsTransformD22, /) -> ants.lib.AntsTransformD22 # 5. inverseTransform(arg: ants.lib.AntsTransformD33, /) -> ants.lib.AntsTransformD33 # 6. inverseTransform(arg: ants.lib.AntsTransformD44, /) -> ants.lib.AntsTransformD44 # Invoked with types: ants.lib.AntsTransformDF3 'warpedfixout': mytx['warpedfixout'], 'warpedmovout': mytx['warpedmovout'], 'metrics': max(m1, m2) } def __init__(self, fi, mv, debug=False): fixed_image = ants.image_read(fi, dimension=3) moving_image = ants.image_read(mv, dimension=3) r1 = self.register_aux(fixed_image, moving_image) r2 = self.register_aux(moving_image, fixed_image) if r1['metrics'] > r2['metrics']: self.res = r1 else: self.res = dict(r2) self.res.update({ # 'fwdtransforms': r2['invtransforms'], 'invtransforms': r2['fwdtransforms'], 'warpedfixout': r2['warpedmovout'], 'warpedmovout': r2['warpedfixout'], }) self.res.update({ 'fix': fixed_image, 'mov': moving_image, }) if debug: pprint(self.res) ants.image_write(fixed_image, '0fixed.nii.gz') ants.image_write(moving_image, '0moving.nii.gz') ants.image_write(r1['warpedfixout'], '0mf1.nii.gz') ants.image_write(r1['warpedmovout'], '0fm1.nii.gz') ants.image_write(r2['warpedmovout'], '0mf2.nii.gz') ants.image_write(r2['warpedfixout'], '0fm2.nii.gz') def get_metrics(self): return self.res['metrics'] def write_warpedmovout(self, out): ants.image_write(self.res['warpedmovout'], out) def transform(self, moving, output_filename, is_label=False): transform1 = next(tempfile._get_candidate_names())+'.mat' # print(transform1) ants.write_transform(self.res['fwdtransforms'], transform1) mi = ants.image_read(moving, dimension=3) if is_label: transformed = ants.apply_transforms(self.res['fix'], mi, transformlist=[transform1], interpolator='genericLabel').astype('uint8') else: transformed = ants.apply_transforms(self.res['fix'], mi, transformlist=[transform1]) # print(transformed) ants.image_write(transformed, output_filename) os.remove(transform1)