''' conda deactivate conda create -y -n fireants -c conda-forge 'numpy<2' simpleitk=2.2.1 conda activate fireants pip install fireants CUDA out of memory. ??? ''' from fireants.io import Image, BatchedImages from fireants.registration import AffineRegistration, GreedyRegistration import matplotlib.pyplot as plt import SimpleITK as sitk from time import time # load the images # image1 = Image.load_file("atlas_2mm_1000_3.nii.gz") # image2 = Image.load_file("atlas_2mm_1001_3.nii.gz") image1 = Image.load_file("/mnt/1218/Public/dataset2/M6/ZYRGTRKJ/20230728/CT/1.1_CyberKnife_head(MAR)_20230728111920_3.nii.gz") image2 = Image.load_file("/mnt/1218/Public/dataset2/M6/ZYRGTRKJ/20230728/MR/3D_SAG_T1_MPRAGE_+C_MPR_Tra_20230728143005_14.nii.gz") # batchify them (we only have a single image per batch, but we can pass multiple images) batch1 = BatchedImages([image1]) batch2 = BatchedImages([image2]) # check device name print(batch1().device) # specify some values scales = [4, 2, 1] # scales at which to perform registration iterations = [200, 100, 50] optim = 'Adam' lr = 3e-3 # create affine registration object affine = AffineRegistration(scales, iterations, batch1, batch2, optimizer=optim, optimizer_lr=lr, loss_type = 'mi', cc_kernel_size=5) # run registration start = time() transformed_images = affine.optimize(save_transformed=True) end = time() print("Runtime", end - start, "seconds") reg = GreedyRegistration(scales=[4, 2, 1], iterations=[200, 100, 25], fixed_images=batch1, moving_images=batch2, cc_kernel_size=5, deformation_type='compositive', smooth_grad_sigma=1, loss_type = 'mi', optimizer='adam', optimizer_lr=0.5, init_affine=affine.get_affine_matrix().detach()) start = time() reg.optimize(save_transformed=False) end = time() print("Runtime", end - start, "seconds") moved = reg.evaluate(batch1, batch2) # reference_img = sitk.ReadImage("atlas_2mm_1000_3.nii.gz") reference_img = sitk.ReadImage("/mnt/1218/Public/dataset2/M6/ZYRGTRKJ/20230728/CT/1.1_CyberKnife_head(MAR)_20230728111920_3.nii.gz") # Preparing the moving image to be written out moved_image_np = moved[0, 0].detach().cpu().numpy() # volumes are typically stored in tensors with dimensions [Batch, Channels, Depth, Height, Width], so extracting the latter 3 for nifti moved_sitk_image = sitk.GetImageFromArray(moved_image_np) moved_sitk_image.SetOrigin(reference_img.GetOrigin()) moved_sitk_image.SetSpacing(reference_img.GetSpacing()) moved_sitk_image.SetDirection(reference_img.GetDirection()) sitk.WriteImage(moved_sitk_image, 'reslice_deform_atlas_2mm_1000_3.nii.gz')