torch.utils.data について
PyTorchのtorch.utils.data
が意味不明だった。
調べたら多少詳しくなったので、記録。
扱うメソッドは、
torch.utils.data.TensorDataset
torch.utils.data.DataLoader
の二つ。
動機
PyTorch公式のVAEコードでは、MNIST画像の再構築を行うことができる。
画像入力部分を解読して、MNIST画像だけでなく任意の画像を入力・再構築できるようにしたい。
公式のコードは以下の様になっている。
train_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=True, download=True, transform=transforms.ToTensor()), batch_size=args.batch_size, shuffle=True, **kwargs)
ここで、torch.utils.data.DataLoader
にはdatasets.MNIST
によりロードされた三次元テンソルが渡されている。
datasets.MNIST
のイメージとしてはこんな感じ。
datasets.MNIST = [[[0., 0., ... , 0.], [0., 0., ... , 0.], ... [0., 0., ... , 0.]], [[0., 0., ... , 0.], [0., 0., ... , 0.], ... [0., 0., ... , 0.]], ... [[0., 0., ... , 0.], [0., 0., ... , 0.], ... [0., 0., ... , 0.]]]
画像枚数×28×28×1のデータ形式。
このdatasets.MNIST
の部分を、任意の画像データテンソルに置き換えることで、目的達成できそう。
torch.utils.data.TensorDataset
同じ要素数のテンソルを二つ渡すと、その組みを作ってくれる人。
生成されるのはTensorDataset
オブジェクト。
画像データが格納されている四次元テンソルと、そのラベルベクトルを用意しておく。
画像データ = [[[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], ... [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], ... [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]], ... [[[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]], ... [[0., 0., 0.], [0., 0., 0.], [0., 0., 0.]]]]
RGB画像を想定(チャンネル数3)。
ラベルベクトル = [1, 2, 7, 9, 2, 12, 5, ... ,9] (※適当です)
この二つを、一度torch.tensor()
に通しておく。(いらないかも?)
画像データ = torch.tensor(画像データ) ラベルベクトル = torch.tensor(ラベルベクトル)
最後に、torch.utils.data.TensorDataset
を使って一つのデータセットに統合する。
データセット = torch.utils.data.TensorDataset(画像データ, ラベルベクトル)
このデータセットは、TensorDataset
オブジェクト。
torch.utils.data.DataLoader
TensorDataset
オブジェクトを渡してバッチサイズを指定すると、iterableなtorch.utils.data.DataLoader
オブジェクトを返してくれる人。
PyTorch公式のVAEクラスは、nn.Module
を継承しているみたいだが、この型のデータであればモデルに渡すことができる。
使い方は簡単。
torch.utils.data.TensorDataset
オブジェクトを第一引数に渡して、バッチサイズを入れるだけ。
data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, **kwargs)
shuffle=True
とすることでデータセットの並びをシャッフルすることができる!
まとめ
model = VAE().to(device) とした後、 for batch_idx, (data, _) in enumerate(data_loader): (中略) recon_batch, mu, logvar = model(data) (中略) みたいになる(ガバガバ)
こんな感じにすることで、VAEに好きな画像データを入れることができる〜