diff --git a/src/g4synthmorph.py b/src/g4synthmorph.py index 9b2ee9e..ac1ea6b 100644 --- a/src/g4synthmorph.py +++ b/src/g4synthmorph.py @@ -30,7 +30,7 @@ from skimage.metrics import normalized_mutual_information import filelock import matplotlib.pyplot as plt import numpy as np -import SimpleITK as sitk +# import SimpleITK as sitk from mri_synthmorph.synthmorph import registration # from synthmorph import registration @@ -63,40 +63,6 @@ logging.basicConfig( logger = logging.getLogger(__name__) -# def resize_with_crop_or_pad(image, tx = SIZE_X, ty = SIZE_Y, tz = SIZE_Z): -def resize_with_pad(image, tx = SIZE_X, ty = SIZE_Y, tz = SIZE_Z): - sx, sy, sz = image.GetSize() - l = [(tx-sx)//2, - (ty-sy)//2, - (tz-sz)//2,] - u = [tx-sx-l[0], - ty-sy-l[1], - tz-sz-l[2], - ] - # print (l, u) - return sitk.ConstantPad(image, l, u) - - -def draw_sitk(image, d, post): - a = sitk.GetArrayFromImage(image) - s = a.shape - - fig, axs = plt.subplots(1, 3) - # fig.suptitle('%dx%dx%d'%(s[2], s[1], s[0])) - axs.flat[0].imshow(a[s[0]//2,:,:], cmap='gray') - axs.flat[1].imshow(a[:,s[1]//2,:], cmap='gray') - axs.flat[2].imshow(a[:,:,s[2]//2], cmap='gray') - axs.flat[0].axis('off') - axs.flat[1].axis('off') - axs.flat[2].axis('off') - axs.flat[1].invert_yaxis() - axs.flat[2].invert_yaxis() - plt.tight_layout() - os.makedirs(d, exist_ok=True) - plt.savefig(os.path.join(d, '%dx%dx%d-%s'%(s[2],s[1],s[0],post))) - plt.close() - # exit() - def bbox2_3D(img): r = np.any(img, axis=(1, 2)) @@ -132,6 +98,8 @@ def register(ct0, ct1, moving, out_root): logger.info(' '.join((ct0, ct1, moving, str(out_root)))) orig = sf.load_volume(moving) + + base = sf.load_volume(ct0) if modality == 'CT': clipped = out_root/'clipped.nii.gz' @@ -201,6 +169,11 @@ def register(ct0, ct1, moving, out_root): os.environ["FREESURFER_HOME"] = FREESURFER_HOME os.environ["XLA_FLAGS"] = '--xla_gpu_cuda_data_dir=%s'% os.environ["CONDA_PREFIX"] + fill = orig.min() + # print(fill) + # exit() + + for m in MODELS: default['model'] = m default['out_dir'] = out_root/m @@ -231,16 +204,18 @@ def register(ct0, ct1, moving, out_root): out = out_root/('%s.nii.gz'%m) if m in ['affine', 'rigid']: trans = sf.load_affine(default['out_dir']/'tra_1.lta') - prop = dict(method='linear', resample=True, fill=0) - orig.transform(trans, **prop).save(out) + prop = dict(method='linear', resample=True, fill=fill) + orig.transform(trans, **prop).resample_like(base, fill=fill).save(out) logger.info('transformed %s'%out) + # print(prop) + # exit() else: # need to resample before transform in warp, too complicated, just copy it # trans1 = default['out_dir']/'tra_1.nii.gz' # trans = sf.load_warp(trans1) - shutil.copy(default['out_dir']/'out_1.nii.gz', out) - logger.info('copied %s'% out) + sf.load_volume(default['out_dir']/'out_1.nii.gz').resample_like(base, fill=fill).save(out) + logger.info('resampled %s'% out) with open(out_root/'metric.txt', 'w') as f_metrics: for m in MODELS: @@ -295,7 +270,7 @@ def check(epath): if root2.endswith('RT'): modality = 'RT' logger.info('copying %s %s' %(root2, outdir)) - shutil.copytree(root2, outdir) + shutil.copytree(root2, outdir, dirs_exist_ok=True) # exit() continue skip = (root2==root) or ('RT' in root2.split('/'))