''' Use SynthMorph to register M6 images https://download-directory.github.io/ https://github.com/freesurfer/freesurfer/tree/dev/mri_synthmorph CUDA_VISIBLE_DEVICES=3 python m6synthmorph.py XLA_FLAGS=--xla_gpu_cuda_data_dir=/home/xfr/.conda/envs/25reg time ./mri_synthmorph -m affine -o ../test.nii.gz -g '/mnt/1218/Public/dataset2/M6/ZYRGTRKJ/20230728/MR/3D_SAG_T1_MPRAGE_+C_MPR_Tra_20230728143005_14.nii.gz' '/mnt/1218/Public/dataset2/M6/ZYRGTRKJ/20230728/CT/1.1_CyberKnife_head(MAR)_20230728111920_3.nii.gz' XLA_FLAGS=--xla_gpu_cuda_data_dir=/home/xfr/.conda/envs/25reg time mri_synthmorph/mri_synthmorph -m affine -o affine.nii.gz -g moving.nii.gz clipped.nii.gz find /mnt/1218/Public/dataset2/G4-synthmorph/ -iname metrics.json -exec grep -H "1.*," {} ";"|sort -k 2 -n|head find /mnt/1218/Public/dataset2/G4-synthmorph/ -iname metrics1.json -exec grep -H ":" {} ";"|sort -k 3 -n|head find /mnt/1218/Public/dataset2/G4-synthmorph/ -iname metrics1.json -exec grep -H ":" {} ";"|grep joint|sort -k 3 -n|head -n 20 bad registration if metric1 < 1.09 ''' from pathlib import Path import argparse import logging import json import os # import pathlib import shelve import shutil import time from skimage.metrics import normalized_mutual_information import filelock import matplotlib.pyplot as plt import numpy as np # import SimpleITK as sitk from mri_synthmorph.synthmorph import registration # from synthmorph import registration import surfa as sf ### Need NFS for lock PATIENTS_ROOT = '/mnt/1220/Public/dataset2/M6' OUT_ROOT = '/mnt/1220/Public/dataset2/M6-synthmorph' SHELVE = os.path.join(OUT_ROOT, '0shelve') MAX_Y = 256 SIZE_X = 249 SIZE_Y = 249 SIZE_Z = 192 # SIZE_Z = 256 MIN_OVERLAP = 0.50 MIN_METRIC = -0.50 logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler(), logging.FileHandler('g4synthmorph.log') ] ) logger = logging.getLogger(__name__) def bbox2_3D(img): r = np.any(img, axis=(1, 2)) c = np.any(img, axis=(0, 2)) z = np.any(img, axis=(0, 1)) if not np.any(r): return -1, -1, -1, -1, -1, -1 rmin, rmax = np.where(r)[0][[0, -1]] cmin, cmax = np.where(c)[0][[0, -1]] zmin, zmax = np.where(z)[0][[0, -1]] return rmin, rmax, cmin, cmax, zmin, zmax ''' Namespace(command='register', moving='/nn/7295866/20250127/nii/7_3D_SAG_T1_MPRAGE_+C_20250127132612_100.nii.gz', fixed='/123/onlylian/0/tmpgp96622o/clipped.nii.gz', model='joint', out_moving='/123/onlylian/0/tmpgp96622o/joint.nii.gz', out_fixed='/123/onlylian/0/tmpgp96622o/out_fixed-joint.nii.gz', header_only=False, trans='/123/onlylian/0/tmpgp96622o/moving_to_fixed-joint.nii.gz', inverse='/123/onlylian/0/tmpgp96622o/fixed_to_moving-joint.nii.gz', init=None, mid_space=False, threads=None, gpu=True, hyper=0.5, steps=7, extent=256, weights=None, verbose=False, out_dir=None) ''' def register(ct0, ct1, moving, out_root): FREESURFER_HOME = '/mnt/1218/Public/packages/freesurfer-8.0.0-beta/' # out_root = Path(ct0).resolve().parent/os.path.basename(mr).replace('.nii.gz','') # print(out_root) modality = os.path.basename(out_root) # exit() out_root = Path(out_root)/os.path.basename(moving).replace('.nii.gz','') out_root.mkdir(exist_ok=True) logger.info(' '.join((modality, ct0, ct1, moving, str(out_root)))) orig = sf.load_volume(moving) base = sf.load_volume(ct0) base1 = sf.load_volume(ct1) if modality == 'XA': exit() if modality == 'CT': clipped = out_root/'clipped.nii.gz' cl = orig.clip(0, 80) cl.save(clipped) MODELS = [ 'rigid', # 'affine', # 'joint', ] else: clipped = moving MODELS = [ 'rigid', 'affine', 'joint', ] # exit() default = { 'command': 'register', 'header_only': False, 'init': None, 'mid_space': False, 'threads': None, # 'gpu': False, 'gpu': True, 'verbose': False, # 'verbose': True, 'hyper': 0.5, 'steps': 7, 'extent': 256, 'weights': None, # 'model': 'affine', # 'out_dir': None, # 'out_fixed': 'out_fixed.nii.gz', # 'out_moving': 'out_moving.nii.gz', # 'trans': None, # 'inverse': None, 'out_fixed': None, 'out_moving': None, 'trans': None, 'inverse': None, 'moving' : clipped, 'fixed' : ct1, # 'weights': str(Path(__file__).resolve().parent/'mri_synthmorph/models/synthmorph.affine.2.h5'), } os.environ["FREESURFER_HOME"] = FREESURFER_HOME os.environ["XLA_FLAGS"] = '--xla_gpu_cuda_data_dir=%s'% os.environ["CONDA_PREFIX"] fill = orig.min() # print(fill) # exit() METRICS0 = {} METRICS1 = {} inp1 = None inp2 = None for m in MODELS: default['model'] = m default['out_dir'] = out_root/m # if m == 'rigid': # default['gpu'] = False # else: # default['gpu'] = True arg=argparse.Namespace(**default) # CONDA_PREFIX=/home/xfr/.conda/envs/25reg # XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda logger.info('registering %s'%m) registration.register(arg) logger.info('registered %s'%m) if inp1 == None: inp1 = sf.load_volume(default['out_dir']/'inp_1.nii.gz') if inp2 == None: inp2 = sf.load_volume(default['out_dir']/'inp_2.nii.gz') out1 = sf.load_volume(default['out_dir']/'out_1.nii.gz') out2 = sf.load_volume(default['out_dir']/'out_2.nii.gz') out = out_root/('%s.nii.gz'%m) if m in ['affine', 'rigid']: trans = sf.load_affine(default['out_dir']/'tra_1.lta') prop = dict(method='linear', resample=True, fill=fill) resampled = orig.transform(trans, **prop).resample_like(base, fill=fill) logger.info('transformed %s'%out) # print(prop) # exit() else: # need to resample before transform in warp, too complicated, just copy it # trans1 = default['out_dir']/'tra_1.nii.gz' # trans = sf.load_warp(trans1) resampled = out1.resample_like(base, fill=fill) logger.info('resampled %s'% out) resampled.save(out) inp1_out2 = normalized_mutual_information(inp1.data, out2.data) inp2_out1 = normalized_mutual_information(inp2.data, out1.data) m0 = normalized_mutual_information(base.data, resampled.data) m1 = normalized_mutual_information(base1.data, resampled.data) METRICS0[m] = (inp1_out2, inp2_out1, m0, m1) METRICS1[m] = max(inp1_out2, inp2_out1, m0, m1) with open(out_root/'metrics0.json', 'w') as f_metrics: json.dump(METRICS0, f_metrics, indent=1) with open(out_root/'metrics1.json', 'w') as f_metrics: json.dump(METRICS1, f_metrics, indent=1) return out_root def check(epath): registered = 0 for root, dirs, files in os.walk(epath): dirs.sort() RT_DIR = os.path.join(root, 'RT') ORGAN_DIR = os.path.join(RT_DIR, 'ORGAN') if not os.path.isdir(ORGAN_DIR): continue # if there is no eye, it's no a brain image eye = None organs = sorted(os.scandir(ORGAN_DIR), key=lambda e: e.name) for o in organs: if 'eye' in o.name.lower(): eye = o if eye is None: logger.info('no eye... skip ' + root) # exit() return None ct_image = os.path.join(RT_DIR, 'ct_image.nii.gz') outdir = os.path.join(OUT_ROOT, os.path.relpath(root, PATIENTS_ROOT)) logger.info(outdir) os.makedirs(outdir, exist_ok=True) # ct0_nii = os.path.join(outdir, 'ct0.nii.gz') ct1_nii = os.path.join(outdir, 'clipped.nii.gz') # shutil.copy(ct_image, ct0_nii) ct = sf.load_volume(ct_image) clipped = ct.clip(0, 80) clipped.save(ct1_nii) for root2, dirs2, files2 in os.walk(root): dirs2.sort() outdir = os.path.join(OUT_ROOT, os.path.relpath(root2, PATIENTS_ROOT)) if root2.endswith('RT'): modality = 'RT' logger.info('copying %s %s' %(root2, outdir)) shutil.copytree(root2, outdir, dirs_exist_ok=True) # exit() continue skip = (root2==root) or ('RT' in root2.split('/')) if skip: continue if root2.endswith('CT'): modality = 'CT' else: modality = 'other' logger.info(' '.join([str(skip), root2, modality])) outdir = os.path.join(OUT_ROOT, os.path.relpath(root2, PATIENTS_ROOT)) os.makedirs(outdir, exist_ok=True) for e in sorted(os.scandir(root2), key=lambda e: e.name): if not e.name.endswith('.nii.gz'): continue if '_RTDOSE_' in e.name: continue if '_DTI_' in e.name: continue if '_ROI1.' in e.name: continue OUT_IMG = os.path.join(outdir, e.name) if os.path.isfile(OUT_IMG): logger.info('skip '+ OUT_IMG) continue logger.info(' '.join([e.name, e.path])) moving = e.path register(ct_image, ct1_nii, moving, outdir) registered += 1 # exit() # exit() return registered def main(): # check('/mnt/1218/Public/dataset2/G4/22M5LAGD') # first case # check('/mnt/1218/Public/dataset2/G4/2FHZOOLU') # bad registration - cervical # check('/mnt/1218/Public/dataset2/G4/2EL6U5TF') # bad registration # exit() EXCLUDE = ( # 'LLUQJUY4', #cervical '2XYU7UHB', # I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: INVALID_ARGUMENT: Input is not invertible. ) os.makedirs(OUT_ROOT, exist_ok=True) LOCK_DIR = os.path.join(OUT_ROOT, '0lock') os.makedirs(LOCK_DIR, exist_ok=True) for e in sorted(os.scandir(PATIENTS_ROOT), key=lambda e: e.name): if e.is_dir(): d = shelve.open(SHELVE) if e.name in d or e.name in EXCLUDE: logger.info('skip '+ e.name) d.close() continue d.close() lock_path = os.path.join(LOCK_DIR, '%s.lock'%e.name) lock = filelock.FileLock(lock_path, timeout=1) try: lock.acquire() except: logger.info(lock_path + ' locked') continue ret = check(e.path) lock.release() # exit() d = shelve.open(SHELVE) d[e.name] = ret d.close() if __name__ == '__main__': main()