123/registration/best_reg.py
2025-02-01 15:57:22 +08:00

114 lines
3 KiB
Python
Executable file

import multiprocessing
import os
import shutil
import tempfile
import time
from .ants_reg import ants_reg
from .elastix_reg import elastix_reg
Q = multiprocessing.Queue()
'''
if mov_label is not None, return transformed label, else return transformed images
'''
def registration(method, fix, mov_image, mov_label=None):
start = time.time()
transformed = next(tempfile._get_candidate_names())+'.nii.gz'
r = method(fix, mov_image)
if mov_label is None:
r.write_warpedmovout(transformed)
# print(r)
# shutil.copy(r.get_warpedmovout(), transformed)
else:
# r.transform(mov_label, transformed, is_label=False)
r.transform(mov_label, transformed, is_label=True)
end = time.time()
# print(r)
res = {
'name': r.__class__.__name__,
'metrics': r.get_metrics(),
'transformed': transformed,
'time': end - start,
}
Q. put(res)
return r.get_metrics()
def dump_queue(q):
q.put(None)
return list(iter(lambda : q.get(timeout=0.00001), None))
def reg_transform(fix, mov_image, mov_label, out_label):
regs = [ants_reg, elastix_reg]
inputs = [(r, fix, mov_image, mov_label) for r in regs]
pool = multiprocessing.Pool(4)
# print(inputs)
pool_outputs = pool.starmap(registration, inputs)
pool.close()
print(pool_outputs)
rlist = dump_queue(Q)
# print(rlist)
rlist2 = sorted(rlist, key=lambda r: -r['metrics'])
# print(rlist2)
shutil.copy(rlist2[0]['transformed'], out_label)
for r in rlist2:
os.remove(r['transformed'])
def reg_only(fix, mov_image, out_image):
regs = [ants_reg, elastix_reg]
inputs = [(r, fix, mov_image) for r in regs]
pool = multiprocessing.Pool(4)
# print(inputs)
pool_outputs = pool.starmap(registration, inputs)
pool.close()
print(pool_outputs)
rlist = dump_queue(Q)
# print(rlist)
rlist2 = sorted(rlist, key=lambda r: -r['metrics'])
# print(rlist2)
shutil.copy(rlist2[0]['transformed'], out_image)
for r in rlist2:
os.remove(r['transformed'])
fi = '/nn/2896833/20220506/nii/b_C+MAR_20220506155936_301.nii.gz'
mv_img = '/nn/2896833/20220506/nii/9_3D_fl3d_mt_FS_+_c_MPR_Tra_20220506142416_15.nii.gz'
mv_lab = '/nn/2896833/20220506/output/9_3D_fl3d_mt_FS_+_c_MPR_Tra_20220506142416.nii.gz'
fi = '/nn/7295866/20250127/nii/a_1.1_CyberKnife_head(MAR)_20250127111447_5.nii.gz'
mv_img = '/nn/7295866/20250127/nii/7_3D_SAG_T1_MPRAGE_+C_20250127132612_100.nii.gz'
if __name__ == '__main__':
# reg_transform(fi, mv_img, mv_lab, 'tmp.nii.gz')
reg_only(fi, mv_img, 'tmp.nii.gz')
# regs = [ants_reg, elastix_reg]
# inputs = [(r, fi, mv) for r in regs]
# pool = multiprocessing.Pool(4)
# # print(inputs)
# pool_outputs = pool.starmap(registration, inputs)
# pool.close()
# print(pool_outputs)
# print(dump_queue(Q))