157 lines
5.6 KiB
Python
157 lines
5.6 KiB
Python
![]() |
from pprint import pprint
|
||
|
from time import time
|
||
|
|
||
|
import os
|
||
|
import tempfile
|
||
|
|
||
|
from skimage.metrics import normalized_mutual_information
|
||
|
|
||
|
from fireants.io import Image, BatchedImages
|
||
|
from fireants.registration import RigidRegistration
|
||
|
|
||
|
import SimpleITK as sitk
|
||
|
|
||
|
class fireants_reg:
|
||
|
|
||
|
def register_aux(self, fi, mv):
|
||
|
mytx = ants.registration(
|
||
|
fixed=fi,
|
||
|
moving=mv,
|
||
|
# type_of_transform = 'Rigid',
|
||
|
type_of_transform = 'SyNRA',
|
||
|
# verbose=True,
|
||
|
)
|
||
|
|
||
|
# print(mytx['fwdtransforms'][0])
|
||
|
fwdtransforms = ants.read_transform(mytx['fwdtransforms'][0])
|
||
|
|
||
|
m1 = normalized_mutual_information(fi.numpy(), mytx['warpedmovout'].numpy())
|
||
|
m2 = normalized_mutual_information(mv.numpy(), mytx['warpedfixout'].numpy())
|
||
|
|
||
|
|
||
|
print(m1, m2)
|
||
|
|
||
|
return {
|
||
|
'fwdtransforms': fwdtransforms,
|
||
|
# 'invtransforms': fwdtransforms.invert(),
|
||
|
# inverseTransform(): incompatible function arguments. The following argument types are supported:
|
||
|
# 1. inverseTransform(arg: ants.lib.AntsTransformF22, /) -> ants.lib.AntsTransformF22
|
||
|
# 2. inverseTransform(arg: ants.lib.AntsTransformF33, /) -> ants.lib.AntsTransformF33
|
||
|
# 3. inverseTransform(arg: ants.lib.AntsTransformF44, /) -> ants.lib.AntsTransformF44
|
||
|
# 4. inverseTransform(arg: ants.lib.AntsTransformD22, /) -> ants.lib.AntsTransformD22
|
||
|
# 5. inverseTransform(arg: ants.lib.AntsTransformD33, /) -> ants.lib.AntsTransformD33
|
||
|
# 6. inverseTransform(arg: ants.lib.AntsTransformD44, /) -> ants.lib.AntsTransformD44
|
||
|
|
||
|
# Invoked with types: ants.lib.AntsTransformDF3
|
||
|
|
||
|
'warpedfixout': mytx['warpedfixout'],
|
||
|
'warpedmovout': mytx['warpedmovout'],
|
||
|
'metrics': max(m1, m2)
|
||
|
}
|
||
|
|
||
|
def __init__(self, fi, mv, debug=False):
|
||
|
fixed_image = ants.image_read(fi, dimension=3)
|
||
|
moving_image = ants.image_read(mv, dimension=3)
|
||
|
r1 = self.register_aux(fixed_image, moving_image)
|
||
|
r2 = self.register_aux(moving_image, fixed_image)
|
||
|
|
||
|
|
||
|
if r1['metrics'] > r2['metrics']:
|
||
|
self.res = r1
|
||
|
else:
|
||
|
self.res = dict(r2)
|
||
|
self.res.update({
|
||
|
# 'fwdtransforms': r2['invtransforms'],
|
||
|
'invtransforms': r2['fwdtransforms'],
|
||
|
'warpedfixout': r2['warpedmovout'],
|
||
|
'warpedmovout': r2['warpedfixout'],
|
||
|
})
|
||
|
self.res.update({
|
||
|
'fix': fixed_image,
|
||
|
'mov': moving_image,
|
||
|
})
|
||
|
|
||
|
|
||
|
if debug:
|
||
|
pprint(self.res)
|
||
|
ants.image_write(fixed_image, '0fixed.nii.gz')
|
||
|
ants.image_write(moving_image, '0moving.nii.gz')
|
||
|
ants.image_write(r1['warpedfixout'], '0mf1.nii.gz')
|
||
|
ants.image_write(r1['warpedmovout'], '0fm1.nii.gz')
|
||
|
ants.image_write(r2['warpedmovout'], '0mf2.nii.gz')
|
||
|
ants.image_write(r2['warpedfixout'], '0fm2.nii.gz')
|
||
|
|
||
|
def get_metrics(self):
|
||
|
return self.res['metrics']
|
||
|
|
||
|
def write_warpedmovout(self, out):
|
||
|
ants.image_write(self.res['warpedmovout'], out)
|
||
|
|
||
|
|
||
|
def transform(self, moving, output_filename, is_label=False):
|
||
|
transform1 = next(tempfile._get_candidate_names())+'.mat'
|
||
|
# print(transform1)
|
||
|
ants.write_transform(self.res['fwdtransforms'], transform1)
|
||
|
mi = ants.image_read(moving, dimension=3)
|
||
|
if is_label:
|
||
|
transformed = ants.apply_transforms(self.res['fix'], mi,
|
||
|
transformlist=[transform1], interpolator='genericLabel').astype('uint8')
|
||
|
else:
|
||
|
transformed = ants.apply_transforms(self.res['fix'], mi,
|
||
|
transformlist=[transform1])
|
||
|
# print(transformed)
|
||
|
ants.image_write(transformed, output_filename)
|
||
|
os.remove(transform1)
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
|
||
|
fi = '/nn/7295866/20250127/nii/a_1.1_CyberKnife_head(MAR)_20250127111447_5.nii.gz'
|
||
|
mv = '/nn/7295866/20250127/nii/7_3D_SAG_T1_MPRAGE_+C_20250127132612_100.nii.gz'
|
||
|
|
||
|
# load the images
|
||
|
image1 = Image.load_file(fi)
|
||
|
image2 = Image.load_file(mv)
|
||
|
# batchify them (we only have a single image per batch, but we can pass multiple images)
|
||
|
fixed_batch = BatchedImages([image1])
|
||
|
moving_batch = BatchedImages([image2])
|
||
|
|
||
|
# rigid registration
|
||
|
scales = [4, 2, 1] # scales at which to perform registration
|
||
|
iterations = [200, 100, 50]
|
||
|
scales = [4, 2] # scales at which to perform registration
|
||
|
iterations = [200, 100]
|
||
|
optim = 'Adam'
|
||
|
lr = 3e-4
|
||
|
|
||
|
# create rigid registration object
|
||
|
rigid_reg = RigidRegistration(
|
||
|
scales, iterations, fixed_batch, moving_batch,
|
||
|
loss_type = 'mi',
|
||
|
# mi_kernel_type = 'gaussian',
|
||
|
# optimizer=optim, optimizer_lr=lr,
|
||
|
# cc_kernel_size=5,
|
||
|
)
|
||
|
# call method
|
||
|
# rigid_reg.optimize()
|
||
|
|
||
|
start = time()
|
||
|
rigid_reg.optimize(save_transformed=False)
|
||
|
end = time()
|
||
|
|
||
|
print("Runtime", end - start, "seconds")
|
||
|
|
||
|
moved = rigid_reg.evaluate(fixed_batch, moving_batch)
|
||
|
|
||
|
reference_img = sitk.ReadImage(fi)
|
||
|
|
||
|
# Preparing the moving image to be written out
|
||
|
moved_image_np = moved[0, 0].detach().cpu().numpy() # volumes are typically stored in tensors with dimensions [Batch, Channels, Depth, Height, Width], so extracting the latter 3 for nifti
|
||
|
moved_sitk_image = sitk.GetImageFromArray(moved_image_np)
|
||
|
moved_sitk_image.SetOrigin(reference_img.GetOrigin())
|
||
|
moved_sitk_image.SetSpacing(reference_img.GetSpacing())
|
||
|
moved_sitk_image.SetDirection(reference_img.GetDirection())
|
||
|
sitk.WriteImage(moved_sitk_image, 'tmp.nii.gz')
|
||
|
# reg_only(fi, mv_img, 'tmp.nii.gz')
|
||
|
|
||
|
|
||
|
|