torch.utils.data について

PyTorchのtorch.utils.dataが意味不明だった。

調べたら多少詳しくなったので、記録。

扱うメソッドは、

  • torch.utils.data.TensorDataset
  • torch.utils.data.DataLoader

の二つ。

動機

PyTorch公式のVAEコードでは、MNIST画像の再構築を行うことができる。

画像入力部分を解読して、MNIST画像だけでなく任意の画像を入力・再構築できるようにしたい。

公式のコードは以下の様になっている。

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=args.batch_size, shuffle=True, **kwargs)

ここで、torch.utils.data.DataLoaderにはdatasets.MNISTによりロードされた三次元テンソルが渡されている。

datasets.MNISTのイメージとしてはこんな感じ。

datasets.MNIST =
[[[0., 0., ... , 0.],
  [0., 0., ... , 0.],
  ...
  [0., 0., ... , 0.]],

 [[0., 0., ... , 0.],
  [0., 0., ... , 0.],
  ...
  [0., 0., ... , 0.]],

 ...

 [[0., 0., ... , 0.],
  [0., 0., ... , 0.],
  ...
  [0., 0., ... , 0.]]]

画像枚数×28×28×1のデータ形式

このdatasets.MNISTの部分を、任意の画像データテンソルに置き換えることで、目的達成できそう。

torch.utils.data.TensorDataset

同じ要素数テンソルを二つ渡すと、その組みを作ってくれる人。

生成されるのはTensorDatasetオブジェクト。

画像データが格納されている四次元テンソルと、そのラベルベクトルを用意しておく。

画像データ =
[[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
  [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
  ...
  [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],

 [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
   [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
   ...
   [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]],

 ...

 [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
  [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]],
   ...
  [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]]

RGB画像を想定(チャンネル数3)。

ラベルベクトル =
[1, 2, 7, 9, 2, 12, 5, ... ,9] (※適当です)

この二つを、一度torch.tensor()に通しておく。(いらないかも?)

画像データ = torch.tensor(画像データ)
ラベルベクトル = torch.tensor(ラベルベクトル)

最後に、torch.utils.data.TensorDatasetを使って一つのデータセットに統合する。

データセット = torch.utils.data.TensorDataset(画像データ, ラベルベクトル)

このデータセットは、TensorDatasetオブジェクト。

torch.utils.data.DataLoader

TensorDatasetオブジェクトを渡してバッチサイズを指定すると、iterableなtorch.utils.data.DataLoaderオブジェクトを返してくれる人。

PyTorch公式のVAEクラスは、nn.Moduleを継承しているみたいだが、この型のデータであればモデルに渡すことができる。

使い方は簡単。

torch.utils.data.TensorDatasetオブジェクトを第一引数に渡して、バッチサイズを入れるだけ。

data_loader = torch.utils.data.DataLoader(dataset,
                       batch_size=batch_size, shuffle=True, **kwargs)

shuffle=Trueとすることでデータセットの並びをシャッフルすることができる!

まとめ

model = VAE().to(device)

とした後、

for batch_idx, (data, _) in enumerate(data_loader):
 (中略)
   recon_batch, mu, logvar = model(data)
 (中略)

みたいになる(ガバガバ)

こんな感じにすることで、VAEに好きな画像データを入れることができる〜