CBT_project/core/optimizer_ori.py

604 lines
27 KiB
Python
Raw Permalink Normal View History

2026-04-10 05:25:27 +00:00
import time
from datetime import datetime
import SimpleITK as sitk
import torch
from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour
from config.constant import ALLOWED_DIAMETERS, ALLOWED_LENGTHS
from core.objective import objective_function
from pyswarm import pso
import core.objective # <--- 加入這行,讓我們可以直接操作 objective 模組
from core.cylinder import generate_cylinder_n_torch, snap_to_discrete_values, create_coordinate_grid
from core.scoring import compute_overlap_ratio_from_cylinder_mask, is_solution_ok
from config.constant import OVERLAP_THRESH
from visualization.res_plot_3d import res_plt_2_torch
def run_pso_torch(
label_str: str,
image1_path: str,
image2_path: str,
image3_path: str,
folder: str,
swarm_size: int,
max_iter: int,
spacing: list,
CBT: bool,
device: torch.device,
optimize_size: bool = True,
grid=None
):
"""
Main function to run PSO.
如果 optimize_size=Truediameter length 也會被最佳化
如果 optimize_size=False使用預設值向後兼容
"""
start_time = time.time()
# Use global references
global image1_array, image2_array, image2_shape, image3_array
global diameter, length # 這些現在只用於非最佳化模式
global spine_tensor, cortical_tensor, spine_roi_tensor
# Load images
image1 = sitk.ReadImage(image1_path)
image2 = sitk.ReadImage(image2_path)
image3 = sitk.ReadImage(image3_path)
image1_array = sitk.GetArrayFromImage(image1)
image2_array = sitk.GetArrayFromImage(image2)
image3_array = sitk.GetArrayFromImage(image3)
image2_shape = image2_array.shape
image_shape = image2_shape
# Move arrays to torch
cortical_tensor = torch.from_numpy(image1_array).to(device=device, dtype=torch.uint8)
spine_tensor = torch.from_numpy(image2_array).to(device=device, dtype=torch.uint8)
spine_roi_tensor = torch.from_numpy(image3_array).to(device=device, dtype=torch.uint8)
azi = azimuth_rotation(image2_path)
res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False)
alt = res['superior']['tilt_angle_deg']
# 設定基本的 bounds
if CBT == True:
z_bounds = (0, image_shape[0] - 1)
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
azimuth_bounds_l = ((95-azi), (145-azi))
azimuth_bounds_r = ((50-azi), (85-azi))
altitude_bounds = ((60-alt), (75-alt))
else:
z_bounds = (0, image_shape[0] - 1)
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
azimuth_bounds_l = (60-azi, 90-azi)
azimuth_bounds_r = (90-azi, 120-azi)
altitude_bounds = (65-alt, 80-alt)
def eval_overlap_from_position(pos, side: str, optimize_size: bool,
spine_tensor: torch.Tensor,
image_shape, spacing):
"""
根據 PSO 給的 position 生成 cylinder mask再算 overlap ratio
side: "L" or "R" 只是方便 debug
"""
if optimize_size:
d, L = snap_to_discrete_values(pos[5], pos[6])
params_5 = pos[:5]
else:
d, L = diameter, length
params_5 = pos
cyl_mask = generate_cylinder_n_torch(
d, L,
params_5[0], params_5[1], params_5[2],
params_5[3], params_5[4],
image_shape, spacing, device, grid
)
overlap = compute_overlap_ratio_from_cylinder_mask(cyl_mask, spine_tensor)
return overlap, d, L
if optimize_size:
# 模式 1優化 diameter 和 length
print("=== 最佳化模式:最佳化位置、角度、直徑和長度 ===")
# 設定 diameter 和 length 的 bounds連續範圍
diameter_bounds = (min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS))
length_bounds = (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS))
# bounds 現在有 7 個參數
lb_l = [z_bounds[0], y_bounds[0], x_bounds_left[0], azimuth_bounds_l[0],
altitude_bounds[0], diameter_bounds[0], length_bounds[0]]
ub_l = [z_bounds[1], y_bounds[1], x_bounds_left[1], azimuth_bounds_l[1],
altitude_bounds[1], diameter_bounds[1], length_bounds[1]]
lb_r = [z_bounds[0], y_bounds[0], x_bounds_right[0], azimuth_bounds_r[0],
altitude_bounds[0], diameter_bounds[0], length_bounds[0]]
ub_r = [z_bounds[1], y_bounds[1], x_bounds_right[1], azimuth_bounds_r[1],
altitude_bounds[1], diameter_bounds[1], length_bounds[1]]
else:
# 模式 2固定 diameter 和 length向後兼容
print("=== 固定尺寸模式:最佳化位置和角度 ===")
# 使用預設值(需要在調用時提供)
diameter = 4.5 # 或從參數傳入
length = 45 # 或從參數傳入
lb_l = [z_bounds[0], y_bounds[0], x_bounds_left[0], azimuth_bounds_l[0], altitude_bounds[0]]
ub_l = [z_bounds[1], y_bounds[1], x_bounds_left[1], azimuth_bounds_l[1], altitude_bounds[1]]
lb_r = [z_bounds[0], y_bounds[0], x_bounds_right[0], azimuth_bounds_r[0], altitude_bounds[0]]
ub_r = [z_bounds[1], y_bounds[1], x_bounds_right[1], azimuth_bounds_r[1], altitude_bounds[1]]
best_loss_l = float('inf')
best_loss_r = float('inf')
best_position_l = None
best_position_r = None
# Left side optimization
print("\n=== 左側 ===")
position_l, loss_l = pso(objective_function, lb_l, ub_l, swarmsize=swarm_size, maxiter=max_iter)
overlap_l, diameter_l, length_l = eval_overlap_from_position(
position_l, "L", optimize_size, spine_tensor, image_shape, spacing
)
print(f"[LEFT] overlap: {overlap_l*100:.1f}%")
if optimize_size:
print(f"[LEFT] Position: {position_l[:5]}")
print(f"[LEFT] Diameter: {diameter_l} mm (raw: {position_l[5]:.2f})")
print(f"[LEFT] Length: {length_l} mm (raw: {position_l[6]:.2f})")
best_position_l = list(position_l[:5]) + [diameter_l, length_l]
else:
print(f"[LEFT] Position: {position_l}")
best_position_l = position_l
best_loss_l = loss_l
best_overlap_l = overlap_l # 新增
max_retries = 10
retries = 0
# 左側 retryloss 要 <=0 且 overlap >= 0.5 才算過關
while (best_loss_l > 0 or best_overlap_l < OVERLAP_THRESH) and retries < max_retries:
position_l, loss_l = pso(objective_function, lb_l, ub_l, swarmsize=swarm_size, maxiter=max_iter)
overlap_l, diameter_l, length_l = eval_overlap_from_position(
position_l, "L", optimize_size, spine_tensor, image_shape, spacing
)
# 只要找到更好的 loss或你想用 loss+overlap 綜合排序也行)就更新 best
# 安全版本:優先選「合格解」;沒有合格解時才用 loss 最小的當備案
candidate_pos = (list(position_l[:5]) + [diameter_l, length_l]) if optimize_size else position_l
candidate_ok = is_solution_ok(loss_l, overlap_l, OVERLAP_THRESH)
best_ok = is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH)
if candidate_ok and (not best_ok or loss_l < best_loss_l):
best_position_l = candidate_pos
best_loss_l = loss_l
best_overlap_l = overlap_l
print(f"[LEFT][retry {retries+1}] ✅ ok | loss={loss_l:.4f}, overlap={overlap_l*100:.1f}%")
elif (not best_ok) and (loss_l < best_loss_l):
# best 還不合格時,先用更小 loss 的當暫存(至少越來越好)
best_position_l = candidate_pos
best_loss_l = loss_l
best_overlap_l = overlap_l
print(f"[LEFT][retry {retries+1}] ⚠️ not ok | loss improved={loss_l:.4f}, overlap={overlap_l*100:.1f}%")
else:
print(f"[LEFT][retry {retries+1}] ❌ no improve | loss={loss_l:.4f}, overlap={overlap_l*100:.1f}%")
retries += 1
# Right side optimization
print("\n=== 右側 ===")
position_r, loss_r = pso(objective_function, lb_r, ub_r, swarmsize=swarm_size, maxiter=max_iter)
overlap_r, diameter_r, length_r = eval_overlap_from_position(
position_r, "R", optimize_size, spine_tensor, image_shape, spacing
)
print(f"[RIGHT] overlap: {overlap_r*100:.1f}%")
if optimize_size:
diameter_r, length_r = snap_to_discrete_values(position_r[5], position_r[6])
print(f"[RIGHT] Position: {position_r[:5]}")
print(f"[RIGHT] Diameter: {diameter_r} mm (raw: {position_r[5]:.2f})")
print(f"[RIGHT] Length: {length_r} mm (raw: {position_r[6]:.2f})")
print(f"[RIGHT] Loss: {loss_r}\n")
best_position_r = list(position_r[:5]) + [diameter_r, length_r]
else:
print(f"[RIGHT] Position: {position_r}")
print(f"[RIGHT] Loss: {loss_r}\n")
best_position_r = position_r
best_loss_r = loss_r
best_overlap_r = overlap_r
# 如果需要 retryloss > 0
max_retries = 10
retries = 0
while (best_loss_r > 0 or best_overlap_r < OVERLAP_THRESH) and retries < max_retries:
position_r, loss_r = pso(objective_function, lb_r, ub_r, swarmsize=swarm_size, maxiter=max_iter)
overlap_r, diameter_r, length_r = eval_overlap_from_position(
position_r, "R", optimize_size, spine_tensor, image_shape, spacing
)
# 只要找到更好的 loss或你想用 loss+overlap 綜合排序也行)就更新 best
# 這裡給你一個更安全的版本:優先選「合格解」;沒有合格解時才用 loss 最小的當備案
candidate_pos = (list(position_r[:5]) + [diameter_r, length_r]) if optimize_size else position_r
candidate_ok = is_solution_ok(loss_r, overlap_r, OVERLAP_THRESH)
best_ok = is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH)
if candidate_ok and (not best_ok or loss_r < best_loss_r):
best_position_r = candidate_pos
best_loss_r = loss_r
best_overlap_r = overlap_r
print(f"[RIGHT][retry {retries+1}] ✅ ok | loss={loss_r:.4f}, overlap={overlap_r*100:.1f}%")
elif (not best_ok) and (loss_r < best_loss_r):
# best 還不合格時,先用更小 loss 的當暫存(至少越來越好)
best_position_r = candidate_pos
best_loss_r = loss_r
best_overlap_r = overlap_r
print(f"[RIGHT][retry {retries+1}] ⚠️ not ok | loss improved={loss_r:.4f}, overlap={overlap_r*100:.1f}%")
else:
print(f"[RIGHT][retry {retries+1}] ❌ no improve | loss={loss_r:.4f}, overlap={overlap_r*100:.1f}%")
retries += 1
end_time = time.time()
total_time = end_time - start_time
# 提取最終的 diameter 和 length
if optimize_size:
final_diameter_l = best_position_l[5]
final_length_l = best_position_l[6]
final_diameter_r = best_position_r[5]
final_length_r = best_position_r[6]
print(f"\n=== 最終結果 ===")
print(f"Left - Diameter: {final_diameter_l} mm, Length: {final_length_l} mm")
print(f"Right - Diameter: {final_diameter_r} mm, Length: {final_length_r} mm")
else:
final_diameter_l = diameter
final_length_l = length
final_diameter_r = diameter
final_length_r = length
res_plt_2_torch(
spine_tensor,
cortical_tensor,
image_shape,
image2_path,
'Output',
label_str,
final_diameter_l,
final_length_l,
final_diameter_r,
final_length_r,
best_position_l,
best_position_r,
swarm_size,
max_iter,
total_time,
spacing,
CBT,
device,
grid)
return best_position_l, best_loss_l, best_position_r, best_loss_r, total_time
import time
import numpy as np
import SimpleITK as sitk
import torch
from scipy.optimize import differential_evolution
from scipy.optimize import minimize
from imaging.orientation import azimuth_rotation, analyze_vertebral_tilt_contour
from config.constant import ALLOWED_DIAMETERS, ALLOWED_LENGTHS
from core.objective import objective_function
from core.cylinder import generate_cylinder_n_torch, snap_to_discrete_values, create_coordinate_grid
from core.scoring import compute_overlap_ratio_from_cylinder_mask, is_solution_ok
from config.constant import OVERLAP_THRESH
from visualization.res_plot_3d import res_plt_2_torch
def run_de_torch(
label_str: str,
image1_path: str,
image2_path: str,
image3_path: str,
folder: str,
swarm_size: int,
max_iter: int,
spacing: list,
CBT: bool,
device: torch.device,
optimize_size: bool = True,
grid=None
):
"""
使用 Differential Evolution (DE) 進行最佳化
"""
start_time = time.time()
global image1_array, image2_array, image2_shape, image3_array
global diameter, length
global spine_tensor, cortical_tensor, spine_roi_tensor
image1 = sitk.ReadImage(image1_path)
image2 = sitk.ReadImage(image2_path)
image3 = sitk.ReadImage(image3_path)
image1_array = sitk.GetArrayFromImage(image1)
image2_array = sitk.GetArrayFromImage(image2)
image3_array = sitk.GetArrayFromImage(image3)
image2_shape = image2_array.shape
image_shape = image2_shape
cortical_tensor = torch.from_numpy(image1_array).to(device=device, dtype=torch.uint8)
spine_tensor = torch.from_numpy(image2_array).to(device=device, dtype=torch.uint8)
spine_roi_tensor = torch.from_numpy(image3_array).to(device=device, dtype=torch.uint8)
azi = azimuth_rotation(image2_path)
res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False)
alt = res['superior']['tilt_angle_deg']
if CBT == True:
z_bounds = (0, image_shape[0] - 1)
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
azimuth_bounds_l = ((95-azi), (145-azi))
azimuth_bounds_r = ((50-azi), (85-azi))
altitude_bounds = ((60-alt), (75-alt))
else:
z_bounds = (0, image_shape[0] - 1)
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
azimuth_bounds_l = (60-azi, 90-azi)
azimuth_bounds_r = (90-azi, 120-azi)
altitude_bounds = (65-alt, 80-alt)
def eval_overlap_from_position(pos, side: str, optimize_size: bool, spine_tensor: torch.Tensor, image_shape, spacing):
if optimize_size:
d, L = snap_to_discrete_values(pos[5], pos[6])
params_5 = pos[:5]
else:
d, L = diameter, length
params_5 = pos
cyl_mask = generate_cylinder_n_torch(
d, L, params_5[0], params_5[1], params_5[2], params_5[3], params_5[4],
image_shape, spacing, device, grid
)
overlap = compute_overlap_ratio_from_cylinder_mask(cyl_mask, spine_tensor)
return overlap, d, L
if optimize_size:
print("=== DE 最佳化模式:最佳化位置、角度、直徑和長度 ===")
diameter_bounds = (min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS))
length_bounds = (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS))
bounds_l = [z_bounds, y_bounds, x_bounds_left, azimuth_bounds_l, altitude_bounds, diameter_bounds, length_bounds]
bounds_r = [z_bounds, y_bounds, x_bounds_right, azimuth_bounds_r, altitude_bounds, diameter_bounds, length_bounds]
else:
print("=== DE 固定尺寸模式:最佳化位置和角度 ===")
diameter = 4.5
length = 45
bounds_l = [z_bounds, y_bounds, x_bounds_left, azimuth_bounds_l, altitude_bounds]
bounds_r = [z_bounds, y_bounds, x_bounds_right, azimuth_bounds_r, altitude_bounds]
# DE 的 popsize 實際粒子數 = popsize * len(bounds)
# 為了跟 PSO 公平比較,我們讓它轉換一下
de_popsize = max(1, swarm_size // len(bounds_l))
# --- 左側最佳化 ---
print("\n=== 左側 (DE) ===")
res_l = differential_evolution(objective_function, bounds_l, popsize=de_popsize, maxiter=max_iter)
position_l, loss_l = res_l.x, res_l.fun
overlap_l, diameter_l, length_l = eval_overlap_from_position(position_l, "L", optimize_size, spine_tensor, image_shape, spacing)
best_position_l = list(position_l[:5]) + [diameter_l, length_l] if optimize_size else list(position_l)
best_loss_l, best_overlap_l = loss_l, overlap_l
retries = 0
while (best_loss_l > 0 or best_overlap_l < OVERLAP_THRESH) and retries < 10:
res_l = differential_evolution(objective_function, bounds_l, popsize=de_popsize, maxiter=max_iter)
position_l, loss_l = res_l.x, res_l.fun
overlap_l, diameter_l, length_l = eval_overlap_from_position(position_l, "L", optimize_size, spine_tensor, image_shape, spacing)
candidate_pos = (list(position_l[:5]) + [diameter_l, length_l]) if optimize_size else list(position_l)
if is_solution_ok(loss_l, overlap_l, OVERLAP_THRESH) and (not is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH) or loss_l < best_loss_l):
best_position_l, best_loss_l, best_overlap_l = candidate_pos, loss_l, overlap_l
elif (not is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH)) and (loss_l < best_loss_l):
best_position_l, best_loss_l, best_overlap_l = candidate_pos, loss_l, overlap_l
retries += 1
# --- 右側最佳化 ---
print("\n=== 右側 (DE) ===")
res_r = differential_evolution(objective_function, bounds_r, popsize=de_popsize, maxiter=max_iter)
position_r, loss_r = res_r.x, res_r.fun
overlap_r, diameter_r, length_r = eval_overlap_from_position(position_r, "R", optimize_size, spine_tensor, image_shape, spacing)
best_position_r = list(position_r[:5]) + [diameter_r, length_r] if optimize_size else list(position_r)
best_loss_r, best_overlap_r = loss_r, overlap_r
retries = 0
while (best_loss_r > 0 or best_overlap_r < OVERLAP_THRESH) and retries < 10:
res_r = differential_evolution(objective_function, bounds_r, popsize=de_popsize, maxiter=max_iter)
position_r, loss_r = res_r.x, res_r.fun
overlap_r, diameter_r, length_r = eval_overlap_from_position(position_r, "R", optimize_size, spine_tensor, image_shape, spacing)
candidate_pos = (list(position_r[:5]) + [diameter_r, length_r]) if optimize_size else list(position_r)
if is_solution_ok(loss_r, overlap_r, OVERLAP_THRESH) and (not is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH) or loss_r < best_loss_r):
best_position_r, best_loss_r, best_overlap_r = candidate_pos, loss_r, overlap_r
elif (not is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH)) and (loss_r < best_loss_r):
best_position_r, best_loss_r, best_overlap_r = candidate_pos, loss_r, overlap_r
retries += 1
total_time = time.time() - start_time
final_diameter_l = best_position_l[5] if optimize_size else diameter
final_length_l = best_position_l[6] if optimize_size else length
final_diameter_r = best_position_r[5] if optimize_size else diameter
final_length_r = best_position_r[6] if optimize_size else length
res_plt_2_torch(
spine_tensor, cortical_tensor, image_shape, image2_path, 'Output', label_str,
final_diameter_l, final_length_l, final_diameter_r, final_length_r,
best_position_l, best_position_r, swarm_size, max_iter, total_time, spacing, CBT, device, grid
)
return best_position_l, best_loss_l, best_position_r, best_loss_r, total_time
def run_nm_torch(
label_str: str,
image1_path: str,
image2_path: str,
image3_path: str,
folder: str,
swarm_size: int, # NM 不用 swarm_size但保留參數以維持介面統一
max_iter: int,
spacing: list,
CBT: bool,
device: torch.device,
optimize_size: bool = True,
grid=None
):
"""
使用 Nelder-Mead 進行最佳化
"""
start_time = time.time()
global image1_array, image2_array, image2_shape, image3_array
global diameter, length
global spine_tensor, cortical_tensor, spine_roi_tensor
image1 = sitk.ReadImage(image1_path)
image2 = sitk.ReadImage(image2_path)
image3 = sitk.ReadImage(image3_path)
image1_array = sitk.GetArrayFromImage(image1)
image2_array = sitk.GetArrayFromImage(image2)
image3_array = sitk.GetArrayFromImage(image3)
image2_shape = image2_array.shape
image_shape = image2_shape
cortical_tensor = torch.from_numpy(image1_array).to(device=device, dtype=torch.uint8)
spine_tensor = torch.from_numpy(image2_array).to(device=device, dtype=torch.uint8)
spine_roi_tensor = torch.from_numpy(image3_array).to(device=device, dtype=torch.uint8)
azi = azimuth_rotation(image2_path)
res = analyze_vertebral_tilt_contour(image2_path, edge_type='superior', show_plot=False, debug=False)
alt = res['superior']['tilt_angle_deg']
if CBT == True:
z_bounds = (0, image_shape[0] - 1)
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
azimuth_bounds_l = ((95-azi), (145-azi))
azimuth_bounds_r = ((50-azi), (85-azi))
altitude_bounds = ((60-alt), (75-alt))
else:
z_bounds = (0, image_shape[0] - 1)
y_bounds = (image_shape[1]/5, image_shape[1]/2 - 1)
x_bounds_left = (0, image_shape[2]/2 - image_shape[2]/10 - 1)
x_bounds_right = (image_shape[2]/2 + image_shape[2]/10, image_shape[2] - 1)
azimuth_bounds_l = (60-azi, 90-azi)
azimuth_bounds_r = (90-azi, 120-azi)
altitude_bounds = (65-alt, 80-alt)
def eval_overlap_from_position(pos, side: str, optimize_size: bool, spine_tensor: torch.Tensor, image_shape, spacing):
if optimize_size:
d, L = snap_to_discrete_values(pos[5], pos[6])
params_5 = pos[:5]
else:
d, L = diameter, length
params_5 = pos
cyl_mask = generate_cylinder_n_torch(
d, L, params_5[0], params_5[1], params_5[2], params_5[3], params_5[4],
image_shape, spacing, device, grid
)
overlap = compute_overlap_ratio_from_cylinder_mask(cyl_mask, spine_tensor)
return overlap, d, L
if optimize_size:
print("=== NM 最佳化模式 ===")
bounds_l = [z_bounds, y_bounds, x_bounds_left, azimuth_bounds_l, altitude_bounds,
(min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS)), (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS))]
bounds_r = [z_bounds, y_bounds, x_bounds_right, azimuth_bounds_r, altitude_bounds,
(min(ALLOWED_DIAMETERS), max(ALLOWED_DIAMETERS)), (min(ALLOWED_LENGTHS), max(ALLOWED_LENGTHS))]
else:
print("=== NM 固定尺寸模式 ===")
diameter, length = 4.5, 45
bounds_l = [z_bounds, y_bounds, x_bounds_left, azimuth_bounds_l, altitude_bounds]
bounds_r = [z_bounds, y_bounds, x_bounds_right, azimuth_bounds_r, altitude_bounds]
def get_random_x0(bounds):
# 產生在 Bounds 內的隨機起始點
return [np.random.uniform(b[0], b[1]) for b in bounds]
# --- 左側最佳化 ---
print("\n=== 左側 (Nelder-Mead) ===")
x0_l = get_random_x0(bounds_l)
res_l = minimize(objective_function, x0_l, method='Nelder-Mead', bounds=bounds_l, options={'maxiter': max_iter})
position_l, loss_l = res_l.x, res_l.fun
overlap_l, diameter_l, length_l = eval_overlap_from_position(position_l, "L", optimize_size, spine_tensor, image_shape, spacing)
best_position_l = list(position_l[:5]) + [diameter_l, length_l] if optimize_size else list(position_l)
best_loss_l, best_overlap_l = loss_l, overlap_l
retries = 0
while (best_loss_l > 0 or best_overlap_l < OVERLAP_THRESH) and retries < 10:
x0_l = get_random_x0(bounds_l) # 每次 retry 都換一個隨機起始點
res_l = minimize(objective_function, x0_l, method='Nelder-Mead', bounds=bounds_l, options={'maxiter': max_iter})
position_l, loss_l = res_l.x, res_l.fun
overlap_l, diameter_l, length_l = eval_overlap_from_position(position_l, "L", optimize_size, spine_tensor, image_shape, spacing)
candidate_pos = (list(position_l[:5]) + [diameter_l, length_l]) if optimize_size else list(position_l)
if is_solution_ok(loss_l, overlap_l, OVERLAP_THRESH) and (not is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH) or loss_l < best_loss_l):
best_position_l, best_loss_l, best_overlap_l = candidate_pos, loss_l, overlap_l
elif (not is_solution_ok(best_loss_l, best_overlap_l, OVERLAP_THRESH)) and (loss_l < best_loss_l):
best_position_l, best_loss_l, best_overlap_l = candidate_pos, loss_l, overlap_l
retries += 1
# --- 右側最佳化 ---
print("\n=== 右側 (Nelder-Mead) ===")
x0_r = get_random_x0(bounds_r)
res_r = minimize(objective_function, x0_r, method='Nelder-Mead', bounds=bounds_r, options={'maxiter': max_iter})
position_r, loss_r = res_r.x, res_r.fun
overlap_r, diameter_r, length_r = eval_overlap_from_position(position_r, "R", optimize_size, spine_tensor, image_shape, spacing)
best_position_r = list(position_r[:5]) + [diameter_r, length_r] if optimize_size else list(position_r)
best_loss_r, best_overlap_r = loss_r, overlap_r
retries = 0
while (best_loss_r > 0 or best_overlap_r < OVERLAP_THRESH) and retries < 10:
x0_r = get_random_x0(bounds_r)
res_r = minimize(objective_function, x0_r, method='Nelder-Mead', bounds=bounds_r, options={'maxiter': max_iter})
position_r, loss_r = res_r.x, res_r.fun
overlap_r, diameter_r, length_r = eval_overlap_from_position(position_r, "R", optimize_size, spine_tensor, image_shape, spacing)
candidate_pos = (list(position_r[:5]) + [diameter_r, length_r]) if optimize_size else list(position_r)
if is_solution_ok(loss_r, overlap_r, OVERLAP_THRESH) and (not is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH) or loss_r < best_loss_r):
best_position_r, best_loss_r, best_overlap_r = candidate_pos, loss_r, overlap_r
elif (not is_solution_ok(best_loss_r, best_overlap_r, OVERLAP_THRESH)) and (loss_r < best_loss_r):
best_position_r, best_loss_r, best_overlap_r = candidate_pos, loss_r, overlap_r
retries += 1
total_time = time.time() - start_time
final_diameter_l = best_position_l[5] if optimize_size else diameter
final_length_l = best_position_l[6] if optimize_size else length
final_diameter_r = best_position_r[5] if optimize_size else diameter
final_length_r = best_position_r[6] if optimize_size else length
res_plt_2_torch(
spine_tensor, cortical_tensor, image_shape, image2_path, 'Output', label_str,
final_diameter_l, final_length_l, final_diameter_r, final_length_r,
best_position_l, best_position_r, swarm_size, max_iter, total_time, spacing, CBT, device, grid
)
return best_position_l, best_loss_l, best_position_r, best_loss_r, total_time