ck-preprocess/m6resample.py
2025-02-02 06:30:53 +08:00

544 lines
16 KiB
Python

import json
import os
import shelve
import shutil
import time
import ants
import matplotlib.pyplot as plt
import numpy as np
import SimpleITK as sitk
PATIENTS_ROOT = '/mnt/1220/Public/dataset2/M6'
OUT_ROOT = '/home/xfr/git9/Taipei-1/0'
OUT_ROOT = '/mnt/1220/Public/dataset2/M7'
SHELVE = os.path.join(OUT_ROOT, '0shelve')
MAX_Y = 256
SIZE_X = 249
SIZE_Y = 249
SIZE_Z = 192
# SIZE_Z = 256
MIN_OVERLAP = 0.50
MIN_METRIC = -0.50
# def resize_with_crop_or_pad(image, tx = SIZE_X, ty = SIZE_Y, tz = SIZE_Z):
def resize_with_pad(image, tx = SIZE_X, ty = SIZE_Y, tz = SIZE_Z):
sx, sy, sz = image.GetSize()
l = [(tx-sx)//2,
(ty-sy)//2,
(tz-sz)//2,]
u = [tx-sx-l[0],
ty-sy-l[1],
tz-sz-l[2],
]
# print (l, u)
return sitk.ConstantPad(image, l, u)
def draw_sitk(image, d, post):
a = sitk.GetArrayFromImage(image)
s = a.shape
fig, axs = plt.subplots(1, 3)
# fig.suptitle('%dx%dx%d'%(s[2], s[1], s[0]))
axs.flat[0].imshow(a[s[0]//2,:,:], cmap='gray')
axs.flat[1].imshow(a[:,s[1]//2,:], cmap='gray')
axs.flat[2].imshow(a[:,:,s[2]//2], cmap='gray')
axs.flat[0].axis('off')
axs.flat[1].axis('off')
axs.flat[2].axis('off')
axs.flat[1].invert_yaxis()
axs.flat[2].invert_yaxis()
plt.tight_layout()
os.makedirs(d, exist_ok=True)
plt.savefig(os.path.join(d, '%dx%dx%d-%s'%(s[2],s[1],s[0],post)))
plt.close()
# exit()
def bbox2_3D(img):
r = np.any(img, axis=(1, 2))
c = np.any(img, axis=(0, 2))
z = np.any(img, axis=(0, 1))
if not np.any(r):
return -1, -1, -1, -1, -1, -1
rmin, rmax = np.where(r)[0][[0, -1]]
cmin, cmax = np.where(c)[0][[0, -1]]
zmin, zmax = np.where(z)[0][[0, -1]]
return rmin, rmax, cmin, cmax, zmin, zmax
def physical_size(a):
A = a.numpy().shape
B = a.spacing
C = (
A[0]*B[0],
A[1]*B[1],
A[2]*B[2],
)
print(A,B,C)
return C
def registration(ct0, ct1, mr, modality = 'CT'):
# print(
# physical_size(ct0),
# physical_size(moving)
# )
TX = []
FIXS = (ct0, ct1)
# FIXS = (ct1,)
TYPES = (
# 'Translation',
'Rigid',
'QuickRigid',
'DenseRigid',
'BOLDRigid',
# 'Similarity',
# 'Affine',
# 'AffineFast',
# 'BOLDAffine',
# 'TRSAA',
)
if modality == 'CT':
FIXS = (ct1, ct0)
# TYPES = (
# # 'Translation',
# 'Rigid',
# )
start = time.time()
for ff in range(len(FIXS)):
for typ in TYPES:
ct = FIXS[ff]
print(typ)
# '''
mytx = ants.registration(ct, mr, typ)
ones = np.ones(mr.numpy().shape)
mask1 = mr.new_image_like(ones)
mytx['mask'] = ants.apply_transforms(ct0, mask1, mytx['fwdtransforms'],
interpolator='genericLabel',
)
am = ants.create_ants_metric(ct,
mytx['warpedmovout'],
metric_type='MattesMutualInformation',
moving_mask = mytx['mask'],
)
mytx['metricFM'] = am.get_value()
am = ants.create_ants_metric(mr,
mytx['warpedfixout'],
metric_type='MattesMutualInformation')
mytx['metricMM'] = am.get_value()
mytx['metric'] = min([mytx['metricFM'], mytx['metricMM']])
mytx['ct'] = ff
mytx['type'] = (typ, 'fwd')
mytx['warpedout'] = mytx['warpedmovout']
TX.append(mytx)
print(mytx['metric'], mytx['metricFM'], mytx['metricMM'])
# '''
mytx = ants.registration(mr, ct, typ)
ones = np.ones(mr.numpy().shape)
mask1 = mr.new_image_like(ones)
mytx['mask'] = ants.apply_transforms(ct0, mask1, mytx['fwdtransforms'],
interpolator='genericLabel',
whichtoinvert=[True],
)
am = ants.create_ants_metric(mr,
mytx['warpedmovout'],
metric_type='MattesMutualInformation')
mytx['metricFM'] = am.get_value()
am = ants.create_ants_metric(ct,
mytx['warpedfixout'],
metric_type='MattesMutualInformation',
fixed_mask = mytx['mask'],
)
mytx['metricMM'] = am.get_value()
mytx['metric'] = min([mytx['metricFM'], mytx['metricMM']])
mytx['ct'] = ff
mytx['type'] = (typ, 'inv')
mytx['warpedout'] = mytx['warpedfixout']
TX.append(mytx)
print(mytx['metric'], mytx['metricFM'], mytx['metricMM'])
if min([t['metric'] for t in TX]) < MIN_METRIC:
break
if min([t['metric'] for t in TX]) < MIN_METRIC:
break
print(time.time()-start, 'seconds')
tx = {
'metric': 0,
}
for t in TX:
if t['metric'] < tx['metric']:
tx = t
tx['ratio'] = tx['mask'].numpy().sum() / np.prod(ct0.numpy().shape)
return tx
def check(epath):
roi=None
for root, dirs, files in os.walk(epath):
dirs.sort()
RT_DIR = os.path.join(root, 'RT')
EXTERNAL_DIR = os.path.join(RT_DIR, 'EXTERNAL')
if not os.path.isdir(EXTERNAL_DIR):
continue
ORGAN_DIR = os.path.join(RT_DIR, 'ORGAN')
if not os.path.isdir(ORGAN_DIR):
continue
# if there is no eye, it's no a brain image
eye = None
organs = sorted(os.scandir(ORGAN_DIR), key=lambda e: e.name)
for o in organs:
if 'eye' in o.name.lower():
eye = o
if eye is None:
print('no eye... skip', root)
# exit()
return None
CT_IMAGE = os.path.join(RT_DIR, 'ct_image.nii.gz')
externals = sorted(os.scandir(EXTERNAL_DIR), key=lambda e: -e.stat().st_size)
# for e in externals:
# print(e.name, e.stat().st_size)
print(externals[0].path)
external = sitk.ReadImage(externals[0].path)
# print(external.GetSize())
# print(external.GetSpacing())
sz = external.GetSize()
sp = external.GetSpacing()
outputSize = (
round(sz[0]*sp[0]),
round(sz[1]*sp[1]),
round(sz[2]*sp[2]),
)
outputOrigin = external.GetOrigin()
outputSpacing = (1., 1., 1.)
outputDirection = external.GetDirection()
mask = sitk.Resample(external, outputSize, sitk.Transform(), sitk.sitkNearestNeighbor, outputOrigin, outputSpacing, outputDirection)
# sitk.WriteImage(mask, '0mask.nii.gz')
# print(mask.GetSize())
# print(mask.GetSpacing())
lsif = sitk.LabelShapeStatisticsImageFilter()
lsif.Execute(mask)
if 1 not in lsif.GetLabels():
print('strange mask... skip', root)
exit()
boundingBox = lsif.GetBoundingBox(1)
print(boundingBox)
ix, iy, iz, sx, sy, sz = boundingBox
sy = min(sy, MAX_Y)
sz0 = min(sz, SIZE_Z-2) # will pad later
iz0 = iz+sz-sz0
roi = sitk.RegionOfInterest(mask, (sx, sy, sz0), (ix, iy, iz0))
# remove frame
ER = 10
ER = 29
ER = 44
ER = 40 # remove most of it
ER = 37 # remove most of it
roi = sitk.BinaryErode(roi, (ER,ER,ER))
roi2 = sitk.BinaryDilate(roi, (ER,ER,ER))
# sitk.WriteImage(roi, '0roi.nii.gz')
# exit()
# print(external.GetSize())
# print(roi.GetSize())
#Now no shoulder, get BB again
lsif.Execute(roi2)
# Something wrong? 6L63PXNV
# print(lsif.GetLabels())
if 1 not in lsif.GetLabels():
return None
boundingBox = lsif.GetBoundingBox(1)
print(boundingBox)
ix2, iy2, iz2, sx2, sy2, sz2 = boundingBox
# sy2 = sy
iy2 = 0
# iy2 = min(iy2, roi2.GetSize()[1]-sy2)
sy2 = (sy+sy2)//2
print((sx2, sy2, sz2), (ix2, iy2, iz2))
roi = sitk.RegionOfInterest(roi2, (sx2, sy2, sz0), (ix2, iy2, 0))
# roi = resize_with_pad(roi)
# roi = sitk.FFTPad(roi)
ypad = min(2, (SIZE_Y-sy2)//2)
roi = sitk.ConstantPad(roi, (1,ypad,1), (1,ypad,1))
# roi = sitk.ConstantPad(roi, (0,0,0), (0,SIZE_Y-sy2,0))
print(external.GetSize())
print(roi.GetSize())
# sitk.WriteImage(external, '0external.nii.gz')
# sitk.WriteImage(roi, '0roi.nii.gz')
# exit()
ct = sitk.ReadImage(CT_IMAGE,
# sitk.sitkFloat32
)
CTresampled = sitk.Resample(ct, roi,
# sitk.Transform(),
# sitk.sitkBSpline,
)
rescaled = sitk.RescaleIntensity(sitk.Clamp(CTresampled,
sitk.sitkUInt8,
# sitk.sitkFloat32,
0, 80))
TV = None
for e in sorted(os.scandir(RT_DIR), key=lambda e: e.name):
if not e.name.startswith('TV-'):
continue
tv = sitk.ReadImage(e.path)
TVresampled = sitk.Resample(tv, roi,
sitk.Transform(),
sitk.sitkNearestNeighbor,
)
if TV is None:
TV = TVresampled
else:
TV = sitk.Maximum(TV, TVresampled)
if TV is None:
continue
EXresampled = sitk.Resample(external, roi,
sitk.Transform(),
sitk.sitkNearestNeighbor,
)
relpath = os.path.relpath(root, PATIENTS_ROOT)
OUT_DIR = os.path.join(OUT_ROOT, relpath)
os.makedirs(OUT_DIR, exist_ok=True)
CT0 = os.path.join(OUT_DIR, 'ct0.nii.gz')
CT1 = os.path.join(OUT_DIR, 'ct1.nii.gz')
print(CT0)
sitk.WriteImage(CTresampled, CT0)
sitk.WriteImage(rescaled, CT1)
sitk.WriteImage(TV+EXresampled, os.path.join(OUT_DIR, 'label.nii.gz'))
draw_sitk(rescaled, os.path.join(OUT_ROOT, '0'), '%s.png'%relpath.replace('/', '-'))
# exit()
if max(roi.GetSize()) >= SIZE_X+64: #BODY
continue
if max(roi.GetSize()) > SIZE_X:
exit()
# exit()
# continue # skip registration first
ct0 = ants.image_read(CT0)
ct1 = ants.image_read(CT1)
# fixed = ants.image_read(CT0)
for root2, dirs2, files2 in os.walk(root):
dirs2.sort()
skip = (root2==root) or ('RT' in root2.split('/'))
if skip:
continue
if root2.endswith('CT'):
modality = 'CT'
fixed=ct1
else:
modality = 'other'
fixed=ct1
print(skip, root2, modality)
relpath = os.path.relpath(root2, PATIENTS_ROOT)
out_dir = os.path.join(OUT_ROOT, relpath)
for mv in sorted(os.scandir(root2), key=lambda e: e.name):
if not mv.name.endswith('nii.gz'):
continue
if mv.name.startswith('dose'):
continue
if '_DTI_' in mv.name:
continue
if '_RTDOSE_' in mv.name:
continue
if '_ROI1.' in mv.name:
continue
# if '_swi_' in mv.name:
# continue
# if '_ph.' in mv.name:
# continue
# if '_Doses_' in mv.name:
# continue
# if '_phMag.' in mv.name:
# continue
OUT_IMG = os.path.join(out_dir, mv.name)
if os.path.isfile(OUT_IMG):
print('skip', OUT_IMG)
continue
split = mv.name[:-7].split('_')
series = None
pos = -1
while series is None:
try:
series = int(split[pos])
except:
pos -= 1
print(mv.name, series)
# if series > 99:
# continue
moving = ants.image_read(mv.path)
# skip thin images
_, _, _, _, zmin, zmax = bbox2_3D(moving.numpy())
mv_height = moving.spacing[2] * (zmax-zmin)
z_ratio = mv_height / SIZE_Z
if z_ratio < MIN_OVERLAP:
print("skip", mv.name, mv_height, z_ratio)
continue
print(mv.path, moving.view().shape, mv_height, z_ratio)
if moving.max() == 0.0:
print('Total Mass of the image was zero. Aborting here to prevent division by zero later on.')
continue
mytx = registration(ct0, ct1, moving, modality)
print(mytx['ratio'], mytx['metric'])
# if (mytx['metric'] < MIN_METRIC) and (mytx['ratio'] < MIN_OVERLAP):
# print("skip", mv.name, mytx['ratio'])
# continue
os.makedirs(out_dir, exist_ok=True)
OUT_MSK = OUT_IMG.replace('.nii.gz', '.mask.nii.gz')
OUT_JSON = OUT_IMG.replace('.nii.gz', '.json')
OUT_MAT = OUT_IMG.replace('.nii.gz', '.mat')
jj = {
'ct' : mytx['ct'],
'type' : mytx['type'],
'metric': mytx['metric'],
'ratio' : mytx['ratio'],
}
with open(OUT_JSON, 'w') as f:
json.dump(jj, f, indent=1)
shutil.copy(mytx['fwdtransforms'][0], OUT_MAT)
ants.image_write(mytx['mask'], OUT_MSK)
ants.image_write(mytx['warpedout'], OUT_IMG)
metric = mytx['metric']
metric_dir = os.path.join(OUT_ROOT, "1")
OUT_TXT = os.path.join(metric_dir, '%f-%s'%(metric, mv.name.replace('.nii.gz', '.txt')))
jj['dir'] = out_dir
os.makedirs(metric_dir, exist_ok=True)
with open(OUT_TXT, 'w') as f:
json.dump(jj, f, indent=1)
# exit()
# exit()
if roi is None:
return None
return roi.GetSize()
def main():
# check('/mnt/1220/Public/dataset2/M6/QT6QVY34') # Lung
# check('/mnt/1220/Public/dataset2/M6/24V62QQM') # index limit
# check('/mnt/1220/Public/dataset2/M6/WD34XIQO') # index limit
# check('/mnt/1220/Public/dataset2/M6/26TYQI3C') # bad registration, need DenseRigid for 1 CT only
# check('/mnt/1220/Public/dataset2/M6/5BXKANST') # low score
# check('/mnt/1220/Public/dataset2/M6/BQ773ST6') # bbox problem
# check('/mnt/1220/Public/dataset2/M6/OJVMKLLO') # series naming conventions
# check('/mnt/1220/Public/dataset2/M6/PBH6IPXE') # TOF
# check('/mnt/1220/Public/dataset2/M6/SYPFKLU4') # strange frame, eclipse doses
# check('/mnt/1220/Public/dataset2/M6/TCZXFBY3') # bad registration
# exit()
EXCLUDE = (
'LLUQJUY4', #cervical
)
EXCLUDE = ()
os.makedirs(OUT_ROOT, exist_ok=True)
d = shelve.open(SHELVE) # open -- file may get suffix added by low-level
for e in sorted(os.scandir(PATIENTS_ROOT), key=lambda e: e.name):
if e.is_dir():
if e.name in d or e.name in EXCLUDE:
print('skip', e.name)
continue
ret = check(e.path)
d[e.name] = ret
d.sync()
# exit()
d.close()
if __name__ == '__main__':
main()