''' metric > -0.27 is bad registration ''' import json import os # import pathlib import shelve import shutil import time import ants import filelock import matplotlib.pyplot as plt import numpy as np import SimpleITK as sitk # import skimage PATIENTS_ROOT = '/mnt/1220/Public/dataset2/G4' OUT_ROOT = '/mnt/1220/Public/dataset2/H4' SHELVE = os.path.join(OUT_ROOT, '0shelve') MAX_Y = 256 SIZE_X = 249 SIZE_Y = 249 SIZE_Z = 192 # SIZE_Z = 256 MIN_OVERLAP = 0.50 MIN_METRIC = -0.50 # def resize_with_crop_or_pad(image, tx = SIZE_X, ty = SIZE_Y, tz = SIZE_Z): def resize_with_pad(image, tx = SIZE_X, ty = SIZE_Y, tz = SIZE_Z): sx, sy, sz = image.GetSize() l = [(tx-sx)//2, (ty-sy)//2, (tz-sz)//2,] u = [tx-sx-l[0], ty-sy-l[1], tz-sz-l[2], ] # print (l, u) return sitk.ConstantPad(image, l, u) def draw_sitk(image, d, post): a = sitk.GetArrayFromImage(image) s = a.shape fig, axs = plt.subplots(1, 3) # fig.suptitle('%dx%dx%d'%(s[2], s[1], s[0])) axs.flat[0].imshow(a[s[0]//2,:,:], cmap='gray') axs.flat[1].imshow(a[:,s[1]//2,:], cmap='gray') axs.flat[2].imshow(a[:,:,s[2]//2], cmap='gray') axs.flat[0].axis('off') axs.flat[1].axis('off') axs.flat[2].axis('off') axs.flat[1].invert_yaxis() axs.flat[2].invert_yaxis() plt.tight_layout() os.makedirs(d, exist_ok=True) plt.savefig(os.path.join(d, '%dx%dx%d-%s'%(s[2],s[1],s[0],post))) plt.close() # exit() def bbox2_3D(img): r = np.any(img, axis=(1, 2)) c = np.any(img, axis=(0, 2)) z = np.any(img, axis=(0, 1)) if not np.any(r): return -1, -1, -1, -1, -1, -1 rmin, rmax = np.where(r)[0][[0, -1]] cmin, cmax = np.where(c)[0][[0, -1]] zmin, zmax = np.where(z)[0][[0, -1]] return rmin, rmax, cmin, cmax, zmin, zmax def registration(ct0, ct1, mr): ct0n = ct0.numpy() mrn = mr.numpy() if np.array_equal(ct0n, mrn): print('EQUAL') return { 'fwdtransforms': [], 'warpedfixout': ct0, 'warpedmovout': ct0, 'ct': 0, 'type': 'Identity', 'metric': -2, 'ratio': 1, } if ct0n.shape == mrn.shape: print('SAME SHAPE') CTS = (ct0, ct1) TYPES = ( # 'Translation', 'Rigid', 'QuickRigid', 'DenseRigid', 'BOLDRigid', ) else: print('others', mrn.shape) if min(mrn.shape) < 4: print('skip') return None CTS = (ct0, ct1) TYPES = ( # 'Translation', 'Rigid', 'QuickRigid', 'DenseRigid', 'BOLDRigid', ) TX = [] start = time.time() for m in range(len(CTS)): for typ in TYPES: ct = CTS[m] print(typ) mytx = ants.registration(ct, mr, typ) ones = np.ones(mr.numpy().shape) mask1 = mr.new_image_like(ones) mytx['mask'] = ants.apply_transforms(ct0, mask1, mytx['fwdtransforms'], interpolator='genericLabel', ) mytx['metric0M'] = ants.create_ants_metric(ct0, mytx['warpedmovout'], metric_type='MattesMutualInformation', moving_mask = mytx['mask'], ).get_value() mytx['metric1M'] = ants.create_ants_metric(ct1, mytx['warpedmovout'], metric_type='MattesMutualInformation', moving_mask = mytx['mask'], ).get_value() mytx['metricMM'] = ants.create_ants_metric(mr, mytx['warpedfixout'], metric_type='MattesMutualInformation', ).get_value() mytx['metric'] = min([mytx['metric0M'], mytx['metric1M'], mytx['metricMM']]) mytx['ct'] = m mytx['type'] = (typ, 'fwd') mytx['warpedout'] = mytx['warpedmovout'] TX.append(mytx) print(mytx['metric'], mytx['metric0M'], mytx['metric1M'], mytx['metricMM']) # ''' mytx = ants.registration(mr, ct, typ) ones = np.ones(mr.numpy().shape) mask1 = mr.new_image_like(ones) mytx['mask'] = ants.apply_transforms(ct0, mask1, mytx['fwdtransforms'], interpolator='genericLabel', whichtoinvert=[True], ) mytx['metricMM'] = ants.create_ants_metric(mr, mytx['warpedmovout'], metric_type='MattesMutualInformation', ).get_value() mytx['metric0M'] = ants.create_ants_metric(ct0, mytx['warpedfixout'], metric_type='MattesMutualInformation', fixed_mask = mytx['mask'], ).get_value() mytx['metric1M'] = ants.create_ants_metric(ct1, mytx['warpedfixout'], metric_type='MattesMutualInformation', fixed_mask = mytx['mask'], ).get_value() mytx['metric'] = min([mytx['metricMM'], mytx['metric0M'], mytx['metric1M']]) mytx['ct'] = m mytx['type'] = (typ, 'inv') mytx['warpedout'] = mytx['warpedfixout'] TX.append(mytx) print(mytx['metric'], mytx['metricMM'], mytx['metric0M'], mytx['metric1M']) if min([t['metric'] for t in TX]) < MIN_METRIC: break if min([t['metric'] for t in TX]) < MIN_METRIC: break print(time.time()-start, 'seconds') # exit() tx = { 'metric': 0, } for t in TX: if t['metric'] < tx['metric']: tx = t tx['ratio'] = tx['mask'].numpy().sum() / np.prod(ct0.numpy().shape) return tx def check(epath): registered = 0 for root, dirs, files in os.walk(epath): dirs.sort() RT_DIR = os.path.join(root, 'RT') ORGAN_DIR = os.path.join(RT_DIR, 'ORGAN') if not os.path.isdir(ORGAN_DIR): continue # if there is no eye, it's no a brain image eye = None organs = sorted(os.scandir(ORGAN_DIR), key=lambda e: e.name) for o in organs: if 'eye' in o.name.lower(): eye = o if eye is None: print('no eye... skip', root) # exit() return None ct_image = os.path.join(RT_DIR, 'ct_image.nii.gz') ct0 = sitk.ReadImage(ct_image) ct1 = sitk.Clamp(ct0, sitk.sitkUInt8, 0, 80) print(ct_image, ct0.GetSize()) outdir = os.path.join(OUT_ROOT, os.path.relpath(root, PATIENTS_ROOT)) print(outdir) os.makedirs(outdir, exist_ok=True) ct0_nii = os.path.join(outdir, 'ct0.nii.gz') ct1_nii = os.path.join(outdir, 'ct1.nii.gz') shutil.copy(ct_image, ct0_nii) sitk.WriteImage(ct1, ct1_nii) # sitk.WriteImage(sitk.DICOMOrient(ct0), ct0_nii) # sitk.WriteImage(sitk.DICOMOrient(ct1), ct1_nii) ct0 = ants.image_read(ct0_nii, reorient=True) ct1 = ants.image_read(ct1_nii, reorient=True) for root2, dirs2, files2 in os.walk(root): dirs2.sort() skip = (root2==root) or ('RT' in root2.split('/')) if skip: continue if root2.endswith('CT'): modality = 'CT' # continue else: modality = 'other' print(skip, root2, modality) outdir = os.path.join(OUT_ROOT, os.path.relpath(root2, PATIENTS_ROOT)) os.makedirs(outdir, exist_ok=True) for e in sorted(os.scandir(root2), key=lambda e: e.name): if not e.name.endswith('.nii.gz'): continue if '_RTDOSE_' in e.name: continue if '_DTI_' in e.name: continue if '_ROI1.' in e.name: continue OUT_IMG = os.path.join(outdir, e.name) if os.path.isfile(OUT_IMG): print('skip', OUT_IMG) continue print(e.name, e.path) fix = ants.image_read(e.path, reorient=True) mytx = registration(ct0, ct1, fix) if mytx is None: continue print(mytx['ratio'], mytx['metric']) registered += 1 os.makedirs(outdir, exist_ok=True) OUT_WARP = OUT_IMG.replace('.nii.gz', '.warp.nii.gz') OUT_MSK = OUT_IMG.replace('.nii.gz', '.mask.nii.gz') OUT_JSON = OUT_IMG.replace('.nii.gz', '.json') OUT_MAT = OUT_IMG.replace('.nii.gz', '.mat') jj = { 'ct' : mytx['ct'], 'type' : mytx['type'], 'metric': mytx['metric'], 'ratio' : mytx['ratio'], } with open(OUT_JSON, 'w') as f: json.dump(jj, f, indent=1) if mytx['fwdtransforms']: shutil.copy(mytx['fwdtransforms'][0], OUT_MAT) ants.image_write(mytx['mask'], OUT_MSK) ants.image_write(mytx['warpedout'], OUT_WARP) metric = mytx['metric'] metric_dir = os.path.join(OUT_ROOT, "1") OUT_TXT = os.path.join(metric_dir, '%f-%s'%(metric, e.name.replace('.nii.gz', '.txt'))) jj['dir'] = outdir os.makedirs(metric_dir, exist_ok=True) with open(OUT_TXT, 'w') as f: json.dump(jj, f, indent=1) shutil.copy(e.path, OUT_IMG) return registered def main(): # check('/mnt/1220/Public/dataset2/G4/3L6LOEER') # bad registration # exit() EXCLUDE = ( # 'LLUQJUY4', #cervical ) os.makedirs(OUT_ROOT, exist_ok=True) LOCK_DIR = os.path.join(OUT_ROOT, '0lock') os.makedirs(LOCK_DIR, exist_ok=True) for e in sorted(os.scandir(PATIENTS_ROOT), key=lambda e: e.name): if e.is_dir(): d = shelve.open(SHELVE) if e.name in d or e.name in EXCLUDE: print('skip', e.name) d.close() continue d.close() lock_path = os.path.join(LOCK_DIR, '%s.lock'%e.name) lock = filelock.FileLock(lock_path, timeout=1) try: lock.acquire() except: print(lock_path, 'locked') continue ret = check(e.path) lock.release() # exit() d = shelve.open(SHELVE) d[e.name] = ret d.close() if __name__ == '__main__': main()