pytorch之repeate()用法
pytorch之repeate()用法
- 当repeate参数个数和tensor的形状个数一样时,每个参数分别表示对应维度复制的次数
- 当参数不一样时,首先在第0维扩展一个维度,维数为1,然后按照参数指定的次数进行复制
1 | import torch |
torch.Size([2, 3])
1 | # repeat参数比维度多,在扩展前先讲a的形状扩展为(1,2,3)然后复制 |
torch.Size([1, 4, 3])
1 | c = a.unsqueeze(1) |
torch.Size([2, 1, 3])
1 | d = c.repeat(1,2,1) |
torch.Size([2, 2, 3])
Reference
All articles in this blog are licensed under CC BY-NC-SA 4.0 unless stating additionally.
Comment