diff --git a/src/g4synthmorph.py b/src/g4synthmorph.py index ac1ea6b..37ada25 100644 --- a/src/g4synthmorph.py +++ b/src/g4synthmorph.py @@ -12,6 +12,16 @@ XLA_FLAGS=--xla_gpu_cuda_data_dir=/home/xfr/.conda/envs/25reg time ./mri_synthmo XLA_FLAGS=--xla_gpu_cuda_data_dir=/home/xfr/.conda/envs/25reg time mri_synthmorph/mri_synthmorph -m affine -o affine.nii.gz -g moving.nii.gz clipped.nii.gz + + +find /mnt/1218/Public/dataset2/G4-synthmorph/ -iname metrics.json -exec grep -H "1.*," {} ";"|sort -k 2 -n|head + +find /mnt/1218/Public/dataset2/G4-synthmorph/ -iname metrics1.json -exec grep -H ":" {} ";"|sort -k 3 -n|head + +find /mnt/1218/Public/dataset2/G4-synthmorph/ -iname metrics1.json -exec grep -H ":" {} ";"|grep joint|sort -k 3 -n|head -n 20 + +bad registration if metric1 < 1.09 + ''' from pathlib import Path @@ -37,8 +47,8 @@ from mri_synthmorph.synthmorph import registration import surfa as sf -PATIENTS_ROOT = '/mnt/1220/Public/dataset2/G4' -OUT_ROOT = '/mnt/1220/Public/dataset2/G4-synthmorph' +PATIENTS_ROOT = '/mnt/1218/Public/dataset2/G4' +OUT_ROOT = '/mnt/1218/Public/dataset2/G4-synthmorph' SHELVE = os.path.join(OUT_ROOT, '0shelve') MAX_Y = 256 @@ -98,8 +108,10 @@ def register(ct0, ct1, moving, out_root): logger.info(' '.join((ct0, ct1, moving, str(out_root)))) orig = sf.load_volume(moving) - base = sf.load_volume(ct0) + base1 = sf.load_volume(ct1) + + if modality == 'CT': clipped = out_root/'clipped.nii.gz' @@ -173,6 +185,11 @@ def register(ct0, ct1, moving, out_root): # print(fill) # exit() + METRICS0 = {} + METRICS1 = {} + + inp1 = None + inp2 = None for m in MODELS: default['model'] = m @@ -194,35 +211,47 @@ def register(ct0, ct1, moving, out_root): registration.register(arg) logger.info('registered %s'%m) - if m in ( - 'rigid', - 'affine', - 'joint', - ): - # which = 'affine' if arg.trans.endswith('.lta') else 'warp' - out = out_root/('%s.nii.gz'%m) - if m in ['affine', 'rigid']: - trans = sf.load_affine(default['out_dir']/'tra_1.lta') - prop = dict(method='linear', resample=True, fill=fill) - orig.transform(trans, **prop).resample_like(base, fill=fill).save(out) - logger.info('transformed %s'%out) - # print(prop) - # exit() - else: - # need to resample before transform in warp, too complicated, just copy it - # trans1 = default['out_dir']/'tra_1.nii.gz' - # trans = sf.load_warp(trans1) + if inp1 == None: + inp1 = sf.load_volume(default['out_dir']/'inp_1.nii.gz') + if inp2 == None: + inp2 = sf.load_volume(default['out_dir']/'inp_2.nii.gz') - sf.load_volume(default['out_dir']/'out_1.nii.gz').resample_like(base, fill=fill).save(out) - logger.info('resampled %s'% out) + out1 = sf.load_volume(default['out_dir']/'out_1.nii.gz') + out2 = sf.load_volume(default['out_dir']/'out_2.nii.gz') - with open(out_root/'metric.txt', 'w') as f_metrics: - for m in MODELS: - out1 = sf.load_volume(out_root/m/'out_1.nii.gz').data - inp2 = sf.load_volume(out_root/m/'inp_2.nii.gz').data - met = normalized_mutual_information(out1, inp2) - f_metrics.write('%s\t%f\n'%(m, met)) + + out = out_root/('%s.nii.gz'%m) + if m in ['affine', 'rigid']: + trans = sf.load_affine(default['out_dir']/'tra_1.lta') + prop = dict(method='linear', resample=True, fill=fill) + resampled = orig.transform(trans, **prop).resample_like(base, fill=fill) + logger.info('transformed %s'%out) + # print(prop) + # exit() + else: + # need to resample before transform in warp, too complicated, just copy it + # trans1 = default['out_dir']/'tra_1.nii.gz' + # trans = sf.load_warp(trans1) + + resampled = out1.resample_like(base, fill=fill) + logger.info('resampled %s'% out) + + resampled.save(out) + + inp1_out2 = normalized_mutual_information(inp1.data, out2.data) + inp2_out1 = normalized_mutual_information(inp2.data, out1.data) + m0 = normalized_mutual_information(base.data, resampled.data) + m1 = normalized_mutual_information(base1.data, resampled.data) + + METRICS0[m] = (inp1_out2, inp2_out1, m0, m1) + METRICS1[m] = max(inp1_out2, inp2_out1, m0, m1) + + with open(out_root/'metrics0.json', 'w') as f_metrics: + json.dump(METRICS0, f_metrics, indent=1) + + with open(out_root/'metrics1.json', 'w') as f_metrics: + json.dump(METRICS1, f_metrics, indent=1) return out_root @@ -314,7 +343,9 @@ def check(epath): def main(): - # check('/mnt/1220/Public/dataset2/G4/3L6LOEER') # bad registration + # check('/mnt/1218/Public/dataset2/G4/22M5LAGD') # first case + # check('/mnt/1218/Public/dataset2/G4/2FHZOOLU') # bad registration - cervical + # check('/mnt/1218/Public/dataset2/G4/2EL6U5TF') # bad registration # exit() EXCLUDE = ( diff --git a/src/m6synthmorph.py b/src/m6synthmorph.py new file mode 100644 index 0000000..5950097 --- /dev/null +++ b/src/m6synthmorph.py @@ -0,0 +1,389 @@ + +''' +Use SynthMorph to register M6 images + +https://download-directory.github.io/ +https://github.com/freesurfer/freesurfer/tree/dev/mri_synthmorph + + +CUDA_VISIBLE_DEVICES=3 python m6synthmorph.py + + + +XLA_FLAGS=--xla_gpu_cuda_data_dir=/home/xfr/.conda/envs/25reg time ./mri_synthmorph -m affine -o ../test.nii.gz -g '/mnt/1218/Public/dataset2/M6/ZYRGTRKJ/20230728/MR/3D_SAG_T1_MPRAGE_+C_MPR_Tra_20230728143005_14.nii.gz' '/mnt/1218/Public/dataset2/M6/ZYRGTRKJ/20230728/CT/1.1_CyberKnife_head(MAR)_20230728111920_3.nii.gz' + +XLA_FLAGS=--xla_gpu_cuda_data_dir=/home/xfr/.conda/envs/25reg time mri_synthmorph/mri_synthmorph -m affine -o affine.nii.gz -g moving.nii.gz clipped.nii.gz + + + + +find /mnt/1218/Public/dataset2/G4-synthmorph/ -iname metrics.json -exec grep -H "1.*," {} ";"|sort -k 2 -n|head + +find /mnt/1218/Public/dataset2/G4-synthmorph/ -iname metrics1.json -exec grep -H ":" {} ";"|sort -k 3 -n|head + +find /mnt/1218/Public/dataset2/G4-synthmorph/ -iname metrics1.json -exec grep -H ":" {} ";"|grep joint|sort -k 3 -n|head -n 20 + +bad registration if metric1 < 1.09 + +''' + +from pathlib import Path + +import argparse +import logging +import json +import os +# import pathlib +import shelve +import shutil +import time + +from skimage.metrics import normalized_mutual_information + +import filelock +import matplotlib.pyplot as plt +import numpy as np +# import SimpleITK as sitk + +from mri_synthmorph.synthmorph import registration +# from synthmorph import registration + +import surfa as sf + +PATIENTS_ROOT = '/mnt/1218/Public/dataset2/M6' +OUT_ROOT = '/mnt/1218/Public/dataset2/M6-synthmorph' +SHELVE = os.path.join(OUT_ROOT, '0shelve') + +MAX_Y = 256 + +SIZE_X = 249 +SIZE_Y = 249 +SIZE_Z = 192 +# SIZE_Z = 256 + +MIN_OVERLAP = 0.50 +MIN_METRIC = -0.50 + + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(levelname)s - %(message)s', + handlers=[ + logging.StreamHandler(), + logging.FileHandler('g4synthmorph.log') + ] +) +logger = logging.getLogger(__name__) + + +def bbox2_3D(img): + + r = np.any(img, axis=(1, 2)) + c = np.any(img, axis=(0, 2)) + z = np.any(img, axis=(0, 1)) + + if not np.any(r): + return -1, -1, -1, -1, -1, -1 + + rmin, rmax = np.where(r)[0][[0, -1]] + cmin, cmax = np.where(c)[0][[0, -1]] + zmin, zmax = np.where(z)[0][[0, -1]] + + return rmin, rmax, cmin, cmax, zmin, zmax + +''' +Namespace(command='register', moving='/nn/7295866/20250127/nii/7_3D_SAG_T1_MPRAGE_+C_20250127132612_100.nii.gz', fixed='/123/onlylian/0/tmpgp96622o/clipped.nii.gz', +model='joint', out_moving='/123/onlylian/0/tmpgp96622o/joint.nii.gz', out_fixed='/123/onlylian/0/tmpgp96622o/out_fixed-joint.nii.gz', +header_only=False, trans='/123/onlylian/0/tmpgp96622o/moving_to_fixed-joint.nii.gz', inverse='/123/onlylian/0/tmpgp96622o/fixed_to_moving-joint.nii.gz', +init=None, mid_space=False, threads=None, gpu=True, hyper=0.5, steps=7, extent=256, weights=None, verbose=False, out_dir=None) +''' +def register(ct0, ct1, moving, out_root): + FREESURFER_HOME = '/mnt/1218/Public/packages/freesurfer-8.0.0-beta/' + # out_root = Path(ct0).resolve().parent/os.path.basename(mr).replace('.nii.gz','') + + # print(out_root) + modality = os.path.basename(out_root) + # exit() + + out_root = Path(out_root)/os.path.basename(moving).replace('.nii.gz','') + out_root.mkdir(exist_ok=True) + + logger.info(' '.join((modality, ct0, ct1, moving, str(out_root)))) + + orig = sf.load_volume(moving) + base = sf.load_volume(ct0) + base1 = sf.load_volume(ct1) + + + + if modality == 'XA': + exit() + + if modality == 'CT': + clipped = out_root/'clipped.nii.gz' + + cl = orig.clip(0, 80) + cl.save(clipped) + + MODELS = [ + 'rigid', + # 'affine', + # 'joint', + ] + + else: + clipped = moving + MODELS = [ + 'rigid', + 'affine', + 'joint', + ] + + + + # exit() + + + + default = { + + 'command': 'register', + 'header_only': False, + 'init': None, + 'mid_space': False, + 'threads': None, + + # 'gpu': False, + 'gpu': True, + + 'verbose': False, + # 'verbose': True, + + 'hyper': 0.5, + 'steps': 7, + 'extent': 256, + + 'weights': None, + + # 'model': 'affine', + + # 'out_dir': None, + # 'out_fixed': 'out_fixed.nii.gz', + # 'out_moving': 'out_moving.nii.gz', + # 'trans': None, + # 'inverse': None, + + 'out_fixed': None, + 'out_moving': None, + 'trans': None, + 'inverse': None, + + 'moving' : clipped, + 'fixed' : ct1, + +# 'weights': str(Path(__file__).resolve().parent/'mri_synthmorph/models/synthmorph.affine.2.h5'), + + } + + os.environ["FREESURFER_HOME"] = FREESURFER_HOME + os.environ["XLA_FLAGS"] = '--xla_gpu_cuda_data_dir=%s'% os.environ["CONDA_PREFIX"] + + fill = orig.min() + # print(fill) + # exit() + + METRICS0 = {} + METRICS1 = {} + + inp1 = None + inp2 = None + + for m in MODELS: + default['model'] = m + default['out_dir'] = out_root/m + + # if m in ('affine', 'rigid'): + # default['trans'] = 'trans.lta' + # default['inverse'] = 'inverse.lta' + # else: + # default['trans'] = 'trans.nii.gz' + # default['inverse'] = 'inverse.nii.gz' + + arg=argparse.Namespace(**default) + +# CONDA_PREFIX=/home/xfr/.conda/envs/25reg +# XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda + + logger.info('registering %s'%m) + registration.register(arg) + logger.info('registered %s'%m) + + + if inp1 == None: + inp1 = sf.load_volume(default['out_dir']/'inp_1.nii.gz') + if inp2 == None: + inp2 = sf.load_volume(default['out_dir']/'inp_2.nii.gz') + + out1 = sf.load_volume(default['out_dir']/'out_1.nii.gz') + out2 = sf.load_volume(default['out_dir']/'out_2.nii.gz') + + + out = out_root/('%s.nii.gz'%m) + if m in ['affine', 'rigid']: + trans = sf.load_affine(default['out_dir']/'tra_1.lta') + prop = dict(method='linear', resample=True, fill=fill) + resampled = orig.transform(trans, **prop).resample_like(base, fill=fill) + logger.info('transformed %s'%out) + # print(prop) + # exit() + else: + # need to resample before transform in warp, too complicated, just copy it + # trans1 = default['out_dir']/'tra_1.nii.gz' + # trans = sf.load_warp(trans1) + + resampled = out1.resample_like(base, fill=fill) + logger.info('resampled %s'% out) + + resampled.save(out) + + inp1_out2 = normalized_mutual_information(inp1.data, out2.data) + inp2_out1 = normalized_mutual_information(inp2.data, out1.data) + m0 = normalized_mutual_information(base.data, resampled.data) + m1 = normalized_mutual_information(base1.data, resampled.data) + + METRICS0[m] = (inp1_out2, inp2_out1, m0, m1) + METRICS1[m] = max(inp1_out2, inp2_out1, m0, m1) + + with open(out_root/'metrics0.json', 'w') as f_metrics: + json.dump(METRICS0, f_metrics, indent=1) + + with open(out_root/'metrics1.json', 'w') as f_metrics: + json.dump(METRICS1, f_metrics, indent=1) + + return out_root + +def check(epath): + registered = 0 + for root, dirs, files in os.walk(epath): + dirs.sort() + + RT_DIR = os.path.join(root, 'RT') + + ORGAN_DIR = os.path.join(RT_DIR, 'ORGAN') + if not os.path.isdir(ORGAN_DIR): + continue + + # if there is no eye, it's no a brain image + eye = None + organs = sorted(os.scandir(ORGAN_DIR), key=lambda e: e.name) + for o in organs: + if 'eye' in o.name.lower(): + eye = o + if eye is None: + logger.info('no eye... skip ' + root) + # exit() + return None + + ct_image = os.path.join(RT_DIR, 'ct_image.nii.gz') + + + + outdir = os.path.join(OUT_ROOT, os.path.relpath(root, PATIENTS_ROOT)) + logger.info(outdir) + + os.makedirs(outdir, exist_ok=True) + # ct0_nii = os.path.join(outdir, 'ct0.nii.gz') + ct1_nii = os.path.join(outdir, 'clipped.nii.gz') + # shutil.copy(ct_image, ct0_nii) + + ct = sf.load_volume(ct_image) + clipped = ct.clip(0, 80) + clipped.save(ct1_nii) + + for root2, dirs2, files2 in os.walk(root): + dirs2.sort() + outdir = os.path.join(OUT_ROOT, os.path.relpath(root2, PATIENTS_ROOT)) + if root2.endswith('RT'): + modality = 'RT' + logger.info('copying %s %s' %(root2, outdir)) + shutil.copytree(root2, outdir, dirs_exist_ok=True) + # exit() + continue + skip = (root2==root) or ('RT' in root2.split('/')) + if skip: + continue + if root2.endswith('CT'): + modality = 'CT' + else: + modality = 'other' + logger.info(' '.join([str(skip), root2, modality])) + outdir = os.path.join(OUT_ROOT, os.path.relpath(root2, PATIENTS_ROOT)) + os.makedirs(outdir, exist_ok=True) + for e in sorted(os.scandir(root2), key=lambda e: e.name): + if not e.name.endswith('.nii.gz'): + continue + if '_RTDOSE_' in e.name: + continue + if '_DTI_' in e.name: + continue + if '_ROI1.' in e.name: + continue + + OUT_IMG = os.path.join(outdir, e.name) + if os.path.isfile(OUT_IMG): + logger.info('skip '+ OUT_IMG) + continue + + logger.info(' '.join([e.name, e.path])) + + moving = e.path + register(ct_image, ct1_nii, moving, outdir) + registered += 1 + # exit() + + + + # exit() + return registered + + + +def main(): + + # check('/mnt/1218/Public/dataset2/G4/22M5LAGD') # first case + # check('/mnt/1218/Public/dataset2/G4/2FHZOOLU') # bad registration - cervical + # check('/mnt/1218/Public/dataset2/G4/2EL6U5TF') # bad registration + # exit() + + EXCLUDE = ( + # 'LLUQJUY4', #cervical + ) + + os.makedirs(OUT_ROOT, exist_ok=True) + + LOCK_DIR = os.path.join(OUT_ROOT, '0lock') + os.makedirs(LOCK_DIR, exist_ok=True) + for e in sorted(os.scandir(PATIENTS_ROOT), key=lambda e: e.name): + if e.is_dir(): + d = shelve.open(SHELVE) + if e.name in d or e.name in EXCLUDE: + logger.info('skip '+ e.name) + d.close() + continue + d.close() + lock_path = os.path.join(LOCK_DIR, '%s.lock'%e.name) + lock = filelock.FileLock(lock_path, timeout=1) + try: + lock.acquire() + except: + logger.info(lock_path + ' locked') + continue + ret = check(e.path) + lock.release() + # exit() + d = shelve.open(SHELVE) + d[e.name] = ret + d.close() + +if __name__ == '__main__': + main()