pytorch之repeate()用法

  • 当repeate参数个数和tensor的形状个数一样时,每个参数分别表示对应维度复制的次数
  • 当参数不一样时,首先在第0维扩展一个维度,维数为1,然后按照参数指定的次数进行复制
1
2
3
4
import torch
a = torch.tensor([[1, 2, 3],
[1, 2, 3]])
a.shape
torch.Size([2, 3])
1
2
3
# repeat参数比维度多,在扩展前先讲a的形状扩展为(1,2,3)然后复制
b = a.repeat(1, 2, 1)
print(b.shape) # 得到结果torch.Size([1, 4, 3])
torch.Size([1, 4, 3])
1
2
c = a.unsqueeze(1)
print(c.shape)
torch.Size([2, 1, 3])
1
2
d = c.repeat(1,2,1)
print(d.shape)
torch.Size([2, 2, 3])

Reference

  1. pytorch中repeat()函数理解