PyTorch: как использовать DataLoaders для пользовательских наборов данных

как использовать torch.utils.data.Dataset и torch.utils.data.DataLoader по вашим собственным данным (не только torchvision.datasets)?

есть ли способ использовать встроенный DataLoaders, который они используют на TorchVisionDatasets для использования в любом наборе данных?

2 ответов


Да, это возможно. Просто создайте объекты самостоятельно, например

import torch.utils.data as data_utils

train = data_utils.TensorDataset(features, targets)
train_loader = data_utils.DataLoader(train, batch_size=50, shuffle=True)

здесь features и targets являются тензоры. features должно быть 2-D, т. е. Матрица, где каждая строка представляет собой один обучающий образец, и targets может быть 1-D или 2-D, в зависимости от того, пытаетесь ли вы предсказать скаляр или вектор.

надеюсь, что это поможет!


редактировать: ответ на вопрос @sarthak

в основном да. Если вы создайте объект типа TensorData, затем конструктор исследует, являются ли первые измерения тензора объектов (который фактически называется data_tensor) и целевой тензор (называемый target_tensor) имеют одинаковую длину:

assert data_tensor.size(0) == target_tensor.size(0)

однако, если вы хотите передать эти данные в нейронную сеть впоследствии, вам нужно быть осторожным. В то время как слои свертки работают с данными, подобными вашим, (Я думаю), все другие типы слоев ожидают, что данные будут даны в матричной форме. Итак, если вы сталкиваетесь с такой проблемой, тогда простым решением было бы преобразовать ваш 4D-dataset (заданный как какой-то тензор, например FloatTensor) в матрицу с помощью метода view. Для вашего набора данных 5000xnxnx3, это будет выглядеть так:

2d_dataset = 4d_dataset.view(5000, -1)

(стоимость -1 говорит PyTorch, чтобы выяснить длину второго измерения автоматически.)


вы можете легко сделать это, расширяя data.Dataset класса. Согласно API, все, что вам нужно сделать, это реализовать две функции: __getitem__ и __len__.

затем вы можете обернуть набор данных с помощью DataLoader, как показано в API и в ответе @pho7.

Я думаю ImageFolder класс-это ссылка. См. код здесь.