Source code for deepctr_torch.layers.sequence

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import PackedSequence

from ..layers.core import LocalActivationUnit


[docs]class SequencePoolingLayer(nn.Module): """The SequencePoolingLayer is used to apply pooling operation(sum,mean,max) on variable-length sequence feature/multi-value feature. Input shape - A list of two tensor [seq_value,seq_len] - seq_value is a 3D tensor with shape: ``(batch_size, T, embedding_size)`` - seq_len is a 2D tensor with shape : ``(batch_size, 1)``,indicate valid length of each sequence. Output shape - 3D tensor with shape: ``(batch_size, 1, embedding_size)``. Arguments - **mode**:str.Pooling operation to be used,can be sum,mean or max. """ def __init__(self, mode='mean', supports_masking=False, device='cpu'): super(SequencePoolingLayer, self).__init__() if mode not in ['sum', 'mean', 'max']: raise ValueError('parameter mode should in [sum, mean, max]') self.supports_masking = supports_masking self.mode = mode self.device = device self.eps = torch.FloatTensor([1e-8]).to(device) self.to(device) def _sequence_mask(self, lengths, maxlen=None, dtype=torch.bool): # Returns a mask tensor representing the first N positions of each cell. if maxlen is None: maxlen = lengths.max() row_vector = torch.arange(0, maxlen, 1).to(lengths.device) matrix = torch.unsqueeze(lengths, dim=-1) mask = row_vector < matrix mask.type(dtype) return mask
[docs] def forward(self, seq_value_len_list): if self.supports_masking: uiseq_embed_list, mask = seq_value_len_list # [B, T, E], [B, 1] mask = mask.float() user_behavior_length = torch.sum(mask, dim=-1, keepdim=True) mask = mask.unsqueeze(2) else: uiseq_embed_list, user_behavior_length = seq_value_len_list # [B, T, E], [B, 1] mask = self._sequence_mask(user_behavior_length, maxlen=uiseq_embed_list.shape[1], dtype=torch.float32) # [B, 1, maxlen] mask = torch.transpose(mask, 1, 2) # [B, maxlen, 1] embedding_size = uiseq_embed_list.shape[-1] mask = torch.repeat_interleave(mask, embedding_size, dim=2) # [B, maxlen, E] if self.mode == 'max': hist = uiseq_embed_list - (1 - mask) * 1e9 hist = torch.max(hist, dim=1, keepdim=True)[0] return hist hist = uiseq_embed_list * mask.float() hist = torch.sum(hist, dim=1, keepdim=False) if self.mode == 'mean': self.eps = self.eps.to(user_behavior_length.device) hist = torch.div(hist, user_behavior_length.type(torch.float32) + self.eps) hist = torch.unsqueeze(hist, dim=1) return hist
[docs]class AttentionSequencePoolingLayer(nn.Module): """The Attentional sequence pooling operation used in DIN & DIEN. Arguments - **att_hidden_units**:list of positive integer, the attention net layer number and units in each layer. - **att_activation**: Activation function to use in attention net. - **weight_normalization**: bool.Whether normalize the attention score of local activation unit. - **supports_masking**:If True,the input need to support masking. References - [Zhou G, Zhu X, Song C, et al. Deep interest network for click-through rate prediction[C]//Proceedings of the 24th ACM SIGKDD International Conference on Knowledge Discovery & Data Mining. ACM, 2018: 1059-1068.](https://arxiv.org/pdf/1706.06978.pdf) """ def __init__(self, att_hidden_units=(80, 40), att_activation='sigmoid', weight_normalization=False, return_score=False, supports_masking=False, embedding_dim=4, **kwargs): super(AttentionSequencePoolingLayer, self).__init__() self.return_score = return_score self.weight_normalization = weight_normalization self.supports_masking = supports_masking self.local_att = LocalActivationUnit(hidden_units=att_hidden_units, embedding_dim=embedding_dim, activation=att_activation, dropout_rate=0, use_bn=False)
[docs] def forward(self, query, keys, keys_length, mask=None): """ Input shape - A list of three tensor: [query,keys,keys_length] - query is a 3D tensor with shape: ``(batch_size, 1, embedding_size)`` - keys is a 3D tensor with shape: ``(batch_size, T, embedding_size)`` - keys_length is a 2D tensor with shape: ``(batch_size, 1)`` Output shape - 3D tensor with shape: ``(batch_size, 1, embedding_size)``. """ batch_size, max_length, _ = keys.size() # Mask if self.supports_masking: if mask is None: raise ValueError("When supports_masking=True,input must support masking") keys_masks = mask.unsqueeze(1) else: keys_masks = torch.arange(max_length, device=keys_length.device, dtype=keys_length.dtype).repeat(batch_size, 1) # [B, T] keys_masks = keys_masks < keys_length.view(-1, 1) # 0, 1 mask keys_masks = keys_masks.unsqueeze(1) # [B, 1, T] attention_score = self.local_att(query, keys) # [B, T, 1] outputs = torch.transpose(attention_score, 1, 2) # [B, 1, T] if self.weight_normalization: paddings = torch.ones_like(outputs) * (-2 ** 32 + 1) else: paddings = torch.zeros_like(outputs) outputs = torch.where(keys_masks, outputs, paddings) # [B, 1, T] # Scale # outputs = outputs / (keys.shape[-1] ** 0.05) if self.weight_normalization: outputs = F.softmax(outputs, dim=-1) # [B, 1, T] if not self.return_score: # Weighted sum outputs = torch.matmul(outputs, keys) # [B, 1, E] return outputs
[docs]class KMaxPooling(nn.Module): """K Max pooling that selects the k biggest value along the specific axis. Input shape - nD tensor with shape: ``(batch_size, ..., input_dim)``. Output shape - nD tensor with shape: ``(batch_size, ..., output_dim)``. Arguments - **k**: positive integer, number of top elements to look for along the ``axis`` dimension. - **axis**: positive integer, the dimension to look for elements. """ def __init__(self, k, axis, device='cpu'): super(KMaxPooling, self).__init__() self.k = k self.axis = axis self.to(device)
[docs] def forward(self, inputs): if self.axis < 0 or self.axis >= len(inputs.shape): raise ValueError("axis must be 0~%d,now is %d" % (len(inputs.shape) - 1, self.axis)) if self.k < 1 or self.k > inputs.shape[self.axis]: raise ValueError("k must be in 1 ~ %d,now k is %d" % (inputs.shape[self.axis], self.k)) out = torch.topk(inputs, k=self.k, dim=self.axis, sorted=True)[0] return out
[docs]class AGRUCell(nn.Module): """ Attention based GRU (AGRU) Reference: - Deep Interest Evolution Network for Click-Through Rate Prediction[J]. arXiv preprint arXiv:1809.03672, 2018. """ def __init__(self, input_size, hidden_size, bias=True): super(AGRUCell, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.bias = bias # (W_ir|W_iz|W_ih) self.weight_ih = nn.Parameter(torch.Tensor(3 * hidden_size, input_size)) self.register_parameter('weight_ih', self.weight_ih) # (W_hr|W_hz|W_hh) self.weight_hh = nn.Parameter(torch.Tensor(3 * hidden_size, hidden_size)) self.register_parameter('weight_hh', self.weight_hh) if bias: # (b_ir|b_iz|b_ih) self.bias_ih = nn.Parameter(torch.Tensor(3 * hidden_size)) self.register_parameter('bias_ih', self.bias_ih) # (b_hr|b_hz|b_hh) self.bias_hh = nn.Parameter(torch.Tensor(3 * hidden_size)) self.register_parameter('bias_hh', self.bias_hh) for tensor in [self.bias_ih, self.bias_hh]: nn.init.zeros_(tensor, ) else: self.register_parameter('bias_ih', None) self.register_parameter('bias_hh', None)
[docs] def forward(self, inputs, hx, att_score): gi = F.linear(inputs, self.weight_ih, self.bias_ih) gh = F.linear(hx, self.weight_hh, self.bias_hh) i_r, _, i_n = gi.chunk(3, 1) h_r, _, h_n = gh.chunk(3, 1) reset_gate = torch.sigmoid(i_r + h_r) # update_gate = torch.sigmoid(i_z + h_z) new_state = torch.tanh(i_n + reset_gate * h_n) att_score = att_score.view(-1, 1) hy = (1. - att_score) * hx + att_score * new_state return hy
[docs]class AUGRUCell(nn.Module): """ Effect of GRU with attentional update gate (AUGRU) Reference: - Deep Interest Evolution Network for Click-Through Rate Prediction[J]. arXiv preprint arXiv:1809.03672, 2018. """ def __init__(self, input_size, hidden_size, bias=True): super(AUGRUCell, self).__init__() self.input_size = input_size self.hidden_size = hidden_size self.bias = bias # (W_ir|W_iz|W_ih) self.weight_ih = nn.Parameter(torch.Tensor(3 * hidden_size, input_size)) self.register_parameter('weight_ih', self.weight_ih) # (W_hr|W_hz|W_hh) self.weight_hh = nn.Parameter(torch.Tensor(3 * hidden_size, hidden_size)) self.register_parameter('weight_hh', self.weight_hh) if bias: # (b_ir|b_iz|b_ih) self.bias_ih = nn.Parameter(torch.Tensor(3 * hidden_size)) self.register_parameter('bias_ih', self.bias_ih) # (b_hr|b_hz|b_hh) self.bias_hh = nn.Parameter(torch.Tensor(3 * hidden_size)) self.register_parameter('bias_ih', self.bias_hh) for tensor in [self.bias_ih, self.bias_hh]: nn.init.zeros_(tensor, ) else: self.register_parameter('bias_ih', None) self.register_parameter('bias_hh', None)
[docs] def forward(self, inputs, hx, att_score): gi = F.linear(inputs, self.weight_ih, self.bias_ih) gh = F.linear(hx, self.weight_hh, self.bias_hh) i_r, i_z, i_n = gi.chunk(3, 1) h_r, h_z, h_n = gh.chunk(3, 1) reset_gate = torch.sigmoid(i_r + h_r) update_gate = torch.sigmoid(i_z + h_z) new_state = torch.tanh(i_n + reset_gate * h_n) att_score = att_score.view(-1, 1) update_gate = att_score * update_gate hy = (1. - update_gate) * hx + update_gate * new_state return hy
[docs]class DynamicGRU(nn.Module): def __init__(self, input_size, hidden_size, bias=True, gru_type='AGRU'): super(DynamicGRU, self).__init__() self.input_size = input_size self.hidden_size = hidden_size if gru_type == 'AGRU': self.rnn = AGRUCell(input_size, hidden_size, bias) elif gru_type == 'AUGRU': self.rnn = AUGRUCell(input_size, hidden_size, bias)
[docs] def forward(self, inputs, att_scores=None, hx=None): if not isinstance(inputs, PackedSequence) or not isinstance(att_scores, PackedSequence): raise NotImplementedError("DynamicGRU only supports packed input and att_scores") inputs, batch_sizes, sorted_indices, unsorted_indices = inputs att_scores, _, _, _ = att_scores max_batch_size = int(batch_sizes[0]) if hx is None: hx = torch.zeros(max_batch_size, self.hidden_size, dtype=inputs.dtype, device=inputs.device) outputs = torch.zeros(inputs.size(0), self.hidden_size, dtype=inputs.dtype, device=inputs.device) begin = 0 for batch in batch_sizes: new_hx = self.rnn( inputs[begin:begin + batch], hx[0:batch], att_scores[begin:begin + batch]) outputs[begin:begin + batch] = new_hx hx = new_hx begin += batch return PackedSequence(outputs, batch_sizes, sorted_indices, unsorted_indices)