123/2-infer.py
2023-08-08 22:04:06 +00:00

263 lines
9.2 KiB
Python
Executable file

#!/usr/bin/env python3
'''
2d 0.4610997436727624
3d_fullres 0.5022740762294419
3d_lowres 0.5957028945994233
3d_cascade_fullres 0.517286480153028
ensemble_2d__nnUNetTrainerV2__nnUNetPlansv2.1--3d_fullres__nnUNetTrainerV2__nnUNetPlansv2.1 0.5243185547220239
ensemble_2d__nnUNetTrainerV2__nnUNetPlansv2.1--3d_lowres__nnUNetTrainerV2__nnUNetPlansv2.1 0.552552255340162
ensemble_2d__nnUNetTrainerV2__nnUNetPlansv2.1--3d_cascade_fullres__nnUNetTrainerV2CascadeFullRes__nnUNetPlansv2.1 0.531701751318307
ensemble_3d_fullres__nnUNetTrainerV2__nnUNetPlansv2.1--3d_lowres__nnUNetTrainerV2__nnUNetPlansv2.1 0.6105215496684026
ensemble_3d_fullres__nnUNetTrainerV2__nnUNetPlansv2.1--3d_cascade_fullres__nnUNetTrainerV2CascadeFullRes__nnUNetPlansv2.1 0.5343679184080806
ensemble_3d_lowres__nnUNetTrainerV2__nnUNetPlansv2.1--3d_cascade_fullres__nnUNetTrainerV2CascadeFullRes__nnUNetPlansv2.1 0.6000630104223947
Task222_ICTS2022 submit model ensemble_3d_fullres__nnUNetTrainerV2__nnUNetPlansv2.1--3d_lowres__nnUNetTrainerV2__nnUNetPlansv2.1 0.6105215496684026
Here is how you should predict test cases. Run in sequential order and replace all input and output folder names with your personalized ones
nnUNet_predict -i FOLDER_WITH_TEST_CASES -o OUTPUT_FOLDER_MODEL1 -tr nnUNetTrainerV2 -ctr nnUNetTrainerV2CascadeFullRes -m 3d_fullres -p nnUNetPlansv2.1 -t Task222_ICTS2022
nnUNet_predict -i FOLDER_WITH_TEST_CASES -o OUTPUT_FOLDER_MODEL2 -tr nnUNetTrainerV2 -ctr nnUNetTrainerV2CascadeFullRes -m 3d_lowres -p nnUNetPlansv2.1 -t Task222_ICTS2022
nnUNet_ensemble -f OUTPUT_FOLDER_MODEL1 OUTPUT_FOLDER_MODEL2 -o OUTPUT_FOLDER -pp /workspace/nnUNet_trained_models/nnUNet/ensembles/Task222_ICTS2022/ensemble_3d_fullres__nnUNetTrainerV2__nnUNetPlansv2.1--3d_lowres__nnUNetTrainerV2__nnUNetPlansv2.1/postprocessing.json
'''
import difflib
import os
import shutil
import subprocess
import sys
import time
from nipype.interfaces.dcm2nii import Dcm2niix
from pydicom import dcmread
from pynetdicom import AE, debug_logger
from pynetdicom.sop_class import CTImageStorage, RTStructureSetStorage
from rt_utils import RTStructBuilder
import numpy as np
import SimpleITK as sitk
# import itk_elastix
from registration.best_reg import reg_transform
def dcm2nii(source_dir, output_dir):
# print(source_dir)
# print(output_dir)
converter = Dcm2niix()
converter.inputs.source_dir = source_dir
# converter.inputs.compression = 5
converter.inputs.output_dir = output_dir
print(converter.cmdline)
converter.run()
def inference(DCM_CT, DCM_MR):
matcher = difflib.SequenceMatcher(a=DCM_CT, b=DCM_MR)
match = matcher.find_longest_match(0, len(matcher.a), 0, len(matcher.b))
ROOT_DIR = DCM_CT[:match.size]
NII_DIR = os.path.join(ROOT_DIR, 'nii')
INPUT_DIR = os.path.join(ROOT_DIR, 'input')
OUTPUT_3d_fullres = os.path.join(ROOT_DIR, '3d_fullres')
OUTPUT_3d_lowres = os.path.join(ROOT_DIR, '3d_lowres')
OUTPUT_DIR = os.path.join(ROOT_DIR, 'output')
head, tail = os.path.split(DCM_CT)
rtss_file = os.path.join(head, tail+'-rtss.dcm')
shutil.rmtree(NII_DIR, ignore_errors=True)
os.makedirs(NII_DIR)
shutil.rmtree(INPUT_DIR, ignore_errors=True)
os.makedirs(INPUT_DIR)
shutil.rmtree(OUTPUT_3d_fullres, ignore_errors=True)
shutil.rmtree(OUTPUT_3d_lowres, ignore_errors=True)
shutil.rmtree(OUTPUT_DIR, ignore_errors=True)
# os.makedirs(OUTPUT_DIR)
nCT = os.path.basename(DCM_CT)
nMR = os.path.basename(DCM_MR)
# print(nCT, nMR)
# exit()
dcm2nii(DCM_CT, NII_DIR)
dcm2nii(DCM_MR, NII_DIR)
for e in os.scandir(NII_DIR):
if e.name.endswith('.nii.gz'):
if e.name.startswith(nCT+'_'):
NII_CT = e.path
elif e.name.startswith(nMR+'_'):
NII_MR = e.path
basename = os.path.basename(NII_MR)
old = '_'+basename.split('_')[-1]
input_file = os.path.join(INPUT_DIR, basename.replace(old, '_0000.nii.gz'))
output_file = os.path.join(OUTPUT_DIR, basename.replace(old, '.nii.gz'))
basename_ct = os.path.basename(NII_CT)
old_ct = '_'+basename_ct.split('_')[-1]
label_file = os.path.join(NII_DIR, basename_ct.replace(old_ct, '.label.nii.gz'))
shutil.copy(NII_MR, input_file)
print(NII_CT, NII_MR, input_file)
# nnUNet_predict -i INPUT_FOLDER -o OUTPUT_FOLDER -t 222 -m 3d_lowres --save_npz
'''
subprocess.run(["nnUNet_predict",
"-i", INPUT_DIR,
"-o", OUTPUT_DIR,
"-t", "222",
"-m", "3d_lowres",
"--save_npz",
])
'''
# nnUNet_predict -i FOLDER_WITH_TEST_CASES -o OUTPUT_FOLDER_MODEL1 -tr nnUNetTrainerV2 -ctr nnUNetTrainerV2CascadeFullRes -m 3d_fullres -p nnUNetPlansv2.1 -t Task222_ICTS2022
# nnUNet_predict -i FOLDER_WITH_TEST_CASES -o OUTPUT_FOLDER_MODEL2 -tr nnUNetTrainerV2 -ctr nnUNetTrainerV2CascadeFullRes -m 3d_lowres -p nnUNetPlansv2.1 -t Task222_ICTS2022
# nnUNet_ensemble -f OUTPUT_FOLDER_MODEL1 OUTPUT_FOLDER_MODEL2 -o OUTPUT_FOLDER -pp /workspace/nnUNet_trained_models/nnUNet/ensembles/Task222_ICTS2022/ensemble_3d_fullres__nnUNetTrainerV2__nnUNetPlansv2.1--3d_lowres__nnUNetTrainerV2__nnUNetPlansv2.1/postprocessing.json
subprocess.run(["nnUNet_predict",
"-i", INPUT_DIR,
"-o", OUTPUT_3d_fullres,
"-tr", "nnUNetTrainerV2",
"-ctr", "nnUNetTrainerV2CascadeFullRes",
"-m", "3d_fullres",
"-p", "nnUNetPlansv2.1",
"-t", "Task222_ICTS2022",
"--save_npz",
])
subprocess.run(["nnUNet_predict",
"-i", INPUT_DIR,
"-o", OUTPUT_3d_lowres,
"-tr", "nnUNetTrainerV2",
"-ctr", "nnUNetTrainerV2CascadeFullRes",
"-m", "3d_lowres",
"-p", "nnUNetPlansv2.1",
"-t", "Task222_ICTS2022",
"--save_npz",
])
a = ["nnUNet_ensemble",
"-f", OUTPUT_3d_fullres, OUTPUT_3d_lowres,
"-o", OUTPUT_DIR,
"-pp", "/workspace/nnUNet_trained_models/nnUNet/ensembles/Task222_ICTS2022/ensemble_3d_fullres__nnUNetTrainerV2__nnUNetPlansv2.1--3d_lowres__nnUNetTrainerV2__nnUNetPlansv2.1/postprocessing.json",
]
print (' '.join(a))
subprocess.run(["nnUNet_ensemble",
"-f", OUTPUT_3d_fullres, OUTPUT_3d_lowres,
"-o", OUTPUT_DIR,
"-pp", "/workspace/nnUNet_trained_models/nnUNet/ensembles/Task222_ICTS2022/ensemble_3d_fullres__nnUNetTrainerV2__nnUNetPlansv2.1--3d_lowres__nnUNetTrainerV2__nnUNetPlansv2.1/postprocessing.json",
])
print(output_file)
# r2 = itk_elastix.register(NII_CT, NII_MR)
# itk_elastix.transform_write(output_file, r2['fwdtransforms'], label_file, is_label=True)
reg_transform(NII_CT, NII_MR, output_file, label_file)
reader = sitk.ImageSeriesReader()
dicom_names = reader.GetGDCMSeriesFileNames(DCM_CT)
reader.SetFileNames(dicom_names)
reader.MetaDataDictionaryArrayUpdateOn()
reader.LoadPrivateTagsOn()
image = reader.Execute()
nnU = sitk.ReadImage(label_file)
nnU = sitk.Resample(nnU, image, sitk.Transform(), sitk.sitkNearestNeighbor)
ccfilter = sitk.ConnectedComponentImageFilter ()
nnUCC = ccfilter.Execute(nnU)
ObjectCount1 = ccfilter.GetObjectCount()
rtstruct = RTStructBuilder.create_new(dicom_series_path=DCM_CT)
for j1 in range(ObjectCount1):
label1 = sitk.BinaryThreshold(nnUCC, j1+1, j1+1)
# label1 = sitk.AntiAliasBinary(label1)
mask = sitk.GetArrayFromImage(label1).astype(bool)
mask = np.transpose(mask, (1, 2, 0))
# continue
if mask.any():
print(j1)
rtstruct.add_roi(
mask=mask,
# use_pin_hole=True,
# name="n%d"%n,
)
print(rtss_file)
rtstruct.save(rtss_file)
return rtss_file
# incorporate send_c_store
def SendDCM(fp):
debug_logger()
# Initialise the Application Entity
ae = AE()
ae.ae_title = 'OUR_STORE_SCP'
# Add a requested presentation context
# ae.add_requested_context(CTImageStorage)
ae.add_requested_context(RTStructureSetStorage)
# Read in our DICOM CT dataset
ds = dcmread(fp)
# Associate with peer AE at IP 127.0.0.1 and port 11112
assoc = ae.associate("127.0.0.1", 11112)
assoc = ae.associate("172.16.40.36", 104,
ae_title = 'N1000_STORAGE',
)
if assoc.is_established:
# Use the C-STORE service to send the dataset
# returns the response status as a pydicom Dataset
status = assoc.send_c_store(ds)
# Check the status of the storage request
if status:
# If the storage request succeeded this will be 0x0000
print('C-STORE request status: 0x{0:04x}'.format(status.Status))
else:
print('Connection timed out, was aborted or received invalid response')
# Release the association
assoc.release()
else:
print('Association rejected, aborted or never connected')
def main():
if len(sys.argv) < 2:
print('Usage:', sys.argv[0], 'DCM_CT', 'DCM_MR')
sys.exit()
print('hello')
print(sys.argv[0])
print(sys.argv[1])
start = time.time()
rtss_file = inference(sys.argv[1], sys.argv[2])
SendDCM(rtss_file)
end = time.time()
print(end - start, 'seconds')
if __name__ == '__main__':
main()