目录
- 什么是TextCNN
- 定义TextCNN类
- 初始化一个model实例
- 输出model
什么是TextCNN
- TextCNN(Text Convolutional Neural Network)是一种用于处理文本数据的卷积神经网(CNN)。通过在文本数据上应用卷积操作来提取局部特征,这些特征可以捕捉到文本中的局部模式,如n-gram(连续的n个单词或字符)。
定义TextCNN类
import torch.nn as nn
class TextCNN(nn.Module):def __init__(self, vocab_size, embed_dim, num_classes, num_filters, kernel_sizes):super(TextCNN, self).__init__()self.embedding = nn.Embedding(vocab_size, embed_dim)self.convs = nn.ModuleList([nn.Conv1d(in_channels=embed_dim, out_channels=num_filters, kernel_size=k) for k in kernel_sizes])self.dropout = nn.Dropout(0.5)self.fc = nn.Linear(len(kernel_sizes) * num_filters, num_classes)def forward(self, x):x = self.embedding(x) x = x.transpose(1, 2) convs = [torch.relu(conv(x)) for conv in self.convs] pooled = [torch.max_pool1d(conv, conv.size(2)).squeeze(2) for conv in convs] cat = torch.cat(pooled, 1) return self.fc(self.dropout(cat))
初始化一个model实例
vocab_size = 1000
embed_dim = 128
num_classes = 2
num_filters = 100
kernel_sizes = [3, 4, 5]model = TextCNN(vocab_size, embed_dim, num_classes, num_filters, kernel_sizes)
输出model
TextCNN((embedding): Embedding(8, 128)(convs): ModuleList((0): Conv1d(128, 100, kernel_size=(3,), stride=(1,))(1): Conv1d(128, 100, kernel_size=(4,), stride=(1,))(2): Conv1d(128, 100, kernel_size=(5,), stride=(1,)))(dropout): Dropout(p=0.5, inplace=False)(fc): Linear(in_features=300, out_features=2, bias=True)
)
Embedding(8, 128)
:这是一个嵌入层,它将词汇表中的每个单词映射到一个128
维的向量空间。这里的8
表示词汇表的大小(即输入序列中可能的最大单词索引),128
表示每个单词将被映射到的向量维度。convs: ModuleList[...]
:这是一个包含多个一维卷积层(Conv1d
)的模块列表。每个卷积层都用于提取文本数据的不同局部特征。Conv1d(128, 100, kernel_size=(3,), stride=(1,))
:每个卷积层有128
个输入通道(与嵌入层的输出维度相同)和100
个输出通道(即100
个滤波器)。kernel_size=3
表示每个滤波器的窗口大小为3
个词。stride=1
表示滤波器在文本序列上滑动的步长为1
。Dropout(p=0.5, inplace=False)
:这是一个Dropout
层,它在训练过程中随机丢弃50%
的节点,以减少过拟合。inplace=False
表示Dropout
操作不会在原地修改输入张量。fc: Linear(in_features=300, out_features=2, bias=True)
:这是一个全连接层,它将卷积层和Dropout
层的输出转换为最终的分类结果。in_features=300
表示全连接层的输入特征数量(这是由卷积层的数量和每个卷积层的输出特征数量决定的,即3
个卷积层各100
个特征)。out_features=2
表示输出特征的数量,这通常与分类任务的类别数相对应(在这个例子中,可能是二分类问题)。bias=True
表示全连接层的权重将包含偏置项。