2017-12-21 18 views
1

torch.utils.data.DataLoader 여러 개의 변형을 적용한 데이터 집합을 만들 때 여러 개의 torch.utils.data.DataLoader을 사용하려고합니다. 현재 코드는 대략PyTorch DataLoader

입니다.
d_transforms = [ 
    transforms.RandomHorizontalFlip(), 
    # Some other transforms... 
] 
loaders = [] 
for i in range(len(d_transforms)): 
    dataset = datasets.MNIST('./data', 
      train=train, 
      download=True, 
      transform=d_transforms[i] 
    loaders.append(
     DataLoader(dataset, 
      shuffle=True, 
      pin_memory=True, 
      num_workers=1) 
     ) 

이 코드는 작동하지만 속도는 매우 느립니다. 거의 모든 내 코드에서 시간이

x, y = next(iter(train_loaders[i])) 

같은 라인에 소요되는 kernprof 쇼 나는 이것이 내가, DataLoader의 여러 인스턴스를 사용하고 있다는 사실 때문이라고 생각하려고 자신의 노동자, 각 동일한 데이터 파일을 읽습니다.

내 질문은 무엇이 더 좋은 방법입니까? 이상적으로, torch.utils.data.DataSet을 서브 클래스 화하고을 샘플링 할 때 을 적용하고 싶은 변환을 지정합니다. 그러나 이것은 __getitem__이 인수를 취할 수 없기 때문에 가능하지 않습니다.

+0

네가 훌륭한 사람을 제안 할 수 있다면 가능하다. 앞서 말했듯이, 더 좋은 방법은 내가 찾고있는 것입니다. – Coolness

+0

하나에서 여러 개의 파생 데이터 세트를 만들고 일반화하려고합니다. – Coolness

+0

필자는 내 작업의 구체적인 내용에 대해 논하지 않고 주석 상자를 완전히 설명하지도 않습니다. 나는 이것이 그 질문과 어떻게 관련이 있는지 보지 못한다. – Coolness

답변

0

__getitem__은로드하려는 콘텐츠의 색인 인 인수를 취합니다. 예를 들면.

transform = transforms.Compose(
    [transforms.ToTensor(), 
    normalize]) 

class CountDataset(Dataset): 

def __init__(self, file,transform=None): 

    self.transform = transform 
    #self.vocab = vocab 
    with open(file,'rb') as f: 
     self.data = pickle.load(f) 
    self.y = self.data['answers'] 
    self.I = self.data['images'] 


def __len__(self): 
    return len(self.y) 

def __getitem__(self, idx): 
    img_name = self.I[idx] 
    label = self.y[Idx] 
    fname = '/'.join(img_name.split("/")[-2:]) #/train2014/xx.jpg 
    DIR = '/hdd/manoj/VQA/Images/mscoco/' 
    img_full_path = os.path.join(DIR,fname) 
    img = Image.open(img_full_path).convert("RGB") 
    img_tensor = self.transform(img.resize((224,224))) 
    return img_tensor,label 


testset = CountDataset(file = 'testdat.pkl', 
         transform = transform) 


testloader = DataLoader(testset, batch_size=32, 
         shuffle=False, num_workers=4) 

루프에서 데이터 로더를 호출하지 마십시오.