150 lines
5.9 KiB
Python
150 lines
5.9 KiB
Python
|
|
import os
|
|||
|
|
import SimpleITK as sitk
|
|||
|
|
from imaging.resample import resample_img
|
|||
|
|
from imaging.affine import standardize_affine
|
|||
|
|
from imaging.segmentation import seg_bone
|
|||
|
|
import json
|
|||
|
|
import glob
|
|||
|
|
from config.constant import LABEL_MAP
|
|||
|
|
from imaging.nifti_io import sitk_to_nibabel, nibabel_to_sitk
|
|||
|
|
|
|||
|
|
PROGRESS_FILE = "progress.json"
|
|||
|
|
|
|||
|
|
def load_progress():
|
|||
|
|
if os.path.exists(PROGRESS_FILE):
|
|||
|
|
with open(PROGRESS_FILE, "r") as f:
|
|||
|
|
return json.load(f)
|
|||
|
|
return {}
|
|||
|
|
|
|||
|
|
def save_progress(progress):
|
|||
|
|
with open(PROGRESS_FILE, "w") as f:
|
|||
|
|
json.dump(progress, f, indent=2)
|
|||
|
|
|
|||
|
|
def process_single_image(image_path, label_path, output_dir_base=None):
|
|||
|
|
|
|||
|
|
image = sitk.ReadImage(image_path)
|
|||
|
|
label = sitk.ReadImage(label_path)
|
|||
|
|
file_name = os.path.basename(image_path)
|
|||
|
|
name = file_name.replace(".nii.gz", "")
|
|||
|
|
|
|||
|
|
# LabelStatisticsImageFilter computes statistics (e.g., mean, minimum, maximum, median) of pixel values in an image, segmented by labels in a corresponding label image.
|
|||
|
|
lsif = sitk.LabelStatisticsImageFilter()
|
|||
|
|
lsif.Execute(image, label)
|
|||
|
|
|
|||
|
|
# Assume to have some sitk image (itk_image) and label (itk_label)
|
|||
|
|
resampled_sitk_img = resample_img(image, out_spacing=[0.5, 0.5, 0.5], is_label=False)
|
|||
|
|
resampled_sitk_lbl = resample_img(label, out_spacing=[0.5, 0.5, 0.5], is_label=True)
|
|||
|
|
|
|||
|
|
# 取得現有 label
|
|||
|
|
lssif = sitk.LabelShapeStatisticsImageFilter()
|
|||
|
|
lssif.Execute(label)
|
|||
|
|
existing_labels = lssif.GetLabels() # 這會回傳 list,例如 [1,2,3,20,21]
|
|||
|
|
|
|||
|
|
print(f"Existing labels in {os.path.basename(label_path)}: {existing_labels}")
|
|||
|
|
|
|||
|
|
# 建立每個檔案的輸出資料夾
|
|||
|
|
file_name = os.path.basename(image_path)
|
|||
|
|
name = file_name.replace(".nii.gz", "")
|
|||
|
|
output_dir = os.path.join(output_dir_base, name)
|
|||
|
|
os.makedirs(output_dir, exist_ok=True)
|
|||
|
|
|
|||
|
|
# 存現有 label 到 txt
|
|||
|
|
txt_path = os.path.join(output_dir, f"{name}_labels.txt")
|
|||
|
|
with open(txt_path, "w") as f:
|
|||
|
|
for lab in existing_labels:
|
|||
|
|
f.write(f"{lab}\t{LABEL_MAP.get(lab, 'Unknown')}\n")
|
|||
|
|
|
|||
|
|
# 遍歷現有 label 做分割
|
|||
|
|
for n in existing_labels:
|
|||
|
|
try:
|
|||
|
|
roi_path, binary_path, roi2_path, cortical_path = seg_bone(n, name, resampled_sitk_img, resampled_sitk_lbl, output_dir, label_map=LABEL_MAP)
|
|||
|
|
for path in [roi_path, binary_path, roi2_path, cortical_path]:
|
|||
|
|
standardize_affine(path, output_dir)
|
|||
|
|
except RuntimeError as e:
|
|||
|
|
print(f"Label {n} could not be processed, skipping. Error: {e}")
|
|||
|
|
|
|||
|
|
return {
|
|||
|
|
"processed_labels": [lab for lab in existing_labels],
|
|||
|
|
"missing_labels": []
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
def process_dataset(image_dir, label_dir, output_dir, labels_to_process=None):
|
|||
|
|
image_files = sorted(glob.glob(os.path.join(image_dir, "*.nii.gz")))
|
|||
|
|
total_files = len(image_files)
|
|||
|
|
print(f"Total files: {total_files}")
|
|||
|
|
|
|||
|
|
progress = load_progress()
|
|||
|
|
all_file_summary = []
|
|||
|
|
|
|||
|
|
for idx, image_path in enumerate(image_files, 1):
|
|||
|
|
file_name = os.path.basename(image_path)
|
|||
|
|
name = file_name.replace(".nii.gz", "")
|
|||
|
|
label_path = os.path.join(label_dir, file_name.replace(".nii.gz", "_seg.nii.gz"))
|
|||
|
|
|
|||
|
|
file_summary = {
|
|||
|
|
"file_name": file_name,
|
|||
|
|
"current_labels": [],
|
|||
|
|
"missing_labels": []
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if progress.get(name, {}).get("finished", False):
|
|||
|
|
print(f"[{idx}/{total_files}] Already finished: {file_name}")
|
|||
|
|
file_summary["current_labels"] = progress[name].get("processed_labels", [])
|
|||
|
|
file_summary["missing_labels"] = progress[name].get("missing_labels", [])
|
|||
|
|
all_file_summary.append(file_summary)
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
if not os.path.exists(label_path):
|
|||
|
|
print(f"[{idx}/{total_files}] Warning: label not found for {file_name}")
|
|||
|
|
file_summary["note"] = "Label file not found"
|
|||
|
|
all_file_summary.append(file_summary)
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
result = process_single_image(image_path, label_path, output_dir_base=output_dir)
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"[{idx}/{total_files}] Error processing {file_name}: {e}")
|
|||
|
|
file_summary["note"] = f"Error: {e}"
|
|||
|
|
all_file_summary.append(file_summary)
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
file_summary["current_labels"] = result["processed_labels"]
|
|||
|
|
file_summary["missing_labels"] = result["missing_labels"]
|
|||
|
|
all_file_summary.append(file_summary)
|
|||
|
|
|
|||
|
|
progress[name] = {
|
|||
|
|
"finished": True,
|
|||
|
|
"processed_labels": result["processed_labels"],
|
|||
|
|
"missing_labels": result["missing_labels"]
|
|||
|
|
}
|
|||
|
|
save_progress(progress)
|
|||
|
|
|
|||
|
|
print(f"[{idx}/{total_files}] Finished: {file_name} | Missing labels: {result['missing_labels'] or 'None'}")
|
|||
|
|
|
|||
|
|
# --- Summary ---
|
|||
|
|
summary_path = os.path.join(output_dir, "all_files_label_summary.txt")
|
|||
|
|
os.makedirs(os.path.dirname(summary_path), exist_ok=True)
|
|||
|
|
print(f"\nWriting summary to {summary_path}...")
|
|||
|
|
|
|||
|
|
with open(summary_path, "w") as f:
|
|||
|
|
f.write("--- CTSpine1K Dataset Label Summary ---\n")
|
|||
|
|
f.write(f"Total files processed: {total_files}\n\n")
|
|||
|
|
|
|||
|
|
for summary in all_file_summary:
|
|||
|
|
f.write("================================================\n")
|
|||
|
|
f.write(f"File: {summary['file_name']}\n")
|
|||
|
|
|
|||
|
|
processed_labels_str = ", ".join([str(l) for l in summary['current_labels']])
|
|||
|
|
f.write(f"Labels processed: {processed_labels_str}\n")
|
|||
|
|
|
|||
|
|
if summary['missing_labels']:
|
|||
|
|
missing_str = ", ".join([f"{l} ({LABEL_MAP.get(l, 'Unknown')})" for l in summary['missing_labels']])
|
|||
|
|
f.write(f"🚨 Missing Labels: {missing_str}\n")
|
|||
|
|
else:
|
|||
|
|
f.write("✅ Missing Labels: None\n")
|
|||
|
|
|
|||
|
|
if "note" in summary:
|
|||
|
|
f.write(f"Note: {summary['note']}\n")
|
|||
|
|
f.write("================================================\n\n")
|
|||
|
|
|
|||
|
|
print("All done! Summary file created.")
|