ck-preprocess/mri_synthmorph/synthmorph/registration.py
2025-02-02 12:29:19 +08:00

313 lines
11 KiB
Python

import os
import h5py
import numpy as np
import surfa as sf
import tensorflow as tf
import voxelmorph as vxm
# Settings.
weights = {
'joint': ('synthmorph.affine.2.h5', 'synthmorph.deform.3.h5',),
'deform': ('synthmorph.deform.3.h5',),
'affine': ('synthmorph.affine.2.h5',),
'rigid': ('synthmorph.rigid.1.h5',),
}
def network_space(im, shape, center=None):
"""Construct transform from network space to the voxel space of an image.
Constructs a coordinate transform from the space the network will operate
in to the zero-based image index space. The network space has isotropic
1-mm voxels, left-inferior-anterior (LIA) orientation, and no shear. It is
centered on the field of view, or that of a reference image. This space is
an indexed voxel space, not world space.
Parameters
----------
im : surfa.Volume
Input image to construct the transform for.
shape : (3,) array-like
Spatial shape of the network space.
center : surfa.Volume, optional
Center the network space on the center of a reference image.
Returns
-------
out : tuple of (3, 4) NumPy arrays
Transform from network to input-image space and its inverse, thinking
coordinates.
"""
old = im.geom
new = sf.ImageGeometry(
shape=shape,
voxsize=1,
rotation='LIA',
center=old.center if center is None else center.geom.center,
shear=None,
)
net_to_vox = old.world2vox @ new.vox2world
vox_to_net = new.world2vox @ old.vox2world
return net_to_vox.matrix, vox_to_net.matrix
def transform(im, trans, shape=None, normalize=False, batch=False):
"""Apply a spatial transform to 3D image voxel data in dimensions.
Applies a transformation matrix operating in zero-based index space or a
displacement field to an image buffer.
Parameters
----------
im : surfa.Volume or NumPy array or TensorFlow tensor
Input image to transform, without batch dimension.
trans : array-like
Transform to apply to the image. A matrix of shape (3, 4), a matrix
of shape (4, 4), or a displacement field of shape (*space, 3),
without batch dimension.
shape : (3,) array-like, optional
Output shape used for converting matrices to dense transforms. None
means the shape of the input image will be used.
normalize : bool, optional
Min-max normalize the image intensities into the interval [0, 1].
batch : bool, optional
Prepend a singleton batch dimension to the output tensor.
Returns
-------
out : float TensorFlow tensor
Transformed image with a trailing feature dimension.
"""
# Add singleton feature dimension if needed.
if tf.rank(im) == 3:
im = im[..., tf.newaxis]
out = vxm.utils.transform(
im, trans, fill_value=0, shift_center=False, shape=shape,
)
if normalize:
out -= tf.reduce_min(out)
out /= tf.reduce_max(out)
if batch:
out = out[tf.newaxis, ...]
return out
def load_weights(model, weights):
"""Load weights into model or submodel.
Attempts to load (all) weights into a model or one of its submodels. If
that fails, `model` may be a submodel of what we got weights for, and we
attempt to load the weights of a submodel (layer) into `model`.
Parameters
----------
model : TensorFlow model
Model to initialize.
weights : str or pathlib.Path
Path to weights file.
Raises
------
ValueError
If unsuccessful at loading any weights.
"""
# Extract submodels.
models = [model]
i = 0
while i < len(models):
layers = [f for f in models[i].layers if isinstance(f, tf.keras.Model)]
models.extend(layers)
i += 1
# Add models wrapping a single model in case this was done in training.
# Requires list expansion or Python will get stuck.
models.extend([tf.keras.Model(m.inputs, m(m.inputs)) for m in models])
# Attempt to load all weights into one of the models.
for mod in models:
try:
mod.load_weights(weights)
return
except ValueError as e:
pass
# Assume `model` is a submodel of what we got weights for.
with h5py.File(weights, mode='r') as h5:
layers = h5.attrs['layer_names']
weights = [list(h5[lay].attrs['weight_names']) for lay in layers]
# Layers with weights. Attempt loading.
layers, weights = zip(*filter(lambda f: f[1], zip(layers, weights)))
for lay, wei in zip(layers, weights):
try:
model.set_weights([h5[lay][w] for w in wei])
return
except ValueError as e:
if lay is layers[-1]:
raise e
def register(arg):
# Parse arguments.
in_shape = (arg.extent,) * 3
is_mat = arg.model in ('affine', 'rigid')
# Threading.
if arg.threads:
tf.config.threading.set_inter_op_parallelism_threads(arg.threads)
tf.config.threading.set_intra_op_parallelism_threads(arg.threads)
# Input data.
mov = sf.load_volume(arg.moving)
fix = sf.load_volume(arg.fixed)
if not len(mov.shape) == len(fix.shape) == 3:
sf.system.fatal('input images are not single-frame volumes')
# Transforms between native voxel and network coordinates. Voxel and
# network spaces differ for each image. The networks expect isotropic 1-mm
# LIA spaces. Center these on the original images, except in the deformable
# case: it assumes prior affine registration, so we center the moving
# network space on the fixed image, to take into account affine transforms
# via resampling, updating the header, or passed on the command line alike.
center = fix if arg.model == 'deform' else None
net_to_mov, mov_to_net = network_space(mov, shape=in_shape, center=center)
net_to_fix, fix_to_net = network_space(fix, shape=in_shape)
# Coordinate transforms from and to world space. There is only one world.
mov_to_ras = mov.geom.vox2world.matrix
fix_to_ras = fix.geom.vox2world.matrix
ras_to_mov = mov.geom.world2vox.matrix
ras_to_fix = fix.geom.world2vox.matrix
# Incorporate an initial matrix transform from moving to fixed coordinates,
# as LTAs store the inverse. For mid-space initialization, compute the
# square root of the transform between fixed and moving network space.
if arg.init:
init = sf.load_affine(arg.init).convert(space='voxel')
if init.ndim != 3 \
or not sf.transform.image_geometry_equal(mov.geom, init.source, tol=1e-3) \
or not sf.transform.image_geometry_equal(fix.geom, init.target, tol=1e-3):
sf.system.fatal('initial transform geometry does not match images')
init = fix_to_net @ init @ net_to_mov
if arg.mid_space:
init = tf.linalg.sqrtm(init)
if np.any(np.isnan(init)):
sf.system.fatal(f'cannot compute matrix square root of {arg.init}')
net_to_fix = net_to_fix @ init
fix_to_net = np.linalg.inv(net_to_fix)
net_to_mov = net_to_mov @ tf.linalg.inv(init)
mov_to_net = np.linalg.inv(net_to_mov)
# Take the input images to network space. When saving the moving image with
# the correct voxel-to-RAS matrix after incorporating an initial transform,
# an image viewer taking this matrix into account will show an unchanged
# image. The networks only see the voxel data, which have been moved.
inputs = (
transform(mov, net_to_mov, shape=in_shape, normalize=True, batch=True),
transform(fix, net_to_fix, shape=in_shape, normalize=True, batch=True),
)
# Network. For deformable-only registration, `HyperVxmJoint` ignores the
# `mid_space` argument, and the initialization will determine the space.
prop = dict(in_shape=in_shape, bidir=True)
if is_mat:
prop.update(make_dense=False, rigid=arg.model == 'rigid')
model = vxm.networks.VxmAffineFeatureDetector(**prop)
else:
prop.update(mid_space=True, int_steps=arg.steps, skip_affine=arg.model == 'deform')
model = vxm.networks.HyperVxmJoint(**prop)
inputs = (tf.constant([arg.hyper]), *inputs)
# Weights.
if not arg.weights:
fs = os.environ.get('FREESURFER_HOME')
if not fs:
sf.system.fatal('set environment variable FREESURFER_HOME or weights')
arg.weights = [os.path.join(fs, 'models', f) for f in weights[arg.model]]
for f in arg.weights:
load_weights(model, weights=f)
# Inference. The first transform maps from the moving to the fixed image,
# or equivalently, from fixed to moving coordinates. The second is the
# inverse. Convert transforms between moving and fixed network spaces to
# transforms between the original voxel spaces.
pred = tuple(map(tf.squeeze, model(inputs)))
fw, bw = pred
fw = vxm.utils.compose((net_to_mov, fw, fix_to_net), shift_center=False, shape=fix.shape)
bw = vxm.utils.compose((net_to_fix, bw, mov_to_net), shift_center=False, shape=mov.shape)
# Associate image geometries with the transforms. LTAs store the inverse.
if is_mat:
fw, bw = bw, fw
fw = sf.Affine(fw, source=mov, target=fix, space='voxel')
bw = sf.Affine(bw, source=fix, target=mov, space='voxel')
format = dict(space='world')
else:
fw = sf.Warp(fw, source=mov, target=fix, format=sf.Warp.Format.disp_crs)
bw = sf.Warp(bw, source=fix, target=mov, format=sf.Warp.Format.disp_crs)
format = dict(format=sf.Warp.Format.disp_ras)
# Output transforms.
if arg.trans:
fw.convert(**format).save(arg.trans)
if arg.inverse:
bw.convert(**format).save(arg.inverse)
# Moved images.
if arg.out_moving:
mov.transform(fw, resample=not arg.header_only).save(arg.out_moving)
if arg.out_fixed:
fix.transform(bw, resample=not arg.header_only).save(arg.out_fixed)
# Outputs in network space.
if arg.out_dir:
arg.out_dir.mkdir(exist_ok=True)
# Input images.
mov = sf.ImageGeometry(in_shape, vox2world=mov_to_ras @ net_to_mov)
fix = sf.ImageGeometry(in_shape, vox2world=fix_to_ras @ net_to_fix)
mov = sf.Volume(inputs[-2][0], geometry=fix if arg.init else mov)
fix = sf.Volume(inputs[-1][0], geometry=fix)
mov.save(filename=arg.out_dir / 'inp_1.nii.gz')
fix.save(filename=arg.out_dir / 'inp_2.nii.gz')
fw, bw = pred
if is_mat:
fw, bw = bw, fw
fw = sf.Affine(fw, source=mov, target=fix, space='voxel')
bw = sf.Affine(bw, source=fix, target=mov, space='voxel')
ext = 'lta'
else:
fw = sf.Warp(fw, source=mov, target=fix, format=sf.Warp.Format.disp_crs)
bw = sf.Warp(bw, source=fix, target=mov, format=sf.Warp.Format.disp_crs)
ext = 'nii.gz'
# Transforms.
fw.convert(**format).save(filename=arg.out_dir / f'tra_1.{ext}')
bw.convert(**format).save(filename=arg.out_dir / f'tra_2.{ext}')
# Moved images.
mov.transform(fw).save(filename=arg.out_dir / 'out_1.nii.gz')
fix.transform(bw).save(filename=arg.out_dir / 'out_2.nii.gz')
vmpeak = sf.system.vmpeak()
if vmpeak is not None:
print(f'#@# mri_synthmorph: {arg.model}, threads: {arg.threads}, VmPeak: {vmpeak}')