ck-preprocess/m6register.py

381 lines
12 KiB
Python
Raw Normal View History

2025-02-01 22:30:53 +00:00
'''
metric > -0.27 is bad registration
'''
import json
import os
# import pathlib
import shelve
import shutil
import time
import ants
import filelock
import matplotlib.pyplot as plt
import numpy as np
import SimpleITK as sitk
# import skimage
PATIENTS_ROOT = '/mnt/1220/Public/dataset2/M6'
PATIENTS_ROOT = '/mnt/1220/Public/dataset2/M6_24Q3/nii/'
OUT_ROOT = '/mnt/1220/Public/dataset2/N6'
OUT_ROOT = '/mnt/1220/Public/dataset2/N6_24Q3.1'
SHELVE = os.path.join(OUT_ROOT, '0shelve')
MAX_Y = 256
SIZE_X = 249
SIZE_Y = 249
SIZE_Z = 192
# SIZE_Z = 256
MIN_OVERLAP = 0.50
MIN_METRIC = -0.50
# def resize_with_crop_or_pad(image, tx = SIZE_X, ty = SIZE_Y, tz = SIZE_Z):
def resize_with_pad(image, tx = SIZE_X, ty = SIZE_Y, tz = SIZE_Z):
sx, sy, sz = image.GetSize()
l = [(tx-sx)//2,
(ty-sy)//2,
(tz-sz)//2,]
u = [tx-sx-l[0],
ty-sy-l[1],
tz-sz-l[2],
]
# print (l, u)
return sitk.ConstantPad(image, l, u)
def draw_sitk(image, d, post):
a = sitk.GetArrayFromImage(image)
s = a.shape
fig, axs = plt.subplots(1, 3)
# fig.suptitle('%dx%dx%d'%(s[2], s[1], s[0]))
axs.flat[0].imshow(a[s[0]//2,:,:], cmap='gray')
axs.flat[1].imshow(a[:,s[1]//2,:], cmap='gray')
axs.flat[2].imshow(a[:,:,s[2]//2], cmap='gray')
axs.flat[0].axis('off')
axs.flat[1].axis('off')
axs.flat[2].axis('off')
axs.flat[1].invert_yaxis()
axs.flat[2].invert_yaxis()
plt.tight_layout()
os.makedirs(d, exist_ok=True)
plt.savefig(os.path.join(d, '%dx%dx%d-%s'%(s[2],s[1],s[0],post)))
plt.close()
# exit()
def bbox2_3D(img):
r = np.any(img, axis=(1, 2))
c = np.any(img, axis=(0, 2))
z = np.any(img, axis=(0, 1))
if not np.any(r):
return -1, -1, -1, -1, -1, -1
rmin, rmax = np.where(r)[0][[0, -1]]
cmin, cmax = np.where(c)[0][[0, -1]]
zmin, zmax = np.where(z)[0][[0, -1]]
return rmin, rmax, cmin, cmax, zmin, zmax
def registration(ct0, ct1, mr):
ct0n = ct0.numpy()
mrn = mr.numpy()
if np.array_equal(ct0n, mrn):
print('EQUAL')
return {
'fwdtransforms': [],
'warpedfixout': ct0,
'warpedmovout': ct0,
'ct': 0,
'type': 'Identity',
'metric': -2,
'ratio': 1,
}
if ct0n.shape == mrn.shape:
print('SAME SHAPE')
CTS = (ct0, ct1)
TYPES = (
# 'Translation',
'Rigid',
'QuickRigid',
'DenseRigid',
'BOLDRigid',
)
else:
print('others', mrn.shape)
if min(mrn.shape) < 4:
print('skip')
return None
CTS = (ct0, ct1)
TYPES = (
# 'Translation',
'Rigid',
'QuickRigid',
'DenseRigid',
'BOLDRigid',
)
TX = []
start = time.time()
for m in range(len(CTS)):
for typ in TYPES:
ct = CTS[m]
print(typ)
mytx = ants.registration(ct, mr, typ)
ones = np.ones(mr.numpy().shape)
mask1 = mr.new_image_like(ones)
mytx['mask'] = ants.apply_transforms(ct0, mask1, mytx['fwdtransforms'],
interpolator='genericLabel',
)
mytx['metric0M'] = ants.create_ants_metric(ct0,
mytx['warpedmovout'],
metric_type='MattesMutualInformation',
moving_mask = mytx['mask'],
).get_value()
mytx['metric1M'] = ants.create_ants_metric(ct1,
mytx['warpedmovout'],
metric_type='MattesMutualInformation',
moving_mask = mytx['mask'],
).get_value()
mytx['metricMM'] = ants.create_ants_metric(mr,
mytx['warpedfixout'],
metric_type='MattesMutualInformation',
).get_value()
mytx['metric'] = min([mytx['metric0M'], mytx['metric1M'], mytx['metricMM']])
mytx['ct'] = m
mytx['type'] = (typ, 'fwd')
mytx['warpedout'] = mytx['warpedmovout']
TX.append(mytx)
print(mytx['metric'], mytx['metric0M'], mytx['metric1M'], mytx['metricMM'])
# '''
mytx = ants.registration(mr, ct, typ)
ones = np.ones(mr.numpy().shape)
mask1 = mr.new_image_like(ones)
mytx['mask'] = ants.apply_transforms(ct0, mask1, mytx['fwdtransforms'],
interpolator='genericLabel',
whichtoinvert=[True],
)
mytx['metricMM'] = ants.create_ants_metric(mr,
mytx['warpedmovout'],
metric_type='MattesMutualInformation',
).get_value()
mytx['metric0M'] = ants.create_ants_metric(ct0,
mytx['warpedfixout'],
metric_type='MattesMutualInformation',
fixed_mask = mytx['mask'],
).get_value()
mytx['metric1M'] = ants.create_ants_metric(ct1,
mytx['warpedfixout'],
metric_type='MattesMutualInformation',
fixed_mask = mytx['mask'],
).get_value()
mytx['metric'] = min([mytx['metricMM'], mytx['metric0M'], mytx['metric1M']])
mytx['ct'] = m
mytx['type'] = (typ, 'inv')
mytx['warpedout'] = mytx['warpedfixout']
TX.append(mytx)
print(mytx['metric'], mytx['metricMM'], mytx['metric0M'], mytx['metric1M'])
if min([t['metric'] for t in TX]) < MIN_METRIC:
break
if min([t['metric'] for t in TX]) < MIN_METRIC:
break
print(time.time()-start, 'seconds')
# exit()
tx = {
'metric': 0,
}
for t in TX:
if t['metric'] < tx['metric']:
tx = t
tx['ratio'] = tx['mask'].numpy().sum() / np.prod(ct0.numpy().shape)
return tx
def check(epath):
registered = 0
for root, dirs, files in os.walk(epath):
dirs.sort()
RT_DIR = os.path.join(root, 'RT')
ORGAN_DIR = os.path.join(RT_DIR, 'ORGAN')
if not os.path.isdir(ORGAN_DIR):
continue
# if there is no eye, it's no a brain image
eye = None
organs = sorted(os.scandir(ORGAN_DIR), key=lambda e: e.name)
for o in organs:
if 'eye' in o.name.lower():
eye = o
if eye is None:
print('no eye... skip', root)
# exit()
return None
ct_image = os.path.join(RT_DIR, 'ct_image.nii.gz')
ct0 = sitk.ReadImage(ct_image)
ct1 = sitk.Clamp(ct0, sitk.sitkUInt8, 0, 80)
print(ct_image, ct0.GetSize())
outdir = os.path.join(OUT_ROOT, os.path.relpath(root, PATIENTS_ROOT))
print(outdir)
os.makedirs(outdir, exist_ok=True)
ct0_nii = os.path.join(outdir, 'ct0.nii.gz')
ct1_nii = os.path.join(outdir, 'ct1.nii.gz')
shutil.copy(ct_image, ct0_nii)
sitk.WriteImage(ct1, ct1_nii)
# sitk.WriteImage(sitk.DICOMOrient(ct0), ct0_nii)
# sitk.WriteImage(sitk.DICOMOrient(ct1), ct1_nii)
ct0 = ants.image_read(ct0_nii, reorient=True)
ct1 = ants.image_read(ct1_nii, reorient=True)
for root2, dirs2, files2 in os.walk(root):
dirs2.sort()
skip = (root2==root) or ('RT' in root2.split('/'))
if skip:
continue
if root2.endswith('CT'):
modality = 'CT'
# continue
else:
modality = 'other'
print(skip, root2, modality)
outdir = os.path.join(OUT_ROOT, os.path.relpath(root2, PATIENTS_ROOT))
os.makedirs(outdir, exist_ok=True)
for e in sorted(os.scandir(root2), key=lambda e: e.name):
if not e.name.endswith('.nii.gz'):
continue
if '_RTDOSE_' in e.name:
continue
if '_DTI_' in e.name:
continue
if '_ROI1.' in e.name:
continue
OUT_IMG = os.path.join(outdir, e.name)
if os.path.isfile(OUT_IMG):
print('skip', OUT_IMG)
continue
print(e.name, e.path)
fix = ants.image_read(e.path, reorient=True)
mytx = registration(ct0, ct1, fix)
if mytx is None:
continue
print(mytx['ratio'], mytx['metric'])
registered += 1
os.makedirs(outdir, exist_ok=True)
OUT_WARP = OUT_IMG.replace('.nii.gz', '.warp.nii.gz')
OUT_MSK = OUT_IMG.replace('.nii.gz', '.mask.nii.gz')
OUT_JSON = OUT_IMG.replace('.nii.gz', '.json')
OUT_MAT = OUT_IMG.replace('.nii.gz', '.mat')
jj = {
'ct' : mytx['ct'],
'type' : mytx['type'],
'metric': mytx['metric'],
'ratio' : mytx['ratio'],
}
with open(OUT_JSON, 'w') as f:
json.dump(jj, f, indent=1)
if mytx['fwdtransforms']:
shutil.copy(mytx['fwdtransforms'][0], OUT_MAT)
ants.image_write(mytx['mask'], OUT_MSK)
ants.image_write(mytx['warpedout'], OUT_WARP)
metric = mytx['metric']
metric_dir = os.path.join(OUT_ROOT, "1")
OUT_TXT = os.path.join(metric_dir, '%f-%s'%(metric, e.name.replace('.nii.gz', '.txt')))
jj['dir'] = outdir
os.makedirs(metric_dir, exist_ok=True)
with open(OUT_TXT, 'w') as f:
json.dump(jj, f, indent=1)
shutil.copy(e.path, OUT_IMG)
return registered
def main():
# check('/mnt/1220/Public/dataset2/M6/QT6QVY34') # Lung
# check('/mnt/1220/Public/dataset2/M6/24V62QQM') # index limit
# check('/mnt/1220/Public/dataset2/M6/WD34XIQO') # index limit
# check('/mnt/1220/Public/dataset2/M6/26TYQI3C') # bad registration, need DenseRigid for 1 CT only
# check('/mnt/1220/Public/dataset2/M6/5BXKANST') # low score
# check('/mnt/1220/Public/dataset2/M6/BQ773ST6') # bbox problem
# check('/mnt/1220/Public/dataset2/M6/OJVMKLLO') # series naming conventions
# check('/mnt/1220/Public/dataset2/M6/PBH6IPXE') # TOF
# check('/mnt/1220/Public/dataset2/M6/SYPFKLU4') # strange frame, eclipse doses
# check('/mnt/1220/Public/dataset2/M6/TCZXFBY3') # bad registration
# exit()
EXCLUDE = (
# 'LLUQJUY4', #cervical
)
os.makedirs(OUT_ROOT, exist_ok=True)
LOCK_DIR = os.path.join(OUT_ROOT, '0lock')
os.makedirs(LOCK_DIR, exist_ok=True)
for e in sorted(os.scandir(PATIENTS_ROOT), key=lambda e: e.name):
if e.is_dir():
d = shelve.open(SHELVE)
if e.name in d or e.name in EXCLUDE:
print('skip', e.name)
d.close()
continue
d.close()
lock_path = os.path.join(LOCK_DIR, '%s.lock'%e.name)
lock = filelock.FileLock(lock_path, timeout=1)
try:
lock.acquire()
except:
print(lock_path, 'locked')
continue
ret = check(e.path)
lock.release()
# exit()
d = shelve.open(SHELVE)
d[e.name] = ret
d.close()
if __name__ == '__main__':
main()