import json import os import shelve import shutil import time import ants import matplotlib.pyplot as plt import numpy as np import SimpleITK as sitk PATIENTS_ROOT = '/mnt/1220/Public/dataset2/G4' OUT_ROOT = '/home/xfr/git9/Taipei-1/0' OUT_ROOT = '/mnt/1220/Public/dataset2/G5' SHELVE = os.path.join(OUT_ROOT, '0shelve') MAX_Y = 256 SIZE_X = 249 SIZE_Y = 249 SIZE_Z = 192 # SIZE_Z = 256 MIN_OVERLAP = 0.50 MIN_METRIC = -0.50 # def resize_with_crop_or_pad(image, tx = SIZE_X, ty = SIZE_Y, tz = SIZE_Z): def resize_with_pad(image, tx = SIZE_X, ty = SIZE_Y, tz = SIZE_Z): sx, sy, sz = image.GetSize() l = [(tx-sx)//2, (ty-sy)//2, (tz-sz)//2,] u = [tx-sx-l[0], ty-sy-l[1], tz-sz-l[2], ] # print (l, u) return sitk.ConstantPad(image, l, u) def draw_sitk(image, d, post): a = sitk.GetArrayFromImage(image) s = a.shape fig, axs = plt.subplots(1, 3) # fig.suptitle('%dx%dx%d'%(s[2], s[1], s[0])) axs.flat[0].imshow(a[s[0]//2,:,:], cmap='gray') axs.flat[1].imshow(a[:,s[1]//2,:], cmap='gray') axs.flat[2].imshow(a[:,:,s[2]//2], cmap='gray') axs.flat[0].axis('off') axs.flat[1].axis('off') axs.flat[2].axis('off') axs.flat[1].invert_yaxis() axs.flat[2].invert_yaxis() plt.tight_layout() os.makedirs(d, exist_ok=True) plt.savefig(os.path.join(d, '%dx%dx%d-%s'%(s[2],s[1],s[0],post))) plt.close() # exit() def bbox2_3D(img): r = np.any(img, axis=(1, 2)) c = np.any(img, axis=(0, 2)) z = np.any(img, axis=(0, 1)) if not np.any(r): return -1, -1, -1, -1, -1, -1 rmin, rmax = np.where(r)[0][[0, -1]] cmin, cmax = np.where(c)[0][[0, -1]] zmin, zmax = np.where(z)[0][[0, -1]] return rmin, rmax, cmin, cmax, zmin, zmax def physical_size(a): A = a.numpy().shape B = a.spacing C = ( A[0]*B[0], A[1]*B[1], A[2]*B[2], ) print(A,B,C) return C def registration(ct0, ct1, mr, modality = 'CT'): # print( # physical_size(ct0), # physical_size(moving) # ) TX = [] FIXS = (ct0, ct1) # FIXS = (ct1,) TYPES = ( # 'Translation', 'Rigid', 'QuickRigid', 'DenseRigid', 'BOLDRigid', # 'Similarity', # 'Affine', # 'AffineFast', # 'BOLDAffine', # 'TRSAA', ) if modality == 'CT': FIXS = (ct1, ct0) # TYPES = ( # # 'Translation', # 'Rigid', # ) start = time.time() for ff in range(len(FIXS)): for typ in TYPES: ct = FIXS[ff] print(typ) # ''' mytx = ants.registration(ct, mr, typ) ones = np.ones(mr.numpy().shape) mask1 = mr.new_image_like(ones) mytx['mask'] = ants.apply_transforms(ct0, mask1, mytx['fwdtransforms'], interpolator='genericLabel', ) am = ants.create_ants_metric(ct, mytx['warpedmovout'], metric_type='MattesMutualInformation', moving_mask = mytx['mask'], ) mytx['metricFM'] = am.get_value() am = ants.create_ants_metric(mr, mytx['warpedfixout'], metric_type='MattesMutualInformation') mytx['metricMM'] = am.get_value() mytx['metric'] = min([mytx['metricFM'], mytx['metricMM']]) mytx['ct'] = ff mytx['type'] = (typ, 'fwd') mytx['warpedout'] = mytx['warpedmovout'] TX.append(mytx) print(mytx['metric'], mytx['metricFM'], mytx['metricMM']) # ''' mytx = ants.registration(mr, ct, typ) ones = np.ones(mr.numpy().shape) mask1 = mr.new_image_like(ones) mytx['mask'] = ants.apply_transforms(ct0, mask1, mytx['fwdtransforms'], interpolator='genericLabel', whichtoinvert=[True], ) am = ants.create_ants_metric(mr, mytx['warpedmovout'], metric_type='MattesMutualInformation') mytx['metricFM'] = am.get_value() am = ants.create_ants_metric(ct, mytx['warpedfixout'], metric_type='MattesMutualInformation', fixed_mask = mytx['mask'], ) mytx['metricMM'] = am.get_value() mytx['metric'] = min([mytx['metricFM'], mytx['metricMM']]) mytx['ct'] = ff mytx['type'] = (typ, 'inv') mytx['warpedout'] = mytx['warpedfixout'] TX.append(mytx) print(mytx['metric'], mytx['metricFM'], mytx['metricMM']) if min([t['metric'] for t in TX]) < MIN_METRIC: break if min([t['metric'] for t in TX]) < MIN_METRIC: break print(time.time()-start, 'seconds') tx = { 'metric': 0, } for t in TX: if t['metric'] < tx['metric']: tx = t tx['ratio'] = tx['mask'].numpy().sum() / np.prod(ct0.numpy().shape) return tx def check(epath): roi=None for root, dirs, files in os.walk(epath): dirs.sort() RT_DIR = os.path.join(root, 'RT') EXTERNAL_DIR = os.path.join(RT_DIR, 'EXTERNAL') if not os.path.isdir(EXTERNAL_DIR): continue ORGAN_DIR = os.path.join(RT_DIR, 'ORGAN') if not os.path.isdir(ORGAN_DIR): continue # if there is no eye, it's no a brain image eye = None organs = sorted(os.scandir(ORGAN_DIR), key=lambda e: e.name) for o in organs: if 'eye' in o.name.lower(): eye = o if eye is None: print('no eye... skip', root) # exit() return None CT_IMAGE = os.path.join(RT_DIR, 'ct_image.nii.gz') externals = sorted(os.scandir(EXTERNAL_DIR), key=lambda e: -e.stat().st_size) # for e in externals: # print(e.name, e.stat().st_size) print(externals[0].path) external = sitk.ReadImage(externals[0].path) # print(external.GetSize()) # print(external.GetSpacing()) sz = external.GetSize() sp = external.GetSpacing() outputSize = ( round(sz[0]*sp[0]), round(sz[1]*sp[1]), round(sz[2]*sp[2]), ) outputOrigin = external.GetOrigin() outputSpacing = (1., 1., 1.) outputDirection = external.GetDirection() mask = sitk.Resample(external, outputSize, sitk.Transform(), sitk.sitkNearestNeighbor, outputOrigin, outputSpacing, outputDirection) # sitk.WriteImage(mask, '0mask.nii.gz') # print(mask.GetSize()) # print(mask.GetSpacing()) lsif = sitk.LabelShapeStatisticsImageFilter() lsif.Execute(mask) if 1 not in lsif.GetLabels(): print('strange mask... skip', root) return None boundingBox = lsif.GetBoundingBox(1) print(boundingBox) ix, iy, iz, sx, sy, sz = boundingBox sy = min(sy, MAX_Y) sz0 = min(sz, SIZE_Z-2) # will pad later iz0 = iz+sz-sz0 roi = sitk.RegionOfInterest(mask, (sx, sy, sz0), (ix, iy, iz0)) # remove frame ER = 10 ER = 29 ER = 44 ER = 40 # remove most of it ER = 37 # remove most of it roi = sitk.BinaryErode(roi, (ER,ER,ER)) roi2 = sitk.BinaryDilate(roi, (ER,ER,ER)) # sitk.WriteImage(roi, '0roi.nii.gz') # exit() # print(external.GetSize()) # print(roi.GetSize()) #Now no shoulder, get BB again lsif.Execute(roi2) # Something wrong? 6L63PXNV # print(lsif.GetLabels()) if 1 not in lsif.GetLabels(): return None boundingBox = lsif.GetBoundingBox(1) print(boundingBox) ix2, iy2, iz2, sx2, sy2, sz2 = boundingBox # sy2 = sy iy2 = 0 # iy2 = min(iy2, roi2.GetSize()[1]-sy2) sy2 = (sy+sy2)//2 print((sx2, sy2, sz2), (ix2, iy2, iz2)) roi = sitk.RegionOfInterest(roi2, (sx2, sy2, sz0), (ix2, iy2, 0)) # roi = resize_with_pad(roi) # roi = sitk.FFTPad(roi) ypad = min(2, (SIZE_Y-sy2)//2) roi = sitk.ConstantPad(roi, (1,ypad,1), (1,ypad,1)) # roi = sitk.ConstantPad(roi, (0,0,0), (0,SIZE_Y-sy2,0)) print(external.GetSize()) print(roi.GetSize()) # sitk.WriteImage(external, '0external.nii.gz') # sitk.WriteImage(roi, '0roi.nii.gz') # exit() ct = sitk.ReadImage(CT_IMAGE, # sitk.sitkFloat32 ) CTresampled = sitk.Resample(ct, roi, # sitk.Transform(), # sitk.sitkBSpline, ) rescaled = sitk.RescaleIntensity(sitk.Clamp(CTresampled, sitk.sitkUInt8, # sitk.sitkFloat32, 0, 80)) TV = None for e in sorted(os.scandir(RT_DIR), key=lambda e: e.name): if not e.name.startswith('TV-'): continue tv = sitk.ReadImage(e.path) TVresampled = sitk.Resample(tv, roi, sitk.Transform(), sitk.sitkNearestNeighbor, ) if TV is None: TV = TVresampled else: TV = sitk.Maximum(TV, TVresampled) if TV is None: continue EXresampled = sitk.Resample(external, roi, sitk.Transform(), sitk.sitkNearestNeighbor, ) relpath = os.path.relpath(root, PATIENTS_ROOT) OUT_DIR = os.path.join(OUT_ROOT, relpath) os.makedirs(OUT_DIR, exist_ok=True) CT0 = os.path.join(OUT_DIR, 'ct0.nii.gz') CT1 = os.path.join(OUT_DIR, 'ct1.nii.gz') print(CT0) sitk.WriteImage(CTresampled, CT0) sitk.WriteImage(rescaled, CT1) sitk.WriteImage(TV+EXresampled, os.path.join(OUT_DIR, 'label.nii.gz')) draw_sitk(rescaled, os.path.join(OUT_ROOT, '0'), '%s.png'%relpath.replace('/', '-')) # exit() if max(roi.GetSize()) >= SIZE_X+64: #BODY continue if max(roi.GetSize()) > SIZE_X: exit() # exit() # continue # skip registration first ct0 = ants.image_read(CT0) ct1 = ants.image_read(CT1) # fixed = ants.image_read(CT0) for root2, dirs2, files2 in os.walk(root): dirs2.sort() skip = (root2==root) or ('RT' in root2.split('/')) if skip: continue if root2.endswith('CT'): modality = 'CT' fixed=ct1 else: modality = 'other' fixed=ct1 print(skip, root2, modality) relpath = os.path.relpath(root2, PATIENTS_ROOT) out_dir = os.path.join(OUT_ROOT, relpath) for mv in sorted(os.scandir(root2), key=lambda e: e.name): if not mv.name.endswith('nii.gz'): continue if mv.name.startswith('dose'): continue if '_DTI_' in mv.name: continue if '_RTDOSE_' in mv.name: continue if '_ROI1.' in mv.name: continue # if '_swi_' in mv.name: # continue # if '_ph.' in mv.name: # continue # if '_Doses_' in mv.name: # continue # if '_phMag.' in mv.name: # continue OUT_IMG = os.path.join(out_dir, mv.name) if os.path.isfile(OUT_IMG): print('skip', OUT_IMG) continue split = mv.name[:-7].split('_') series = None pos = -1 while series is None: try: series = int(split[pos]) except: pos -= 1 print(mv.name, series) # if series > 99: # continue moving = ants.image_read(mv.path) # skip thin images _, _, _, _, zmin, zmax = bbox2_3D(moving.numpy()) mv_height = moving.spacing[2] * (zmax-zmin) z_ratio = mv_height / SIZE_Z if z_ratio < MIN_OVERLAP: print("skip", mv.name, mv_height, z_ratio) continue print(mv.path, moving.view().shape, mv_height, z_ratio) if moving.max() == 0.0: print('Total Mass of the image was zero. Aborting here to prevent division by zero later on.') continue mytx = registration(ct0, ct1, moving, modality) print(mytx['ratio'], mytx['metric']) # if (mytx['metric'] < MIN_METRIC) and (mytx['ratio'] < MIN_OVERLAP): # print("skip", mv.name, mytx['ratio']) # continue os.makedirs(out_dir, exist_ok=True) OUT_MSK = OUT_IMG.replace('.nii.gz', '.mask.nii.gz') OUT_JSON = OUT_IMG.replace('.nii.gz', '.json') OUT_MAT = OUT_IMG.replace('.nii.gz', '.mat') jj = { 'ct' : mytx['ct'], 'type' : mytx['type'], 'metric': mytx['metric'], 'ratio' : mytx['ratio'], } with open(OUT_JSON, 'w') as f: json.dump(jj, f, indent=1) shutil.copy(mytx['fwdtransforms'][0], OUT_MAT) ants.image_write(mytx['mask'], OUT_MSK) ants.image_write(mytx['warpedout'], OUT_IMG) metric = mytx['metric'] metric_dir = os.path.join(OUT_ROOT, "1") OUT_TXT = os.path.join(metric_dir, '%f-%s'%(metric, mv.name.replace('.nii.gz', '.txt'))) jj['dir'] = out_dir os.makedirs(metric_dir, exist_ok=True) with open(OUT_TXT, 'w') as f: json.dump(jj, f, indent=1) # exit() # exit() if roi is None: return None return roi.GetSize() def main(): # check('/mnt/1220/Public/dataset2/G4/2UK7274S') # check('/mnt/1220/Public/dataset2/G4/6L63PXNV') # Took MR for base image...:( # check('/mnt/1220/Public/dataset2/G4/5IJBRZ4K') # Large frame ... Need erode 29 ? # check('/mnt/1220/Public/dataset2/G4/7RC7JOLL') # Large frame ... Need erode 44 ? # check('/mnt/1220/Public/dataset2/G4/O7IOGTPD') # Large frame ... Need erode 37 ? # check('/mnt/1220/Public/dataset2/G4/2FHZOOLU') # Poor registration # check('/mnt/1220/Public/dataset2/G4/3N6JVTD4') # Poor registration # check('/mnt/1220/Public/dataset2/G4/3L6LOEER') # Poor registration, metric = 1.058 # check('/mnt/1220/Public/dataset2/G4/IUS424GK') # MASK zero after resample # check('/mnt/1220/Public/dataset2/G4/MCVGFTEI') # CTA as based image :( still need to read base image from XML # exit() EXCLUDE = ( '2FHZOOLU', # CT0 is cervical # '2KCGE3UG', #cervical, but acceptable '53JAUWVU', #cervical '7PS3CJUR', #cervical 'ED4NQNMK', #cervical 'IUS424GK', #strange mask 'MCVGFTEI', #CTA as base image :( ) EXCLUDE = ( ) os.makedirs(OUT_ROOT, exist_ok=True) d = shelve.open(SHELVE) # open -- file may get suffix added by low-level for e in sorted(os.scandir(PATIENTS_ROOT), key=lambda e: e.name): if e.is_dir(): # print(e.name) if e.name in d or e.name in EXCLUDE: print('skip', e.name) continue ret = check(e.path) d[e.name] = ret d.sync() # exit() d.close() if __name__ == '__main__': main()