313 lines
11 KiB
Python
313 lines
11 KiB
Python
import os
|
|
import h5py
|
|
import numpy as np
|
|
import surfa as sf
|
|
import tensorflow as tf
|
|
import voxelmorph as vxm
|
|
|
|
|
|
# Settings.
|
|
weights = {
|
|
'joint': ('synthmorph.affine.2.h5', 'synthmorph.deform.3.h5',),
|
|
'deform': ('synthmorph.deform.3.h5',),
|
|
'affine': ('synthmorph.affine.2.h5',),
|
|
'rigid': ('synthmorph.rigid.1.h5',),
|
|
}
|
|
|
|
|
|
def network_space(im, shape, center=None):
|
|
"""Construct transform from network space to the voxel space of an image.
|
|
|
|
Constructs a coordinate transform from the space the network will operate
|
|
in to the zero-based image index space. The network space has isotropic
|
|
1-mm voxels, left-inferior-anterior (LIA) orientation, and no shear. It is
|
|
centered on the field of view, or that of a reference image. This space is
|
|
an indexed voxel space, not world space.
|
|
|
|
Parameters
|
|
----------
|
|
im : surfa.Volume
|
|
Input image to construct the transform for.
|
|
shape : (3,) array-like
|
|
Spatial shape of the network space.
|
|
center : surfa.Volume, optional
|
|
Center the network space on the center of a reference image.
|
|
|
|
Returns
|
|
-------
|
|
out : tuple of (3, 4) NumPy arrays
|
|
Transform from network to input-image space and its inverse, thinking
|
|
coordinates.
|
|
|
|
"""
|
|
old = im.geom
|
|
new = sf.ImageGeometry(
|
|
shape=shape,
|
|
voxsize=1,
|
|
rotation='LIA',
|
|
center=old.center if center is None else center.geom.center,
|
|
shear=None,
|
|
)
|
|
|
|
net_to_vox = old.world2vox @ new.vox2world
|
|
vox_to_net = new.world2vox @ old.vox2world
|
|
return net_to_vox.matrix, vox_to_net.matrix
|
|
|
|
|
|
def transform(im, trans, shape=None, normalize=False, batch=False):
|
|
"""Apply a spatial transform to 3D image voxel data in dimensions.
|
|
|
|
Applies a transformation matrix operating in zero-based index space or a
|
|
displacement field to an image buffer.
|
|
|
|
Parameters
|
|
----------
|
|
im : surfa.Volume or NumPy array or TensorFlow tensor
|
|
Input image to transform, without batch dimension.
|
|
trans : array-like
|
|
Transform to apply to the image. A matrix of shape (3, 4), a matrix
|
|
of shape (4, 4), or a displacement field of shape (*space, 3),
|
|
without batch dimension.
|
|
shape : (3,) array-like, optional
|
|
Output shape used for converting matrices to dense transforms. None
|
|
means the shape of the input image will be used.
|
|
normalize : bool, optional
|
|
Min-max normalize the image intensities into the interval [0, 1].
|
|
batch : bool, optional
|
|
Prepend a singleton batch dimension to the output tensor.
|
|
|
|
Returns
|
|
-------
|
|
out : float TensorFlow tensor
|
|
Transformed image with a trailing feature dimension.
|
|
|
|
"""
|
|
# Add singleton feature dimension if needed.
|
|
if tf.rank(im) == 3:
|
|
im = im[..., tf.newaxis]
|
|
|
|
out = vxm.utils.transform(
|
|
im, trans, fill_value=0, shift_center=False, shape=shape,
|
|
)
|
|
|
|
if normalize:
|
|
out -= tf.reduce_min(out)
|
|
out /= tf.reduce_max(out)
|
|
|
|
if batch:
|
|
out = out[tf.newaxis, ...]
|
|
|
|
return out
|
|
|
|
|
|
def load_weights(model, weights):
|
|
"""Load weights into model or submodel.
|
|
|
|
Attempts to load (all) weights into a model or one of its submodels. If
|
|
that fails, `model` may be a submodel of what we got weights for, and we
|
|
attempt to load the weights of a submodel (layer) into `model`.
|
|
|
|
Parameters
|
|
----------
|
|
model : TensorFlow model
|
|
Model to initialize.
|
|
weights : str or pathlib.Path
|
|
Path to weights file.
|
|
|
|
Raises
|
|
------
|
|
ValueError
|
|
If unsuccessful at loading any weights.
|
|
|
|
"""
|
|
# Extract submodels.
|
|
models = [model]
|
|
i = 0
|
|
while i < len(models):
|
|
layers = [f for f in models[i].layers if isinstance(f, tf.keras.Model)]
|
|
models.extend(layers)
|
|
i += 1
|
|
|
|
# Add models wrapping a single model in case this was done in training.
|
|
# Requires list expansion or Python will get stuck.
|
|
models.extend([tf.keras.Model(m.inputs, m(m.inputs)) for m in models])
|
|
|
|
# Attempt to load all weights into one of the models.
|
|
for mod in models:
|
|
try:
|
|
mod.load_weights(weights)
|
|
return
|
|
except ValueError as e:
|
|
pass
|
|
|
|
# Assume `model` is a submodel of what we got weights for.
|
|
with h5py.File(weights, mode='r') as h5:
|
|
layers = h5.attrs['layer_names']
|
|
weights = [list(h5[lay].attrs['weight_names']) for lay in layers]
|
|
|
|
# Layers with weights. Attempt loading.
|
|
layers, weights = zip(*filter(lambda f: f[1], zip(layers, weights)))
|
|
for lay, wei in zip(layers, weights):
|
|
try:
|
|
model.set_weights([h5[lay][w] for w in wei])
|
|
return
|
|
except ValueError as e:
|
|
if lay is layers[-1]:
|
|
raise e
|
|
|
|
|
|
def register(arg):
|
|
|
|
# Parse arguments.
|
|
in_shape = (arg.extent,) * 3
|
|
is_mat = arg.model in ('affine', 'rigid')
|
|
|
|
# Threading.
|
|
if arg.threads:
|
|
tf.config.threading.set_inter_op_parallelism_threads(arg.threads)
|
|
tf.config.threading.set_intra_op_parallelism_threads(arg.threads)
|
|
|
|
# Input data.
|
|
mov = sf.load_volume(arg.moving)
|
|
fix = sf.load_volume(arg.fixed)
|
|
if not len(mov.shape) == len(fix.shape) == 3:
|
|
sf.system.fatal('input images are not single-frame volumes')
|
|
|
|
# Transforms between native voxel and network coordinates. Voxel and
|
|
# network spaces differ for each image. The networks expect isotropic 1-mm
|
|
# LIA spaces. Center these on the original images, except in the deformable
|
|
# case: it assumes prior affine registration, so we center the moving
|
|
# network space on the fixed image, to take into account affine transforms
|
|
# via resampling, updating the header, or passed on the command line alike.
|
|
center = fix if arg.model == 'deform' else None
|
|
net_to_mov, mov_to_net = network_space(mov, shape=in_shape, center=center)
|
|
net_to_fix, fix_to_net = network_space(fix, shape=in_shape)
|
|
|
|
# Coordinate transforms from and to world space. There is only one world.
|
|
mov_to_ras = mov.geom.vox2world.matrix
|
|
fix_to_ras = fix.geom.vox2world.matrix
|
|
ras_to_mov = mov.geom.world2vox.matrix
|
|
ras_to_fix = fix.geom.world2vox.matrix
|
|
|
|
# Incorporate an initial matrix transform from moving to fixed coordinates,
|
|
# as LTAs store the inverse. For mid-space initialization, compute the
|
|
# square root of the transform between fixed and moving network space.
|
|
if arg.init:
|
|
init = sf.load_affine(arg.init).convert(space='voxel')
|
|
if init.ndim != 3 \
|
|
or not sf.transform.image_geometry_equal(mov.geom, init.source, tol=1e-3) \
|
|
or not sf.transform.image_geometry_equal(fix.geom, init.target, tol=1e-3):
|
|
sf.system.fatal('initial transform geometry does not match images')
|
|
|
|
init = fix_to_net @ init @ net_to_mov
|
|
if arg.mid_space:
|
|
init = tf.linalg.sqrtm(init)
|
|
if np.any(np.isnan(init)):
|
|
sf.system.fatal(f'cannot compute matrix square root of {arg.init}')
|
|
net_to_fix = net_to_fix @ init
|
|
fix_to_net = np.linalg.inv(net_to_fix)
|
|
|
|
net_to_mov = net_to_mov @ tf.linalg.inv(init)
|
|
mov_to_net = np.linalg.inv(net_to_mov)
|
|
|
|
# Take the input images to network space. When saving the moving image with
|
|
# the correct voxel-to-RAS matrix after incorporating an initial transform,
|
|
# an image viewer taking this matrix into account will show an unchanged
|
|
# image. The networks only see the voxel data, which have been moved.
|
|
inputs = (
|
|
transform(mov, net_to_mov, shape=in_shape, normalize=True, batch=True),
|
|
transform(fix, net_to_fix, shape=in_shape, normalize=True, batch=True),
|
|
)
|
|
|
|
# Network. For deformable-only registration, `HyperVxmJoint` ignores the
|
|
# `mid_space` argument, and the initialization will determine the space.
|
|
prop = dict(in_shape=in_shape, bidir=True)
|
|
if is_mat:
|
|
prop.update(make_dense=False, rigid=arg.model == 'rigid')
|
|
model = vxm.networks.VxmAffineFeatureDetector(**prop)
|
|
|
|
else:
|
|
prop.update(mid_space=True, int_steps=arg.steps, skip_affine=arg.model == 'deform')
|
|
model = vxm.networks.HyperVxmJoint(**prop)
|
|
inputs = (tf.constant([arg.hyper]), *inputs)
|
|
|
|
# Weights.
|
|
if not arg.weights:
|
|
fs = os.environ.get('FREESURFER_HOME')
|
|
if not fs:
|
|
sf.system.fatal('set environment variable FREESURFER_HOME or weights')
|
|
arg.weights = [os.path.join(fs, 'models', f) for f in weights[arg.model]]
|
|
|
|
for f in arg.weights:
|
|
load_weights(model, weights=f)
|
|
|
|
# Inference. The first transform maps from the moving to the fixed image,
|
|
# or equivalently, from fixed to moving coordinates. The second is the
|
|
# inverse. Convert transforms between moving and fixed network spaces to
|
|
# transforms between the original voxel spaces.
|
|
pred = tuple(map(tf.squeeze, model(inputs)))
|
|
fw, bw = pred
|
|
fw = vxm.utils.compose((net_to_mov, fw, fix_to_net), shift_center=False, shape=fix.shape)
|
|
bw = vxm.utils.compose((net_to_fix, bw, mov_to_net), shift_center=False, shape=mov.shape)
|
|
|
|
# Associate image geometries with the transforms. LTAs store the inverse.
|
|
if is_mat:
|
|
fw, bw = bw, fw
|
|
fw = sf.Affine(fw, source=mov, target=fix, space='voxel')
|
|
bw = sf.Affine(bw, source=fix, target=mov, space='voxel')
|
|
format = dict(space='world')
|
|
|
|
else:
|
|
fw = sf.Warp(fw, source=mov, target=fix, format=sf.Warp.Format.disp_crs)
|
|
bw = sf.Warp(bw, source=fix, target=mov, format=sf.Warp.Format.disp_crs)
|
|
format = dict(format=sf.Warp.Format.disp_ras)
|
|
|
|
# Output transforms.
|
|
if arg.trans:
|
|
fw.convert(**format).save(arg.trans)
|
|
|
|
if arg.inverse:
|
|
bw.convert(**format).save(arg.inverse)
|
|
|
|
# Moved images.
|
|
if arg.out_moving:
|
|
mov.transform(fw, resample=not arg.header_only).save(arg.out_moving)
|
|
|
|
if arg.out_fixed:
|
|
fix.transform(bw, resample=not arg.header_only).save(arg.out_fixed)
|
|
|
|
# Outputs in network space.
|
|
if arg.out_dir:
|
|
arg.out_dir.mkdir(exist_ok=True)
|
|
|
|
# Input images.
|
|
mov = sf.ImageGeometry(in_shape, vox2world=mov_to_ras @ net_to_mov)
|
|
fix = sf.ImageGeometry(in_shape, vox2world=fix_to_ras @ net_to_fix)
|
|
mov = sf.Volume(inputs[-2][0], geometry=fix if arg.init else mov)
|
|
fix = sf.Volume(inputs[-1][0], geometry=fix)
|
|
mov.save(filename=arg.out_dir / 'inp_1.nii.gz')
|
|
fix.save(filename=arg.out_dir / 'inp_2.nii.gz')
|
|
|
|
fw, bw = pred
|
|
if is_mat:
|
|
fw, bw = bw, fw
|
|
fw = sf.Affine(fw, source=mov, target=fix, space='voxel')
|
|
bw = sf.Affine(bw, source=fix, target=mov, space='voxel')
|
|
ext = 'lta'
|
|
|
|
else:
|
|
fw = sf.Warp(fw, source=mov, target=fix, format=sf.Warp.Format.disp_crs)
|
|
bw = sf.Warp(bw, source=fix, target=mov, format=sf.Warp.Format.disp_crs)
|
|
ext = 'nii.gz'
|
|
|
|
# Transforms.
|
|
fw.convert(**format).save(filename=arg.out_dir / f'tra_1.{ext}')
|
|
bw.convert(**format).save(filename=arg.out_dir / f'tra_2.{ext}')
|
|
|
|
# Moved images.
|
|
mov.transform(fw).save(filename=arg.out_dir / 'out_1.nii.gz')
|
|
fix.transform(bw).save(filename=arg.out_dir / 'out_2.nii.gz')
|
|
|
|
vmpeak = sf.system.vmpeak()
|
|
if vmpeak is not None:
|
|
print(f'#@# mri_synthmorph: {arg.model}, threads: {arg.threads}, VmPeak: {vmpeak}')
|