123/registration/best_reg.py
2023-08-08 22:04:06 +00:00

76 lines
2 KiB
Python

import multiprocessing
import os
import shutil
import tempfile
import time
from .ants_reg import ants_reg
from .elastix_reg import elastix_reg
Q = multiprocessing.Queue()
def registration(method, fix, mov_image, mov_label):
start = time.time()
transformed = next(tempfile._get_candidate_names())+'.nii.gz'
r = method(fix, mov_image)
r.transform(mov_label, transformed, is_label=False)
r.transform(mov_label, transformed, is_label=True)
end = time.time()
res = {
'name': r.__class__.__name__,
'metrics': r.get_metrics(),
'transformed': transformed,
'time': end - start,
}
Q. put(res)
return r.get_metrics()
def dump_queue(q):
q.put(None)
return list(iter(lambda : q.get(timeout=0.00001), None))
def reg_transform(fix, mov_image, mov_label, out_label):
regs = [ants_reg, elastix_reg]
inputs = [(r, fix, mov_image, mov_label) for r in regs]
pool = multiprocessing.Pool(4)
# print(inputs)
pool_outputs = pool.starmap(registration, inputs)
pool.close()
print(pool_outputs)
rlist = dump_queue(Q)
# print(rlist)
rlist2 = sorted(rlist, key=lambda r: -r['metrics'])
print(rlist2)
shutil.copy(rlist2[0]['transformed'], out_label)
for r in rlist2:
os.remove(r['transformed'])
fi = '/nn/2896833/20220506/nii/b_C+MAR_20220506155936_301.nii.gz'
mv_img = '/nn/2896833/20220506/nii/9_3D_fl3d_mt_FS_+_c_MPR_Tra_20220506142416_15.nii.gz'
mv_lab = '/nn/2896833/20220506/output/9_3D_fl3d_mt_FS_+_c_MPR_Tra_20220506142416.nii.gz'
if __name__ == '__main__':
reg_transform(fi, mv_img, mv_lab, 'tmp.nii.gz')
# regs = [ants_reg, elastix_reg]
# inputs = [(r, fi, mv) for r in regs]
# pool = multiprocessing.Pool(4)
# # print(inputs)
# pool_outputs = pool.starmap(registration, inputs)
# pool.close()
# print(pool_outputs)
# print(dump_queue(Q))