123/registration/ants_reg.py

100 lines
3.7 KiB
Python
Raw Normal View History

2023-08-08 22:04:06 +00:00
from pprint import pprint
import os
import tempfile
from skimage.metrics import normalized_mutual_information
import ants
class ants_reg:
def register_aux(self, fi, mv):
2025-02-01 07:57:22 +00:00
mytx = ants.registration(
fixed=fi,
moving=mv,
# type_of_transform = 'Rigid',
type_of_transform = 'SyNRA',
2023-08-08 22:04:06 +00:00
# verbose=True,
)
# print(mytx['fwdtransforms'][0])
fwdtransforms = ants.read_transform(mytx['fwdtransforms'][0])
m1 = normalized_mutual_information(fi.numpy(), mytx['warpedmovout'].numpy())
m2 = normalized_mutual_information(mv.numpy(), mytx['warpedfixout'].numpy())
print(m1, m2)
return {
'fwdtransforms': fwdtransforms,
2025-02-01 07:57:22 +00:00
# 'invtransforms': fwdtransforms.invert(),
# inverseTransform(): incompatible function arguments. The following argument types are supported:
# 1. inverseTransform(arg: ants.lib.AntsTransformF22, /) -> ants.lib.AntsTransformF22
# 2. inverseTransform(arg: ants.lib.AntsTransformF33, /) -> ants.lib.AntsTransformF33
# 3. inverseTransform(arg: ants.lib.AntsTransformF44, /) -> ants.lib.AntsTransformF44
# 4. inverseTransform(arg: ants.lib.AntsTransformD22, /) -> ants.lib.AntsTransformD22
# 5. inverseTransform(arg: ants.lib.AntsTransformD33, /) -> ants.lib.AntsTransformD33
# 6. inverseTransform(arg: ants.lib.AntsTransformD44, /) -> ants.lib.AntsTransformD44
# Invoked with types: ants.lib.AntsTransformDF3
2023-08-08 22:04:06 +00:00
'warpedfixout': mytx['warpedfixout'],
'warpedmovout': mytx['warpedmovout'],
'metrics': max(m1, m2)
}
def __init__(self, fi, mv, debug=False):
fixed_image = ants.image_read(fi, dimension=3)
moving_image = ants.image_read(mv, dimension=3)
r1 = self.register_aux(fixed_image, moving_image)
r2 = self.register_aux(moving_image, fixed_image)
if r1['metrics'] > r2['metrics']:
self.res = r1
else:
self.res = dict(r2)
self.res.update({
2025-02-01 07:57:22 +00:00
# 'fwdtransforms': r2['invtransforms'],
2023-08-08 22:04:06 +00:00
'invtransforms': r2['fwdtransforms'],
'warpedfixout': r2['warpedmovout'],
'warpedmovout': r2['warpedfixout'],
})
self.res.update({
'fix': fixed_image,
'mov': moving_image,
})
if debug:
pprint(self.res)
ants.image_write(fixed_image, '0fixed.nii.gz')
ants.image_write(moving_image, '0moving.nii.gz')
ants.image_write(r1['warpedfixout'], '0mf1.nii.gz')
ants.image_write(r1['warpedmovout'], '0fm1.nii.gz')
ants.image_write(r2['warpedmovout'], '0mf2.nii.gz')
ants.image_write(r2['warpedfixout'], '0fm2.nii.gz')
def get_metrics(self):
return self.res['metrics']
2025-02-01 07:57:22 +00:00
def write_warpedmovout(self, out):
ants.image_write(self.res['warpedmovout'], out)
2023-08-08 22:04:06 +00:00
def transform(self, moving, output_filename, is_label=False):
transform1 = next(tempfile._get_candidate_names())+'.mat'
# print(transform1)
ants.write_transform(self.res['fwdtransforms'], transform1)
mi = ants.image_read(moving, dimension=3)
if is_label:
transformed = ants.apply_transforms(self.res['fix'], mi,
transformlist=[transform1], interpolator='genericLabel').astype('uint8')
else:
transformed = ants.apply_transforms(self.res['fix'], mi,
transformlist=[transform1])
# print(transformed)
ants.image_write(transformed, output_filename)
os.remove(transform1)