常见函数
torch.topk()
- 作用
取一个 tensor 的 topk 元素,返回值为降序后的前 k 个大小的元素值及索引 - 使用方法
- dim=0 表示按照列求 topn
- dim=1 表示按照行求 topn
- 默认情况下,dim=1
示例
1
2
3
4
5>>> x = torch.arange(1., 6.)
>>> x
tensor([ 1., 2., 3., 4., 5.])
>>> torch.topk(x, 3)
torch.return_types.topk(values=tensor([5., 4., 3.]), indices=tensor([4, 3, 2]))
torch.unique()
torch.unique()的功能类似于数学中的集合,就是挑出 tensor 中的独立不重复元素。
这个方法的参数在官方解释文档中有这么几个:torch.unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None)
input: 待处理的 tensor
sorted:是否对返回的无重复张量按照数值进行排列,默认是生序排列的
return_inverse: 是否返回原始 tensor 中的每个元素在这个无重复张量中的索引
return_counts: 统计原始张量中每个独立元素的个数
dim: 值沿着哪个维度进行 unique 的处理,这个我试验后没有搞懂怎样的机理。如果处理的张量都是一维的,那么这个不需要理会。
下面分别对这些不同的参数进行实验讲解分析。
1 | |
torch.nonzero
torch.nonzero() 是 PyTorch 中的一个函数,用于获取张量中非零元素的索引。这个函数返回一个二维张量,其中每一行都包含输入张量中一个非零元素的索引。
这个函数的语法如下:
1 | |
其中,input 是输入张量。
让我们来看一个例子,假设我们有一个形状为 (3, 3) 的张量 x:
1 | |
现在,我们可以使用 torch.nonzero() 来获取 x 中非零元素的索引:
1 | |
输出结果如下:
1 | |
可以看到,indices 是一个形状为 (4, 2) 的张量,其中每一行都是 x 中一个非零元素的索引。
需要注意的是,torch.nonzero() 返回的索引是按照行优先顺序排列的,也就是说,它首先返回第一行的非零元素的索引,然后返回第二行的,依此类推。
在 PyTorch 中,scatter() 和 scatter_() 函数通常用于在特定维度上根据索引更新张量(tensor)的值。这两个函数的主要区别在于它们是否原地(in-place)修改输入张量。
torch.where 是 PyTorch 中的一个条件选择函数,常用于根据给定条件在两个张量之间进行元素级选择。
torch.where
基本语法
1 | |
condition:布尔张量,元素值为True的位置选择x,为False的位置选择y。x:当condition为True时使用的值或张量。y:当condition为False时使用的值或张量。
使用示例
选择性替换元素
1 | |
a > 3生成布尔张量[False, False, False, True, True]- 只有
a中大于 3 的元素被保留,其他地方使用b的元素。
应用于多维张量
1 | |
- 负数用
B中对应元素替换,其他保持A。
仅提供 condition(索引操作)
如果只提供 condition,torch.where 会返回满足条件的索引。
1 | |
indices代表行索引和列索引,可以用于索引x[indices],取出满足条件的元素。
1 | |
scatter()
scatter() 函数根据提供的索引将源张量的值分散到目标张量中。它不会修改源张量或目标张量本身(即原地操作)。
函数签名:
1 | |
input(Tensor): 目标张量。dim(int): 沿其分散的维度。index(LongTensor): 索引张量,其形状必须与src的形状在dim维度之外的其他所有维度上都匹配。src(Tensor): 源张量,其形状必须与input在dim维度之外的其他所有维度上都匹配。out(Tensor, optional): 输出张量。
官方示例
三维示例
1 | |
二维示例
1 | |
1 | |
那么这个函数有什么作用呢?其实可以利用这个功能将 pytorch 中 mini batch 中的返回的 label(特指[ 1,0,4,9 ],即 size 为[4]这样的 label)转为 one-hot 类型的 label,举例子如下:
1 | |
上述的这个例子假设是一个分类问题,我设置 out_planes=6,是假设总共有 6 类,mini_batch 是我们送入的网络的每个 mini_batch 的样本数量,这里我们不设置网络,直接假设网络的输出为一个随机的张量 ,通常我们要对这个输出进行 softmax 归一化,此时就代表着其属于每个类别的概率了。说到这里都不是重点,就是为了方便理解如何使用 scatter,将 size 为[mini_batch]的张量,转为 size 为[mini_batch, out_palnes]的张量,并且这个生成的张量的每个行向量都是 one-hot 类型的了。通过看下面的输出结果就完全能够理解了。
1 | |
scatter_()
scatter_() 函数与 scatter() 类似,但它会原地修改目标张量(即它会修改 input 张量本身)。
函数签名:
1 | |
dim(int): 沿其分散的维度。index(LongTensor): 索引张量。src(Tensor): 源张量。



