在 PyTorch 中,repeat
方法用于沿指定的维度重复张量。这对于需要扩展张量以匹配特定形状或进行广播操作时非常有用。
repeat
方法的用法
tensor.repeat(*sizes)
方法接受一个或多个整数参数,表示每个维度的重复次数。其返回一个新的张量,其中原始张量的每个维度都根据提供的重复次数进行扩展。
示例
以下是一些示例,展示如何使用 repeat
方法:
1. 基本用法
import torch
# 创建一个张量
tensor = torch.tensor([1, 2, 3])
print(f"原始张量:\n{tensor}")
# 重复张量,沿第一个维度重复 2 次,沿第二个维度重复 3 次
repeated_tensor = tensor.repeat(2, 3)
print(f"重复后的张量:\n{repeated_tensor}")
输出:
原始张量:
tensor([1, 2, 3])
重复后的张量:
tensor([[1, 2, 3, 1, 2, 3, 1, 2, 3],
[1, 2, 3, 1, 2, 3, 1, 2, 3]])
在这个示例中,原始张量 [1, 2, 3]
沿第一个维度(行)重复 2 次,沿第二个维度(列)重复 3 次。
2. 重复多维张量
import torch
# 创建一个 2x2 的张量
tensor = torch.tensor([[1, 2], [3, 4]])
print(f"原始张量:\n{tensor}")
# 沿每个维度重复张量
repeated_tensor = tensor.repeat(2, 3)
print(f"重复后的张量:\n{repeated_tensor}")
输出:
原始张量:
tensor([[1, 2],
[3, 4]])
重复后的张量:
tensor([[1, 2, 1, 2, 1, 2],
[3, 4, 3, 4, 3, 4],
[1, 2, 1, 2, 1, 2],
[3, 4, 3, 4, 3, 4]])
在这个示例中,原始 2x2 张量沿第一个维度重复 2 次,沿第二个维度重复 3 次,生成一个 4x6 的张量。
3. 只重复一个维度
你也可以只沿一个维度重复张量,方法是将另一个维度的重复次数设为 1:
import torch
# 创建一个 2x2 的张量
tensor = torch.tensor([[1, 2], [3, 4]])
print(f"原始张量:\n{tensor}")
# 沿第一个维度重复 2 次,第二个维度不变
repeated_tensor_1 = tensor.repeat(2, 1)
print(f"沿第一个维度重复后的张量:\n{repeated_tensor_1}")
# 沿第二个维度重复 2 次,第一个维度不变
repeated_tensor_2 = tensor.repeat(1, 2)
print(f"沿第二个维度重复后的张量:\n{repeated_tensor_2}")
输出:
原始张量:
tensor([[1, 2],
[3, 4]])
沿第一个维度重复后的张量:
tensor([[1, 2],
[3, 4],
[1, 2],
[3, 4]])
沿第二个维度重复后的张量:
tensor([[1, 2, 1, 2],
[3, 4, 3, 4]])
在这些示例中,原始 2x2 张量沿一个维度重复 2 次,另一个维度保持不变。
总结
repeat
方法用于沿指定维度重复张量。- 它接受一个或多个整数参数,表示每个维度的重复次数。
- 通过重复操作,可以扩展张量以匹配特定形状或进行广播操作。
这些示例展示了如何使用 repeat
方法来重复和扩展张量,以满足不同的需求。