79 lines
2.6 KiB
Python
79 lines
2.6 KiB
Python
|
'''
|
||
|
|
||
|
|
||
|
|
||
|
conda deactivate
|
||
|
conda create -y -n fireants -c conda-forge 'numpy<2' simpleitk=2.2.1
|
||
|
conda activate fireants
|
||
|
pip install fireants
|
||
|
|
||
|
|
||
|
CUDA out of memory. ???
|
||
|
|
||
|
'''
|
||
|
|
||
|
|
||
|
from fireants.io import Image, BatchedImages
|
||
|
from fireants.registration import AffineRegistration, GreedyRegistration
|
||
|
import matplotlib.pyplot as plt
|
||
|
import SimpleITK as sitk
|
||
|
from time import time
|
||
|
|
||
|
# load the images
|
||
|
# image1 = Image.load_file("atlas_2mm_1000_3.nii.gz")
|
||
|
# image2 = Image.load_file("atlas_2mm_1001_3.nii.gz")
|
||
|
image1 = Image.load_file("/mnt/1218/Public/dataset2/M6/ZYRGTRKJ/20230728/CT/1.1_CyberKnife_head(MAR)_20230728111920_3.nii.gz")
|
||
|
image2 = Image.load_file("/mnt/1218/Public/dataset2/M6/ZYRGTRKJ/20230728/MR/3D_SAG_T1_MPRAGE_+C_MPR_Tra_20230728143005_14.nii.gz")
|
||
|
|
||
|
# batchify them (we only have a single image per batch, but we can pass multiple images)
|
||
|
batch1 = BatchedImages([image1])
|
||
|
batch2 = BatchedImages([image2])
|
||
|
|
||
|
# check device name
|
||
|
print(batch1().device)
|
||
|
|
||
|
# specify some values
|
||
|
scales = [4, 2, 1] # scales at which to perform registration
|
||
|
iterations = [200, 100, 50]
|
||
|
optim = 'Adam'
|
||
|
lr = 3e-3
|
||
|
|
||
|
# create affine registration object
|
||
|
affine = AffineRegistration(scales, iterations, batch1, batch2, optimizer=optim, optimizer_lr=lr,
|
||
|
loss_type = 'mi',
|
||
|
cc_kernel_size=5)
|
||
|
|
||
|
# run registration
|
||
|
start = time()
|
||
|
transformed_images = affine.optimize(save_transformed=True)
|
||
|
end = time()
|
||
|
|
||
|
print("Runtime", end - start, "seconds")
|
||
|
|
||
|
reg = GreedyRegistration(scales=[4, 2, 1], iterations=[200, 100, 25],
|
||
|
fixed_images=batch1, moving_images=batch2,
|
||
|
cc_kernel_size=5, deformation_type='compositive',
|
||
|
smooth_grad_sigma=1,
|
||
|
loss_type = 'mi',
|
||
|
optimizer='adam', optimizer_lr=0.5, init_affine=affine.get_affine_matrix().detach())
|
||
|
|
||
|
start = time()
|
||
|
reg.optimize(save_transformed=False)
|
||
|
end = time()
|
||
|
|
||
|
print("Runtime", end - start, "seconds")
|
||
|
|
||
|
moved = reg.evaluate(batch1, batch2)
|
||
|
|
||
|
# reference_img = sitk.ReadImage("atlas_2mm_1000_3.nii.gz")
|
||
|
reference_img = sitk.ReadImage("/mnt/1218/Public/dataset2/M6/ZYRGTRKJ/20230728/CT/1.1_CyberKnife_head(MAR)_20230728111920_3.nii.gz")
|
||
|
|
||
|
# Preparing the moving image to be written out
|
||
|
moved_image_np = moved[0, 0].detach().cpu().numpy() # volumes are typically stored in tensors with dimensions [Batch, Channels, Depth, Height, Width], so extracting the latter 3 for nifti
|
||
|
moved_sitk_image = sitk.GetImageFromArray(moved_image_np)
|
||
|
moved_sitk_image.SetOrigin(reference_img.GetOrigin())
|
||
|
moved_sitk_image.SetSpacing(reference_img.GetSpacing())
|
||
|
moved_sitk_image.SetDirection(reference_img.GetDirection())
|
||
|
sitk.WriteImage(moved_sitk_image, 'reslice_deform_atlas_2mm_1000_3.nii.gz')
|
||
|
|