200 lines
6.4 KiB
Python
200 lines
6.4 KiB
Python
|
|
import os
|
||
|
|
import random
|
||
|
|
|
||
|
|
from PIL import Image, ImageFilter, ImageMath
|
||
|
|
from scipy import ndimage
|
||
|
|
|
||
|
|
import numpy as np
|
||
|
|
import torch
|
||
|
|
|
||
|
|
PATCH_SIZE = 256
|
||
|
|
|
||
|
|
# JPGDIR = '/media/nfs/SRS/IMPAX/'
|
||
|
|
# JPGDIR = '/shares/Public/IMPAX/'
|
||
|
|
|
||
|
|
def img_frombytes(data):
|
||
|
|
size = data.shape[::-1]
|
||
|
|
databytes = np.packbits(data, axis=1)
|
||
|
|
return Image.frombytes(mode='1', size=size, data=databytes)
|
||
|
|
|
||
|
|
|
||
|
|
def getpatch(width, height):
|
||
|
|
w = random.randint(0, width-1)//PATCH_SIZE * PATCH_SIZE
|
||
|
|
if w > width - PATCH_SIZE:
|
||
|
|
w = width - PATCH_SIZE
|
||
|
|
|
||
|
|
h = random.randint(0, height-1)//PATCH_SIZE * PATCH_SIZE
|
||
|
|
if h > height - PATCH_SIZE:
|
||
|
|
h = height - PATCH_SIZE
|
||
|
|
|
||
|
|
return w, h
|
||
|
|
|
||
|
|
class IMPAXDataset(object):
|
||
|
|
|
||
|
|
def __init__(self, JPGDIR):
|
||
|
|
# self.root = root
|
||
|
|
# self.transforms = transforms
|
||
|
|
# load all image files, sorting them to
|
||
|
|
# ensure that they are aligned
|
||
|
|
|
||
|
|
self.ST_90 = []
|
||
|
|
self.ST_100 = []
|
||
|
|
self.ST_AN = []
|
||
|
|
self.ST_TXT = []
|
||
|
|
|
||
|
|
|
||
|
|
|
||
|
|
self.MAXSHAPE = None
|
||
|
|
self.MAXSIZE = 0
|
||
|
|
self.MINSHAPE = None
|
||
|
|
self.MINSIZE = 9999 * 9999
|
||
|
|
|
||
|
|
self.gets = 0
|
||
|
|
|
||
|
|
for pid in sorted(os.listdir(JPGDIR)):
|
||
|
|
PATDIR = os.path.join(JPGDIR, pid)
|
||
|
|
for study in sorted(os.listdir(PATDIR)):
|
||
|
|
if study.endswith('_100'):
|
||
|
|
ST100_DIR = os.path.join(PATDIR, study)
|
||
|
|
TXT_DIR = ST100_DIR.replace('_100', '_TXT')
|
||
|
|
os.makedirs(TXT_DIR, exist_ok=True)
|
||
|
|
for jpg in sorted(os.listdir(ST100_DIR)):
|
||
|
|
jpg_path = os.path.join(ST100_DIR, jpg)
|
||
|
|
txt_path = jpg_path.replace('_100', '_TXT').replace('.jpg', '.png')
|
||
|
|
|
||
|
|
self.ST_100.append(jpg_path)
|
||
|
|
self.ST_90.append(jpg_path.replace('_100', '_90'))
|
||
|
|
self.ST_AN.append(jpg_path.replace('_100', '_AN'))
|
||
|
|
self.ST_TXT.append(txt_path)
|
||
|
|
|
||
|
|
if os.path.isfile(txt_path):
|
||
|
|
continue
|
||
|
|
|
||
|
|
img = Image.open(jpg_path).convert('L')
|
||
|
|
|
||
|
|
width, height = img.size
|
||
|
|
size = width * height
|
||
|
|
|
||
|
|
if self.MAXSIZE < size:
|
||
|
|
self.MAXSIZE = size
|
||
|
|
self.MAXSHAPE = width, height
|
||
|
|
|
||
|
|
if self.MINSIZE > size:
|
||
|
|
self.MINSIZE = size
|
||
|
|
self.MINSHAPE = width, height
|
||
|
|
|
||
|
|
if os.path.isfile(txt_path):
|
||
|
|
continue
|
||
|
|
|
||
|
|
jpg_ndarray = np.array(img)
|
||
|
|
|
||
|
|
# CC = (0xCB <= jpg_ndarray <= 0xCD)
|
||
|
|
CC = np.logical_and(jpg_ndarray >= 0xCB, jpg_ndarray <= 0xCD)
|
||
|
|
C0 = (jpg_ndarray <= 0x01)
|
||
|
|
|
||
|
|
MASK = np.logical_or(CC, C0)
|
||
|
|
MASK = np.roll(MASK, -1, 0)
|
||
|
|
MASK = np.roll(MASK, -1, 1)
|
||
|
|
|
||
|
|
# MASKED = np.logical_and(CC, MASK).astype('uint8') * 255
|
||
|
|
MASKED = np.logical_and(CC, MASK).astype('uint8')
|
||
|
|
FILTERD = ndimage.rank_filter(MASKED, rank=-2, size=3)
|
||
|
|
FILTERD = np.minimum(MASKED, FILTERD)
|
||
|
|
im = img_frombytes(FILTERD)
|
||
|
|
im.save (txt_path)
|
||
|
|
|
||
|
|
if self.MINSHAPE:
|
||
|
|
print(self.MINSHAPE)
|
||
|
|
if self.MAXSHAPE:
|
||
|
|
print(self.MAXSHAPE)
|
||
|
|
|
||
|
|
|
||
|
|
def __getitem__(self, idx):
|
||
|
|
|
||
|
|
# self.gets += 1
|
||
|
|
# print(self.gets)
|
||
|
|
|
||
|
|
st_90 = Image.open(self.ST_90[idx]).convert('L')
|
||
|
|
st_AN = Image.open(self.ST_AN[idx]).convert('L')
|
||
|
|
st_TX = Image.open(self.ST_TXT[idx]).convert('L')
|
||
|
|
|
||
|
|
width, height = st_90.size
|
||
|
|
# print(idx, ST_90[idx])
|
||
|
|
w, h = getpatch(width, height)
|
||
|
|
# print(w, h)
|
||
|
|
|
||
|
|
s2_90 = np.array(st_90)[np.newaxis, h:h+PATCH_SIZE, w:w+PATCH_SIZE]
|
||
|
|
|
||
|
|
s2_AN = np.array(st_AN)[h:h+PATCH_SIZE, w:w+PATCH_SIZE]
|
||
|
|
s2_TX = np.array(st_TX)[h:h+PATCH_SIZE, w:w+PATCH_SIZE]
|
||
|
|
s2_AN_TX = np.stack( (s2_AN,s2_TX) )
|
||
|
|
|
||
|
|
# print(s2_90.shape, s2_AN_TX.shape)
|
||
|
|
# exit()
|
||
|
|
|
||
|
|
# print(s2_90)
|
||
|
|
# exit()
|
||
|
|
|
||
|
|
# return s2_90, s2_AN
|
||
|
|
# return s2_90[np.newaxis, :, :], s2_AN[np.newaxis, :, :]
|
||
|
|
return torch.from_numpy(s2_90).float(), torch.from_numpy(s2_AN_TX).float()
|
||
|
|
|
||
|
|
|
||
|
|
# load images ad masks
|
||
|
|
img_path = os.path.join(self.root, "PNGImages", self.imgs[idx])
|
||
|
|
mask_path = os.path.join(self.root, "PedMasks", self.masks[idx])
|
||
|
|
img = Image.open(img_path).convert("RGB")
|
||
|
|
# note that we haven't converted the mask to RGB,
|
||
|
|
# because each color corresponds to a different instance
|
||
|
|
# with 0 being background
|
||
|
|
mask = Image.open(mask_path)
|
||
|
|
# convert the PIL Image into a numpy array
|
||
|
|
mask = np.array(mask)
|
||
|
|
# instances are encoded as different colors
|
||
|
|
obj_ids = np.unique(mask)
|
||
|
|
# first id is the background, so remove it
|
||
|
|
obj_ids = obj_ids[1:]
|
||
|
|
|
||
|
|
# split the color-encoded mask into a set
|
||
|
|
# of binary masks
|
||
|
|
masks = mask == obj_ids[:, None, None]
|
||
|
|
|
||
|
|
# get bounding box coordinates for each mask
|
||
|
|
num_objs = len(obj_ids)
|
||
|
|
boxes = []
|
||
|
|
for i in range(num_objs):
|
||
|
|
pos = np.where(masks[i])
|
||
|
|
xmin = np.min(pos[1])
|
||
|
|
xmax = np.max(pos[1])
|
||
|
|
ymin = np.min(pos[0])
|
||
|
|
ymax = np.max(pos[0])
|
||
|
|
boxes.append([xmin, ymin, xmax, ymax])
|
||
|
|
|
||
|
|
# convert everything into a torch.Tensor
|
||
|
|
boxes = torch.as_tensor(boxes, dtype=torch.float32)
|
||
|
|
# there is only one class
|
||
|
|
labels = torch.ones((num_objs,), dtype=torch.int64)
|
||
|
|
masks = torch.as_tensor(masks, dtype=torch.uint8)
|
||
|
|
|
||
|
|
image_id = torch.tensor([idx])
|
||
|
|
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
|
||
|
|
# suppose all instances are not crowd
|
||
|
|
iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
|
||
|
|
|
||
|
|
target = {}
|
||
|
|
target["boxes"] = boxes
|
||
|
|
target["labels"] = labels
|
||
|
|
target["masks"] = masks
|
||
|
|
target["image_id"] = image_id
|
||
|
|
target["area"] = area
|
||
|
|
target["iscrowd"] = iscrowd
|
||
|
|
|
||
|
|
if self.transforms is not None:
|
||
|
|
img, target = self.transforms(img, target)
|
||
|
|
|
||
|
|
return img, target
|
||
|
|
|
||
|
|
def __len__(self):
|
||
|
|
return len(self.ST_100)
|
||
|
|
|