123/2-infer.py
2022-05-11 11:38:16 +08:00

147 lines
3.9 KiB
Python
Executable file

#!/usr/bin/env python3
import difflib
import os
import shutil
import subprocess
import sys
import time
from nipype.interfaces.dcm2nii import Dcm2niix
from rt_utils import RTStructBuilder
import numpy as np
import SimpleITK as sitk
import itk_elastix
def dcm2nii(source_dir, output_dir):
# print(source_dir)
# print(output_dir)
converter = Dcm2niix()
converter.inputs.source_dir = source_dir
# converter.inputs.compression = 5
converter.inputs.output_dir = output_dir
print(converter.cmdline)
converter.run()
def register(DCM_CT, DCM_MR):
matcher = difflib.SequenceMatcher(a=DCM_CT, b=DCM_MR)
match = matcher.find_longest_match(0, len(matcher.a), 0, len(matcher.b))
ROOT_DIR = DCM_CT[:match.size]
NII_DIR = os.path.join(ROOT_DIR, 'nii')
INPUT_DIR = os.path.join(ROOT_DIR, 'input')
OUTPUT_DIR = os.path.join(ROOT_DIR, 'output')
head, tail = os.path.split(DCM_CT)
rtss_file = os.path.join(head, tail+'-rtss.dcm')
shutil.rmtree(NII_DIR, ignore_errors=True)
os.makedirs(NII_DIR)
shutil.rmtree(INPUT_DIR, ignore_errors=True)
os.makedirs(INPUT_DIR)
shutil.rmtree(OUTPUT_DIR, ignore_errors=True)
# os.makedirs(OUTPUT_DIR)
nCT = os.path.basename(DCM_CT)
nMR = os.path.basename(DCM_MR)
# print(nCT, nMR)
# exit()
dcm2nii(DCM_CT, NII_DIR)
dcm2nii(DCM_MR, NII_DIR)
for e in os.scandir(NII_DIR):
if e.name.endswith('.nii.gz'):
if e.name.startswith(nCT+'_'):
NII_CT = e.path
elif e.name.startswith(nMR+'_'):
NII_MR = e.path
basename = os.path.basename(NII_MR)
old = '_'+basename.split('_')[-1]
input_file = os.path.join(INPUT_DIR, basename.replace(old, '_0000.nii.gz'))
output_file = os.path.join(OUTPUT_DIR, basename.replace(old, '.nii.gz'))
basename_ct = os.path.basename(NII_CT)
old_ct = '_'+basename_ct.split('_')[-1]
label_file = os.path.join(NII_DIR, basename_ct.replace(old_ct, '.label.nii.gz'))
shutil.copy(NII_MR, input_file)
print(NII_CT, NII_MR, input_file)
# nnUNet_predict -i INPUT_FOLDER -o OUTPUT_FOLDER -t 222 -m 3d_lowres --save_npz
subprocess.run(["nnUNet_predict",
"-i", INPUT_DIR,
"-o", OUTPUT_DIR,
"-t", "222",
"-m", "3d_lowres",
"--save_npz",
])
print(output_file)
r2 = itk_elastix.register(NII_CT, NII_MR)
itk_elastix.transform_write(output_file, r2['fwdtransforms'], label_file, is_label=True)
reader = sitk.ImageSeriesReader()
dicom_names = reader.GetGDCMSeriesFileNames(DCM_CT)
reader.SetFileNames(dicom_names)
reader.MetaDataDictionaryArrayUpdateOn()
reader.LoadPrivateTagsOn()
image = reader.Execute()
nnU = sitk.ReadImage(label_file)
nnU = sitk.Resample(nnU, image, sitk.Transform(), sitk.sitkNearestNeighbor)
ccfilter = sitk.ConnectedComponentImageFilter ()
nnUCC = ccfilter.Execute(nnU)
ObjectCount1 = ccfilter.GetObjectCount()
rtstruct = RTStructBuilder.create_new(dicom_series_path=DCM_CT)
for j1 in range(ObjectCount1):
label1 = sitk.BinaryThreshold(nnUCC, j1+1, j1+1)
# label1 = sitk.AntiAliasBinary(label1)
mask = sitk.GetArrayFromImage(label1).astype(bool)
mask = np.transpose(mask, (1, 2, 0))
# continue
if mask.any():
print(j1)
rtstruct.add_roi(
mask=mask,
# use_pin_hole=True,
# name="n%d"%n,
)
print(rtss_file)
rtstruct.save(rtss_file)
def main():
if len(sys.argv) < 2:
print('Usage:', sys.argv[0], 'DCM_CT', 'DCM_MR')
sys.exit()
print('hello')
print(sys.argv[0])
print(sys.argv[1])
start = time.time()
register(sys.argv[1], sys.argv[2])
end = time.time()
print(end - start, 'seconds')
if __name__ == '__main__':
main()