Source code for deepctr_torch.models.xdeepfm

# -*- coding:utf-8 -*-
"""
Author:
    Wutong Zhang
Reference:
    [1] Guo H, Tang R, Ye Y, et al. Deepfm: a factorization-machine based neural network for ctr prediction[J]. arXiv preprint arXiv:1703.04247, 2017.(https://arxiv.org/abs/1703.04247)
"""

import torch
import torch.nn as nn

from .basemodel import BaseModel
from ..inputs import combined_dnn_input
from ..layers import DNN, CIN


[docs]class xDeepFM(BaseModel): """Instantiates the xDeepFM architecture. :param linear_feature_columns: An iterable containing all the features used by linear part of the model. :param dnn_feature_columns: An iterable containing all the features used by deep part of the model. :param dnn_hidden_units: list,list of positive integer or empty list, the layer number and units in each layer of deep net :param cin_layer_size: list,list of positive integer or empty list, the feature maps in each hidden layer of Compressed Interaction Network :param cin_split_half: bool.if set to True, half of the feature maps in each hidden will connect to output unit :param cin_activation: activation function used on feature maps :param l2_reg_linear: float. L2 regularizer strength applied to linear part :param l2_reg_embedding: L2 regularizer strength applied to embedding vector :param l2_reg_dnn: L2 regularizer strength applied to deep net :param l2_reg_cin: L2 regularizer strength applied to CIN. :param init_std: float,to use as the initialize std of embedding vector :param seed: integer ,to use as random seed. :param dnn_dropout: float in [0,1), the probability we will drop out a given DNN coordinate. :param dnn_activation: Activation function to use in DNN :param dnn_use_bn: bool. Whether use BatchNormalization before activation or not in DNN :param task: str, ``"binary"`` for binary logloss or ``"regression"`` for regression loss :param device: str, ``"cpu"`` or ``"cuda:0"`` :param gpus: list of int or torch.device for multiple gpus. If None, run on `device`. `gpus[0]` should be the same gpu with `device`. :return: A PyTorch model instance. """ def __init__(self, linear_feature_columns, dnn_feature_columns, dnn_hidden_units=(256, 256), cin_layer_size=(256, 128,), cin_split_half=True, cin_activation='relu', l2_reg_linear=0.00001, l2_reg_embedding=0.00001, l2_reg_dnn=0, l2_reg_cin=0, init_std=0.0001, seed=1024, dnn_dropout=0, dnn_activation='relu', dnn_use_bn=False, task='binary', device='cpu', gpus=None): super(xDeepFM, self).__init__(linear_feature_columns, dnn_feature_columns, l2_reg_linear=l2_reg_linear, l2_reg_embedding=l2_reg_embedding, init_std=init_std, seed=seed, task=task, device=device, gpus=gpus) self.dnn_hidden_units = dnn_hidden_units self.use_dnn = len(dnn_feature_columns) > 0 and len(dnn_hidden_units) > 0 if self.use_dnn: self.dnn = DNN(self.compute_input_dim(dnn_feature_columns), dnn_hidden_units, activation=dnn_activation, l2_reg=l2_reg_dnn, dropout_rate=dnn_dropout, use_bn=dnn_use_bn, init_std=init_std, device=device) self.dnn_linear = nn.Linear(dnn_hidden_units[-1], 1, bias=False).to(device) self.add_regularization_weight( filter(lambda x: 'weight' in x[0] and 'bn' not in x[0], self.dnn.named_parameters()), l2=l2_reg_dnn) self.add_regularization_weight(self.dnn_linear.weight, l2=l2_reg_dnn) self.cin_layer_size = cin_layer_size self.use_cin = len(self.cin_layer_size) > 0 and len(dnn_feature_columns) > 0 if self.use_cin: field_num = len(self.embedding_dict) if cin_split_half == True: self.featuremap_num = sum( cin_layer_size[:-1]) // 2 + cin_layer_size[-1] else: self.featuremap_num = sum(cin_layer_size) self.cin = CIN(field_num, cin_layer_size, cin_activation, cin_split_half, l2_reg_cin, seed, device=device) self.cin_linear = nn.Linear(self.featuremap_num, 1, bias=False).to(device) self.add_regularization_weight(filter(lambda x: 'weight' in x[0], self.cin.named_parameters()), l2=l2_reg_cin) self.to(device)
[docs] def forward(self, X): sparse_embedding_list, dense_value_list = self.input_from_feature_columns(X, self.dnn_feature_columns, self.embedding_dict) linear_logit = self.linear_model(X) if self.use_cin: cin_input = torch.cat(sparse_embedding_list, dim=1) cin_output = self.cin(cin_input) cin_logit = self.cin_linear(cin_output) if self.use_dnn: dnn_input = combined_dnn_input(sparse_embedding_list, dense_value_list) dnn_output = self.dnn(dnn_input) dnn_logit = self.dnn_linear(dnn_output) if len(self.dnn_hidden_units) == 0 and len(self.cin_layer_size) == 0: # only linear final_logit = linear_logit elif len(self.dnn_hidden_units) == 0 and len(self.cin_layer_size) > 0: # linear + CIN final_logit = linear_logit + cin_logit elif len(self.dnn_hidden_units) > 0 and len(self.cin_layer_size) == 0: # linear + Deep final_logit = linear_logit + dnn_logit elif len(self.dnn_hidden_units) > 0 and len(self.cin_layer_size) > 0: # linear + CIN + Deep final_logit = linear_logit + dnn_logit + cin_logit else: raise NotImplementedError y_pred = self.out(final_logit) return y_pred