개발/딥러닝
Pytorch RNN(LSTM, GRU) Multi gpu 사용하기
ComEng
2019. 4. 5. 16:02
model = nn.DataParallel(model, dim=1).cuda()
이유는 모르겠지만, dim=1로 해야 잘 된다.
default로 두면 hidden size에서 에러가 난다.