Pytorch容器


1.torch.nn.Parameter()

  torch.nn.Parameter() 是 PyTorch 中的一个类,用于将张量包装成可训练的参数。
  在神经网络中,我们需要定义可训练的参数,例如模型的权重和偏置。torch.nn.Parameter() 允许我们将张量包装成一个特殊的参数对象,该对象会被注册为模型的一部分,并且可以自动进行梯度计算和更新。

import torch
import torch.nn as nn

# 定义一个线性模型类
class LinearModel(nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        # 创建一个可训练的参数
        self.weights = nn.Parameter(torch.randn(3, 2))
        self.bias = nn.Parameter(torch.zeros(3))

    def forward(self, x):
        # 使用参数进行线性变换
        return torch.matmul(x, self.weights) + self.bias

# 创建一个线性模型对象
model = LinearModel()

# 获取模型的参数
parameters = list(model.parameters())

print(parameters)
print("==============")
print(model.parameters)
[Parameter containing:
tensor([[-1.0570, -0.7752],
        [-1.1010,  0.0697],
        [ 1.3795,  0.5130]], requires_grad=True), Parameter containing:
tensor([0., 0., 0.], requires_grad=True)]
==============
<bound method Module.parameters of LinearModel()>

2.torch.nn.Module()

  torch.nn.Module() 是 PyTorch 中的一个基类,用于定义神经网络模型的基本结构。torch.nn.Module 是一个可扩展的类,用于构建神经网络模型。当我们定义自己的神经网络模型时,通常会继承 torch.nn.Module 类,并重写其中的方法,以定义模型的结构和前向传播逻辑。
代码如下(示例):

import torch
import torch.nn as nn

# 定义一个简单的全连接神经网络模型
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        # 定义模型的层
        self.fc1 = nn.Linear(10, 20)
        self.fc2 = nn.Linear(20, 10)

    def forward(self, x):
        # 定义模型的前向传播逻辑
        x = self.fc1(x)
        x = torch.relu(x)
        x = self.fc2(x)
        return x

# 创建一个模型实例
model = SimpleNet()

# 打印模型结构
print(model)
SimpleNet(
  (fc1): Linear(in_features=10, out_features=20, bias=True)
  (fc2): Linear(in_features=20, out_features=10, bias=True)
)

3.torch.nn.Sequential()

  torch.nn.Sequential() 是 PyTorch 中的一个类,用于按顺序组合多个层或模块,构建神经网络模型。torch.nn.Sequential 允许我们将多个层或模块按照顺序连接起来,形成一个串联的神经网络模型。我们可以通过传入层或模块的列表来定义网络的结构,或者在创建 Sequential 实例后使用 .add() 方法逐个添加层或模块。

import torch
import torch.nn as nn

# 使用 Sequential 定义一个简单的全连接神经网络模型
model = nn.Sequential(
    nn.Linear(10, 20),
    nn.ReLU(),
    nn.Linear(20, 10)
)

# 打印模型结构
print(model)
Sequential(
  (0): Linear(in_features=10, out_features=20, bias=True)
  (1): ReLU()
  (2): Linear(in_features=20, out_features=10, bias=True)
)

4.torch.nn.ModuleList()

torch.nn.ModuleList() 是 PyTorch 中的一个类,用于将多个模块组合成一个模块列表。torch.nn.ModuleList 提供了一种容器,可以用于存储和管理多个模块,类似于 Python 中的列表。与普通的 Python 列表不同,ModuleList 将被视为模型的一部分,并且在模型中注册为子模块,以便在模型的其他部分中使用。

import torch
import torch.nn as nn

# 定义一个模型类,其中包含多个线性层
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.layers = nn.ModuleList([
            nn.Linear(10, 20),
            nn.ReLU(),
            nn.Linear(20, 10)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

# 创建一个模型实例
model = MyModel()

# 打印模型结构
print(model)
MyModel(
  (layers): ModuleList(
    (0): Linear(in_features=10, out_features=20, bias=True)
    (1): ReLU()
    (2): Linear(in_features=20, out_features=10, bias=True)
  )
)

5.torch.nn.ModuleDict()

torch.nn.ModuleDict() 是 PyTorch 中的一个类,用于将多个模块组合成一个模块字典。torch.nn.ModuleDict 提供了一种容器,可以用于存储和管理多个模块,并按照键值对的形式进行访问。类似于 Python 中的字典,ModuleDict 允许通过键来访问存储的模块。

import torch
import torch.nn as nn

# 定义一个模型类,其中包含多个线性层,并以字典形式存储
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.layers = nn.ModuleDict({
            'linear1': nn.Linear(10, 20),
            'relu': nn.ReLU(),
            'linear2': nn.Linear(20, 10)
        })

    def forward(self, x):
        x = self.layers['linear1'](x)
        x = self.layers['relu'](x)
        x = self.layers['linear2'](x)
        return x

# 创建一个模型实例
model = MyModel()

# 打印模型结构
print(model)
MyModel(
  (layers): ModuleDict(
    (linear1): Linear(in_features=10, out_features=20, bias=True)
    (relu): ReLU()
    (linear2): Linear(in_features=20, out_features=10, bias=True)
  )
)

6.torch.nn.ParameterList()

torch.nn.ParameterList() 是 PyTorch 中的一个类,用于将多个参数(torch.nn.Parameter)组合成一个参数列表。torch.nn.ParameterList 提供了一种容器,可以用于存储和管理多个参数,并按照列表的形式进行访问。它通常用于将一组可学习的参数组织在一起,以便在模型中使用或进行优化。

import torch
import torch.nn as nn

# 定义一个模型类,其中包含多个可学习的参数并以列表形式存储
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.parameters = nn.ParameterList([
            nn.Parameter(torch.randn(10, 20)),
            nn.Parameter(torch.randn(20, 10))
        ])

    def forward(self, x):
        # 在前向传播中使用参数
        return torch.matmul(x, self.parameters[0]) + self.parameters[1]

# 创建一个模型实例
model = MyModel()

# 打印模型结构
print(model)
MyModel(
  (parameters): ParameterList(
      (0): Parameter containing: [torch.float32 of size 10x20]
      (1): Parameter containing: [torch.float32 of size 20x10]
  )
)

7.torch.nn.ParameterDict()

torch.nn.ParameterDict() 是 PyTorch 中的一个类,用于将多个参数(torch.nn.Parameter)组合成一个参数字典。

torch.nn.ParameterDict 提供了一种容器,可以用于存储和管理多个参数,并按照键值对的形式进行访问。类似于 Python 中的字典,ParameterDict 允许通过键来访问存储的参数。

import torch
import torch.nn as nn

# 定义一个模型类,其中包含多个可学习的参数,并以字典形式存储
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.parameters = nn.ParameterDict({
            'weight1': nn.Parameter(torch.randn(10, 20)),
            'weight2': nn.Parameter(torch.randn(20, 10))
        })

    def forward(self, x):
        # 在前向传播中使用参数
        return torch.matmul(x, self.parameters['weight1']) + self.parameters['weight2']

# 创建一个模型实例
model = MyModel()

# 打印模型结构
print(model)
MyModel(
  (parameters): ParameterDict(
      (weight1): Parameter containing: [torch.FloatTensor of size 10x20]
      (weight2): Parameter containing: [torch.FloatTensor of size 20x10]
  )
)

相关推荐

  1. Pytorch容器

    2024-06-11 19:24:03       16 阅读
  2. PyTorchPyTorch之包装容器

    2024-06-11 19:24:03       28 阅读
  3. <span style='color:red;'>Pytorch</span>

    Pytorch

    2024-06-11 19:24:03      36 阅读
  4. PyTorch

    2024-06-11 19:24:03       35 阅读
  5. PytorchPytorch入门基础

    2024-06-11 19:24:03       20 阅读
  6. python -- 容器

    2024-06-11 19:24:03       46 阅读
  7. 【Spring】容器

    2024-06-11 19:24:03       45 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-06-11 19:24:03       5 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-06-11 19:24:03       5 阅读
  3. 在Django里面运行非项目文件

    2024-06-11 19:24:03       4 阅读
  4. Python语言-面向对象

    2024-06-11 19:24:03       7 阅读

热门阅读

  1. Unity 数据存储

    2024-06-11 19:24:03       17 阅读
  2. Data Management Controls

    2024-06-11 19:24:03       15 阅读
  3. 【AI应用探讨】— Gemini模型应用场景

    2024-06-11 19:24:03       17 阅读
  4. 设计模式---工厂模式

    2024-06-11 19:24:03       15 阅读
  5. C++经典150题

    2024-06-11 19:24:03       17 阅读
  6. k8s 小技巧: 查看 Pod 上运行的容器

    2024-06-11 19:24:03       16 阅读
  7. Elasticsearch 认证模拟题 - 9

    2024-06-11 19:24:03       16 阅读
  8. 深度解读 ChatGPT基本原理

    2024-06-11 19:24:03       12 阅读
  9. 接口interface

    2024-06-11 19:24:03       17 阅读
  10. 使用redis构建简单的社交网站

    2024-06-11 19:24:03       14 阅读