개발/딥러닝

Pytorch RNN(LSTM, GRU) Multi gpu 사용하기

ComEng 2019. 4. 5. 16:02
model = nn.DataParallel(model, dim=1).cuda()

이유는 모르겠지만, dim=1로 해야 잘 된다.
default로 두면 hidden size에서 에러가 난다.