import multiprocessing import os import shutil import tempfile import time from .ants_reg import ants_reg from .elastix_reg import elastix_reg Q = multiprocessing.Queue() ''' if mov_label is not None, return transformed label, else return transformed images ''' def registration(method, fix, mov_image, mov_label=None): start = time.time() transformed = next(tempfile._get_candidate_names())+'.nii.gz' r = method(fix, mov_image) if mov_label is None: r.write_warpedmovout(transformed) # print(r) # shutil.copy(r.get_warpedmovout(), transformed) else: # r.transform(mov_label, transformed, is_label=False) r.transform(mov_label, transformed, is_label=True) end = time.time() # print(r) res = { 'name': r.__class__.__name__, 'metrics': r.get_metrics(), 'transformed': transformed, 'time': end - start, } Q. put(res) return r.get_metrics() def dump_queue(q): q.put(None) return list(iter(lambda : q.get(timeout=0.00001), None)) def reg_transform(fix, mov_image, mov_label, out_label): regs = [ants_reg, elastix_reg] inputs = [(r, fix, mov_image, mov_label) for r in regs] pool = multiprocessing.Pool(4) # print(inputs) pool_outputs = pool.starmap(registration, inputs) pool.close() print(pool_outputs) rlist = dump_queue(Q) # print(rlist) rlist2 = sorted(rlist, key=lambda r: -r['metrics']) # print(rlist2) shutil.copy(rlist2[0]['transformed'], out_label) for r in rlist2: os.remove(r['transformed']) def reg_only(fix, mov_image, out_image): regs = [ants_reg, elastix_reg] inputs = [(r, fix, mov_image) for r in regs] pool = multiprocessing.Pool(4) # print(inputs) pool_outputs = pool.starmap(registration, inputs) pool.close() print(pool_outputs) rlist = dump_queue(Q) # print(rlist) rlist2 = sorted(rlist, key=lambda r: -r['metrics']) # print(rlist2) shutil.copy(rlist2[0]['transformed'], out_image) for r in rlist2: os.remove(r['transformed']) fi = '/nn/2896833/20220506/nii/b_C+MAR_20220506155936_301.nii.gz' mv_img = '/nn/2896833/20220506/nii/9_3D_fl3d_mt_FS_+_c_MPR_Tra_20220506142416_15.nii.gz' mv_lab = '/nn/2896833/20220506/output/9_3D_fl3d_mt_FS_+_c_MPR_Tra_20220506142416.nii.gz' fi = '/nn/7295866/20250127/nii/a_1.1_CyberKnife_head(MAR)_20250127111447_5.nii.gz' mv_img = '/nn/7295866/20250127/nii/7_3D_SAG_T1_MPRAGE_+C_20250127132612_100.nii.gz' if __name__ == '__main__': # reg_transform(fi, mv_img, mv_lab, 'tmp.nii.gz') reg_only(fi, mv_img, 'tmp.nii.gz') # regs = [ants_reg, elastix_reg] # inputs = [(r, fi, mv) for r in regs] # pool = multiprocessing.Pool(4) # # print(inputs) # pool_outputs = pool.starmap(registration, inputs) # pool.close() # print(pool_outputs) # print(dump_queue(Q))