BiLSTM模型实现

# 本段代码构建类BiLSTM, 完成初始化和网络结构的搭建
# 总共3层: 词嵌入层, 双向LSTM层, 全连接线性层

# 本段代码构建类BiLSTM, 完成初始化和网络结构的搭建
# 总共3层: 词嵌入层, 双向LSTM层, 全连接线性层
import torch
import torch.nn as nn

# 本函数实现将中文文本映射为数字化张量
def sentence_map(sentence_list, char_to_id, max_length):
    """
    将句子中的每一个字符映射到码表中
    :param sentence_list: 待映射的句子,类型为字符串或列表
    :param char_to_id: 码表,类型为字典,格式为格式为{"字1": 1, "字2": 2},例如:
            码表与id对照:char_to_id = {"双": 0, "肺": 1, "见": 2, "多": 3, "发": 4, "斑": 5, "片": 6,
                                      "状": 7, "稍": 8, "高": 9, "密": 10, "度": 11, "影": 12, "。": 13}
    :param max_length:
    :return: 每一个字对应的编码,类型为tensor
    """
    # 字符串按照逆序进行排序,不是必须操作
    sentence_list.sort(key=lambda c:len(c), reverse = True)
    # 定义句子映射列表
    sentence_map_list = []
    for sentence in sentence_list:
        # 生成句子中每个字对应的id列表
        sentence_id_list =[char_to_id[c] for c in sentence]
        # 计算所要填充0的长度
        padding = [0] * (max_length-len(sentence))
        # 组合
        sentence_map_list.append(sentence_id_list)
    # 返回句子映射集合,转为标量
    return torch.tensor(sentence_map_list, dtype= torch.long)

class BiLSTM(nn.Module):
    """BiLSTM模型定义"""

    def __init__(self, vocab_size, tag_to_id, input_feature_size, hidden_size,
                 batch_size, sentence_length, num_layers=1, batch_first=True):
        """
        description: 模型初始化
        :param vocab_size:          所有句子包含字符大小
        :param tag_to_id:           标签与 id 对照
        :param input_feature_size:  字嵌入维度( 即LSTM输入层维度 input_size )
        :param hidden_size:         隐藏层向量维度
        :param batch_size:          批训练大小
        :param sentence_length      句子长度
        :param num_layers:          堆叠 LSTM 层数
        :param batch_first:         是否将batch_size放置到矩阵的第一维度
        """

        # 类继承初始化函数
        super(BiLSTM, self).__init__()
        # 设置标签与id对照
        self.tag_to_id = tag_to_id
        # 设置标签大小, 对应BiLSTM最终输出分数矩阵宽度
        self.tag_size = len(tag_to_id)
        # 设定LSTM输入特征大小, 对应词嵌入的维度大小
        self.embedding_size = input_feature_size
        # 设置隐藏层维度, 若为双向时想要得到同样大小的向量, 需要除以2
        self.hidden_size = hidden_size // 2
        # 设置批次大小, 对应每个批次的样本条数, 可以理解为输入张量的第一个维度
        self.batch_size = batch_size
        # 设定句子长度
        self.sentence_length = sentence_length
        # 设定是否将batch_size放置到矩阵的第一维度, 取值True, 或False
        self.batch_first = batch_first
        # 设置网络的LSTM层数
        self.num_layers = num_layers
        """
         构建词嵌入层: 字向量, 维度为总单词数量与词嵌入维度
         参数: 总体字库的单词数量, 每个字被嵌入的维度
        """
        self.embedding = nn.Embedding(vocab_size, self.embedding_size)
        self.bilstm = nn.LSTM(input_size=input_feature_size,
                              hidden_size=self.hidden_size,
                              num_layers=num_layers,
                              bidirectional=True,
                              batch_first=batch_first)

        # 构建全连接线性层: 将BiLSTM的输出层进行线性变换
        self.linear = nn.Linear(hidden_size, self.tag_size)


print("=" * 100)
# 参数1:码表与id对照
char_to_id = {"双": 0, "肺": 1, "见": 2, "多": 3, "发": 4, "斑": 5, "片": 6,
              "状": 7, "稍": 8, "高": 9, "密": 10, "度": 11, "影": 12, "。": 13}

# 参数2:标签码表对照
tag_to_id = {"O": 0, "B-dis": 1, "I-dis": 2, "B-sym": 3, "I-sym": 4}
# 参数3:字向量维度
EMBEDDING_DIM = 200
# 参数4:隐层维度
HIDDEN_DIM = 100
# 参数5:批次大小
BATCH_SIZE = 8
# 参数6:句子长度
SENTENCE_LENGTH = 20
# 参数7:堆叠 LSTM 层数
NUM_LAYERS = 1

# 初始化模型
"""
model = BiLSTM(vocab_size=len(char_to_id),
               tag_to_id=tag_to_id,
               input_feature_size=EMBEDDING_DIM,
               hidden_size=HIDDEN_DIM,
               batch_size= BATCH_SIZE,
               sentence_length= SENTENCE_LENGTH,
               num_layers=NUM_LAYERS)

print(model)
"""



最近更新

  1. B树(B-Tree)详解

    2024-07-10 22:48:06       0 阅读
  2. IPython与Pandas:数据分析的动态组

    2024-07-10 22:48:06       0 阅读
  3. SSR和SPA渲染模式

    2024-07-10 22:48:06       0 阅读
  4. 《流程引擎原理与实践》开源电子书

    2024-07-10 22:48:06       0 阅读
  5. 2742. 给墙壁刷油漆

    2024-07-10 22:48:06       0 阅读
  6. longjmp和多线程:读写线程实例

    2024-07-10 22:48:06       0 阅读
  7. 【CF】1216F-WiFi 题解

    2024-07-10 22:48:06       0 阅读
  8. 牛客周赛 Round 52VP(附D的详细证明)

    2024-07-10 22:48:06       0 阅读
  9. Android13 应用代码中修改热点默认密码

    2024-07-10 22:48:06       0 阅读
  10. 【React】事件绑定、React组件、useState、基础样式

    2024-07-10 22:48:06       0 阅读
  11. postman接口测试工具详解

    2024-07-10 22:48:06       0 阅读

热门阅读

  1. Vue2.0和Vue3.0的区别?

    2024-07-10 22:48:06       8 阅读
  2. 网络安全应急处理流程

    2024-07-10 22:48:06       6 阅读
  3. 算法·高精度

    2024-07-10 22:48:06       5 阅读
  4. 闲聊C++与面向对象思想

    2024-07-10 22:48:06       6 阅读
  5. 路由器中 RIB 与 FIB 的区别

    2024-07-10 22:48:06       11 阅读
  6. 生成日志系统和监控

    2024-07-10 22:48:06       9 阅读
  7. Apache Spark详解

    2024-07-10 22:48:06       7 阅读
  8. qt opencv 应用举例

    2024-07-10 22:48:06       7 阅读
  9. Pytorch中分类回归常用的损失和优化器

    2024-07-10 22:48:06       6 阅读
  10. 【Rust】Cargo介绍

    2024-07-10 22:48:06       6 阅读
  11. 搭建Spring Cloud项目思路

    2024-07-10 22:48:06       6 阅读
  12. C语言从头学32——字符串数组

    2024-07-10 22:48:06       5 阅读
  13. 7. 有奖猜谜

    2024-07-10 22:48:06       5 阅读