import torch def get_device(gpu_id=0): if torch.cuda.is_available(): device = torch.device(f"cuda:{gpu_id}") print(f"Using GPU {gpu_id}: {torch.cuda.get_device_name(gpu_id)}") else: device = torch.device("cpu") print("CUDA not available, using CPU") return device