import torch from core.cylinder import generate_cylinder_n_torch, generate_cylinder_tip_torch from config.constant import OVERLAP_THRESH def cl_score_torch_xfr( cortical_tensor: torch.Tensor, spine_tensor: torch.Tensor, cylinder_torch: torch.Tensor, cylinder_o_torch: torch.Tensor, intersections: int, diameter: float = None, length: float = None, cylinder_tip_torch: torch.Tensor = None # 新增:尖端 mask ) -> float: """ 漸進式評分:優先確保找到骨頭,再改善細節 """ cyl_total = cylinder_torch.sum().item() overlap = ((cortical_tensor == 1) & (cylinder_torch == 1)).sum().item() null_vox = ((cortical_tensor == 0) & (cylinder_torch == 1)).sum().item() null_vox2 = ((spine_tensor == 1) & (cylinder_o_torch == 1)).sum().item() in_bone= ((spine_tensor == 1) & (cylinder_torch == 1)).sum().item() not_in_bone= ((spine_tensor == 0) & (cylinder_torch == 1)).sum().item() # if cyl_total == 0: # return float(1000*1000) # return float(1e9) # 極差的情況 overlap_ratio = overlap / cyl_total # if cyl_total < 1000: # return float((1000 - cyl_total)*10000) score = cyl_total # if in_bone == 0: # return float(not_in_bone*200) score += overlap*30 score += in_bone*10 score -= not_in_bone*1000 score -= null_vox2*1000 return float(-score) # === 階段 1:首要目標是找到骨頭(overlap > 0) === if overlap == 0: # 完全沒有 overlap 是最糟糕的情況 score -= 500000 # 超大懲罰 # 如果連 spine 都沒穿過,更糟 if intersections == 0: score -= 500000 return float(-score) # === 階段 2:有找到骨頭了,開始改善品質 === # 1. Overlap 獎勵(非線性,鼓勵快速提升) if overlap_ratio < 0.1: # 0-10%:每增加 1% 給大量獎勵(鼓勵探索) score += overlap * 5000 # 很高的單位獎勵 elif overlap_ratio < 0.3: # 10-30%:中等獎勵 score += overlap * 3000 elif overlap_ratio < 0.5: # 30-50%:正常獎勵 score += overlap * 2000 else: # 50%+:獎勵 + 額外比例獎勵 score += overlap * 2000 score += (overlap_ratio - 0.5) * 100000 # 超過 50% 額外大獎 # 2. Intersection 控制(稍微放寬) if intersections == 1: score += 20000 # 完美 elif intersections == 0: score -= 200000 # 嚴重錯誤(但比完全沒 overlap 好) elif intersections == 2: score -= 10000 # 可接受但不理想 else: score -= intersections * 15000 # 3. Null voxel 懲罰(漸進式) null_ratio = null_vox / cyl_total if overlap_ratio < 0.2: # 如果 overlap 還很少,對 null voxel 寬容一點 score -= null_vox * 300 elif overlap_ratio < 0.4: score -= null_vox * 600 else: # overlap 夠高了,開始嚴格要求 if null_ratio > 0.5: score -= null_vox * 1500 else: score -= null_vox * 800 # 4. 反向圓柱懲罰 score -= null_vox2 * 1000 # 5. 尺寸合理性(放寬) if diameter is not None and length is not None: if diameter < 2.5 or diameter > 6.0: # 放寬從 (3.0, 5.5) 到 (2.5, 6.0) score -= 3000 if length < 25 or length > 60: # 放寬從 (30, 55) 到 (25, 60) score -= 3000 # 6. 尖端 breach 懲罰 if cylinder_tip_torch is not None: tip_total = cylinder_tip_torch.sum().item() if tip_total > 0: tip_breach = ((cortical_tensor == 0) & (cylinder_tip_torch == 1)).sum().item() tip_breach_ratio = tip_breach / tip_total if tip_breach_ratio > 0: score -= tip_breach * 5000 # 尖端出界懲罰要比一般 null_vox 重很多 return float(-score) def cl_score_torch( cortical_tensor: torch.Tensor, spine_tensor: torch.Tensor, cylinder_torch: torch.Tensor, cylinder_o_torch: torch.Tensor, intersections: int, diameter: float = None, length: float = None, cylinder_tip_torch: torch.Tensor = None # 新增:尖端 mask ) -> float: """ 漸進式評分:優先確保找到骨頭,再改善細節 """ cyl_total = cylinder_torch.sum().item() overlap = ((cortical_tensor == 1) & (cylinder_torch == 1)).sum().item() null_vox = ((cortical_tensor == 0) & (cylinder_torch == 1)).sum().item() null_vox2 = ((spine_tensor == 1) & (cylinder_o_torch == 1)).sum().item() if cyl_total == 0: return float(1e9) # 極差的情況 overlap_ratio = overlap / cyl_total score = 0 # === 階段 1:首要目標是找到骨頭(overlap > 0) === if overlap == 0: # 完全沒有 overlap 是最糟糕的情況 score -= 500000 # 超大懲罰 # 如果連 spine 都沒穿過,更糟 if intersections == 0: score -= 500000 return float(-score) # === 階段 2:有找到骨頭了,開始改善品質 === # 1. Overlap 獎勵(非線性,鼓勵快速提升) if overlap_ratio < 0.1: # 0-10%:每增加 1% 給大量獎勵(鼓勵探索) score += overlap * 5000 # 很高的單位獎勵 elif overlap_ratio < 0.3: # 10-30%:中等獎勵 score += overlap * 3000 elif overlap_ratio < 0.5: # 30-50%:正常獎勵 score += overlap * 2000 else: # 50%+:獎勵 + 額外比例獎勵 score += overlap * 2000 score += (overlap_ratio - 0.5) * 100000 # 超過 50% 額外大獎 # 2. Intersection 控制(稍微放寬) if intersections == 1: score += 20000 # 完美 elif intersections == 0: score -= 200000 # 嚴重錯誤(但比完全沒 overlap 好) elif intersections == 2: score -= 10000 # 可接受但不理想 else: score -= intersections * 15000 # 3. Null voxel 懲罰(漸進式) null_ratio = null_vox / cyl_total if overlap_ratio < 0.2: # 如果 overlap 還很少,對 null voxel 寬容一點 score -= null_vox * 300 elif overlap_ratio < 0.4: score -= null_vox * 600 else: # overlap 夠高了,開始嚴格要求 if null_ratio > 0.5: score -= null_vox * 1500 else: score -= null_vox * 800 # 4. 反向圓柱懲罰 score -= null_vox2 * 1000 # 5. 尺寸合理性(放寬) if diameter is not None and length is not None: if diameter < 2.5 or diameter > 6.0: # 放寬從 (3.0, 5.5) 到 (2.5, 6.0) score -= 3000 if length < 25 or length > 60: # 放寬從 (30, 55) 到 (25, 60) score -= 3000 # 6. 尖端 breach 懲罰 if cylinder_tip_torch is not None: tip_total = cylinder_tip_torch.sum().item() if tip_total > 0: tip_breach = ((cortical_tensor == 0) & (cylinder_tip_torch == 1)).sum().item() tip_breach_ratio = tip_breach / tip_total if tip_breach_ratio > 0: score -= tip_breach * 5000 # 尖端出界懲罰要比一般 null_vox 重很多 return float(-score) def get_overlap_ratio( position_params: list, diameter: float, length: float, cortical_tensor: torch.Tensor, image_shape: tuple, spacing: list, device: torch.device, grid=None ) -> float: """ 計算 Cylinder 與 Cortical Bone 的重疊比例 (%) """ # 生成 Cylinder Mask cyl_mask = generate_cylinder_n_torch( diameter, length, position_params[0], position_params[1], position_params[2], position_params[3], position_params[4], image_shape, spacing, device, grid ) # 計算體積 (Voxel count) cyl_vol = torch.sum(cyl_mask).item() if cyl_vol == 0: return 0.0 # 計算重疊部分 # 注意:這裡使用 cortical_tensor (與 cl_score_torch 邏輯一致) overlap_count = ((cortical_tensor == 1) & (cyl_mask == 1)).sum().item() return (overlap_count / cyl_vol) * 100.0 def compute_overlap_ratio_from_cylinder_mask(cyl_mask: torch.Tensor, spine_mask: torch.Tensor, eps: float = 1e-6) -> float: """ 一個常見定義: overlap = intersection / cylinder_volume 你也可以改成 intersection / spine_volume 或 Dice,依你論文/需求一致即可。 cyl_mask, spine_mask: uint8/bool tensor, same shape """ cyl = cyl_mask.bool() sp = spine_mask.bool() inter = (cyl & sp).sum().item() denom = cyl.sum().item() return float(inter) / float(denom + eps) def is_solution_ok(loss: float, overlap: float, overlap_thresh: float = OVERLAP_THRESH) -> bool: return (loss <= 0) and (overlap >= overlap_thresh)