ck-preprocess/zz/fireants-test.py

79 lines
2.6 KiB
Python
Raw Normal View History

2025-02-08 00:39:18 +00:00
'''
conda deactivate
conda create -y -n fireants -c conda-forge 'numpy<2' simpleitk=2.2.1
conda activate fireants
pip install fireants
CUDA out of memory. ???
'''
from fireants.io import Image, BatchedImages
from fireants.registration import AffineRegistration, GreedyRegistration
import matplotlib.pyplot as plt
import SimpleITK as sitk
from time import time
# load the images
# image1 = Image.load_file("atlas_2mm_1000_3.nii.gz")
# image2 = Image.load_file("atlas_2mm_1001_3.nii.gz")
image1 = Image.load_file("/mnt/1218/Public/dataset2/M6/ZYRGTRKJ/20230728/CT/1.1_CyberKnife_head(MAR)_20230728111920_3.nii.gz")
image2 = Image.load_file("/mnt/1218/Public/dataset2/M6/ZYRGTRKJ/20230728/MR/3D_SAG_T1_MPRAGE_+C_MPR_Tra_20230728143005_14.nii.gz")
# batchify them (we only have a single image per batch, but we can pass multiple images)
batch1 = BatchedImages([image1])
batch2 = BatchedImages([image2])
# check device name
print(batch1().device)
# specify some values
scales = [4, 2, 1] # scales at which to perform registration
iterations = [200, 100, 50]
optim = 'Adam'
lr = 3e-3
# create affine registration object
affine = AffineRegistration(scales, iterations, batch1, batch2, optimizer=optim, optimizer_lr=lr,
loss_type = 'mi',
cc_kernel_size=5)
# run registration
start = time()
transformed_images = affine.optimize(save_transformed=True)
end = time()
print("Runtime", end - start, "seconds")
reg = GreedyRegistration(scales=[4, 2, 1], iterations=[200, 100, 25],
fixed_images=batch1, moving_images=batch2,
cc_kernel_size=5, deformation_type='compositive',
smooth_grad_sigma=1,
loss_type = 'mi',
optimizer='adam', optimizer_lr=0.5, init_affine=affine.get_affine_matrix().detach())
start = time()
reg.optimize(save_transformed=False)
end = time()
print("Runtime", end - start, "seconds")
moved = reg.evaluate(batch1, batch2)
# reference_img = sitk.ReadImage("atlas_2mm_1000_3.nii.gz")
reference_img = sitk.ReadImage("/mnt/1218/Public/dataset2/M6/ZYRGTRKJ/20230728/CT/1.1_CyberKnife_head(MAR)_20230728111920_3.nii.gz")
# Preparing the moving image to be written out
moved_image_np = moved[0, 0].detach().cpu().numpy() # volumes are typically stored in tensors with dimensions [Batch, Channels, Depth, Height, Width], so extracting the latter 3 for nifti
moved_sitk_image = sitk.GetImageFromArray(moved_image_np)
moved_sitk_image.SetOrigin(reference_img.GetOrigin())
moved_sitk_image.SetSpacing(reference_img.GetSpacing())
moved_sitk_image.SetDirection(reference_img.GetDirection())
sitk.WriteImage(moved_sitk_image, 'reslice_deform_atlas_2mm_1000_3.nii.gz')