from pprint import pprint from time import time import os import tempfile from skimage.metrics import normalized_mutual_information from fireants.io import Image, BatchedImages from fireants.registration import RigidRegistration import SimpleITK as sitk class fireants_reg: def register_aux(self, fi, mv): mytx = ants.registration( fixed=fi, moving=mv, # type_of_transform = 'Rigid', type_of_transform = 'SyNRA', # verbose=True, ) # print(mytx['fwdtransforms'][0]) fwdtransforms = ants.read_transform(mytx['fwdtransforms'][0]) m1 = normalized_mutual_information(fi.numpy(), mytx['warpedmovout'].numpy()) m2 = normalized_mutual_information(mv.numpy(), mytx['warpedfixout'].numpy()) print(m1, m2) return { 'fwdtransforms': fwdtransforms, # 'invtransforms': fwdtransforms.invert(), # inverseTransform(): incompatible function arguments. The following argument types are supported: # 1. inverseTransform(arg: ants.lib.AntsTransformF22, /) -> ants.lib.AntsTransformF22 # 2. inverseTransform(arg: ants.lib.AntsTransformF33, /) -> ants.lib.AntsTransformF33 # 3. inverseTransform(arg: ants.lib.AntsTransformF44, /) -> ants.lib.AntsTransformF44 # 4. inverseTransform(arg: ants.lib.AntsTransformD22, /) -> ants.lib.AntsTransformD22 # 5. inverseTransform(arg: ants.lib.AntsTransformD33, /) -> ants.lib.AntsTransformD33 # 6. inverseTransform(arg: ants.lib.AntsTransformD44, /) -> ants.lib.AntsTransformD44 # Invoked with types: ants.lib.AntsTransformDF3 'warpedfixout': mytx['warpedfixout'], 'warpedmovout': mytx['warpedmovout'], 'metrics': max(m1, m2) } def __init__(self, fi, mv, debug=False): fixed_image = ants.image_read(fi, dimension=3) moving_image = ants.image_read(mv, dimension=3) r1 = self.register_aux(fixed_image, moving_image) r2 = self.register_aux(moving_image, fixed_image) if r1['metrics'] > r2['metrics']: self.res = r1 else: self.res = dict(r2) self.res.update({ # 'fwdtransforms': r2['invtransforms'], 'invtransforms': r2['fwdtransforms'], 'warpedfixout': r2['warpedmovout'], 'warpedmovout': r2['warpedfixout'], }) self.res.update({ 'fix': fixed_image, 'mov': moving_image, }) if debug: pprint(self.res) ants.image_write(fixed_image, '0fixed.nii.gz') ants.image_write(moving_image, '0moving.nii.gz') ants.image_write(r1['warpedfixout'], '0mf1.nii.gz') ants.image_write(r1['warpedmovout'], '0fm1.nii.gz') ants.image_write(r2['warpedmovout'], '0mf2.nii.gz') ants.image_write(r2['warpedfixout'], '0fm2.nii.gz') def get_metrics(self): return self.res['metrics'] def write_warpedmovout(self, out): ants.image_write(self.res['warpedmovout'], out) def transform(self, moving, output_filename, is_label=False): transform1 = next(tempfile._get_candidate_names())+'.mat' # print(transform1) ants.write_transform(self.res['fwdtransforms'], transform1) mi = ants.image_read(moving, dimension=3) if is_label: transformed = ants.apply_transforms(self.res['fix'], mi, transformlist=[transform1], interpolator='genericLabel').astype('uint8') else: transformed = ants.apply_transforms(self.res['fix'], mi, transformlist=[transform1]) # print(transformed) ants.image_write(transformed, output_filename) os.remove(transform1) if __name__ == '__main__': fi = '/nn/7295866/20250127/nii/a_1.1_CyberKnife_head(MAR)_20250127111447_5.nii.gz' mv = '/nn/7295866/20250127/nii/7_3D_SAG_T1_MPRAGE_+C_20250127132612_100.nii.gz' # load the images image1 = Image.load_file(fi) image2 = Image.load_file(mv) # batchify them (we only have a single image per batch, but we can pass multiple images) fixed_batch = BatchedImages([image1]) moving_batch = BatchedImages([image2]) # rigid registration scales = [4, 2, 1] # scales at which to perform registration iterations = [200, 100, 50] scales = [4, 2] # scales at which to perform registration iterations = [200, 100] optim = 'Adam' lr = 3e-4 # create rigid registration object rigid_reg = RigidRegistration( scales, iterations, fixed_batch, moving_batch, loss_type = 'mi', # mi_kernel_type = 'gaussian', # optimizer=optim, optimizer_lr=lr, # cc_kernel_size=5, ) # call method # rigid_reg.optimize() start = time() rigid_reg.optimize(save_transformed=False) end = time() print("Runtime", end - start, "seconds") moved = rigid_reg.evaluate(fixed_batch, moving_batch) reference_img = sitk.ReadImage(fi) # Preparing the moving image to be written out moved_image_np = moved[0, 0].detach().cpu().numpy() # volumes are typically stored in tensors with dimensions [Batch, Channels, Depth, Height, Width], so extracting the latter 3 for nifti moved_sitk_image = sitk.GetImageFromArray(moved_image_np) moved_sitk_image.SetOrigin(reference_img.GetOrigin()) moved_sitk_image.SetSpacing(reference_img.GetSpacing()) moved_sitk_image.SetDirection(reference_img.GetDirection()) sitk.WriteImage(moved_sitk_image, 'tmp.nii.gz') # reg_only(fi, mv_img, 'tmp.nii.gz')