Source code for dfa_recommender.vat

'''
Virtual adversarial training
'''
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F


def _l2_normalize(d):
    '''
    Normalize d.

    Parameters
    ----------
    d: torch.Tensor
        random perturbation in the input space

    Returns
    ----------
    dn: torch.Tensor
        normalized random perturbation in the input space
    '''
    d = d.cpu().numpy()
    if len(d.shape) == 4:
        d /= (np.sqrt(np.sum(d ** 2, axis=(1, 2, 3))).reshape(
            (-1, 1, 1, 1)) + 1e-16)
    elif len(d.shape) == 3:
        d /= (np.sqrt(np.sum(d ** 2, axis=(1, 2))).reshape(
            (-1, 1, 1)) + 1e-16)
    elif len(d.shape) == 2:
        d /= (np.sqrt(np.sum(d ** 2, axis=(1))).reshape(
            (-1, 1)) + 1e-16)
    else:
        raise ValueError("Dimension is not encoded yet.")
    dn = torch.from_numpy(d)
    return dn


def _entropy(logits):
    '''
    Calculation of cross entropy
    '''
    return -torch.mean(torch.sum(F.softmax(logits, dim=1) * F.log_softmax(logits, dim=1), dim=1))


def _entropy_array(logits):
    '''
    Calculation of cross entropy as an array (not averaged)
    '''
    return np.abs(torch.sum(F.softmax(logits, dim=1) * F.log_softmax(logits, dim=1), dim=1).detach().cpu().numpy())


[docs]class VAT(object): ''' Implementation of virtual adversarial training. See https://arxiv.org/abs/1704.03976 for more details. ''' def __init__(self, device, eps, xi, alpha, k=1, use_entmin=False): self.device = device self.xi = xi self.eps = eps self.alpha = alpha self.k = k self.kl_div = nn.KLDivLoss(reduction='none').to(device) self.use_entmin = use_entmin def __call__(self, model, X): logits = model(X, update_batch_stats=False) prob_logits = F.softmax(logits.detach(), dim=1) d = _l2_normalize(torch.randn(X.size())).to(self.device) # d = _l2_normalize(torch.ones(X.size())).to(self.device) for ip in range(self.k): X_hat = X + d * self.xi X_hat.requires_grad = True logits_hat = model(X_hat, update_batch_stats=False) adv_distance = torch.mean(self.kl_div( F.log_softmax(logits_hat, dim=1), prob_logits).sum(dim=1)) adv_distance.backward() d = _l2_normalize(X_hat.grad).to(self.device) logits_hat = model(X + self.eps * d, update_batch_stats=False) LDS = self.alpha * torch.mean(self.kl_div(F.log_softmax(logits_hat, dim=1), prob_logits).sum(dim=1)) LDS_array = self.alpha * np.abs( self.kl_div(F.log_softmax(logits_hat, dim=1), prob_logits).detach().cpu().numpy()[:, 0]) if self.use_entmin: LDS += _entropy(logits_hat) ent_array = _entropy_array(logits_hat) LDS_array += ent_array return LDS, LDS_array
[docs]def df_l2_normalize(d, l_x, cut=True): ''' Normalize d with a zero masking. Parameters ---------- d: torch.Tensor random perturbation in the input space l_x: torch.Tensor a tensor based on which the mask is created cut: bool, default as True whether applying the mask or not Returns ---------- dn: torch.Tensor normalized random perturbation in the input space ''' if cut: r = d[:, :, :-1] sample_size = l_x.shape rand_size = (sample_size[0], sample_size[1], sample_size[2] - 1) cat_size = (sample_size[0], sample_size[1], 1) zeros_mat = torch.zeros(rand_size) dn = torch.where(l_x[:, :, :-1] != 0, r, zeros_mat) dn = torch.cat((dn, torch.zeros(cat_size)), -1) dn = dn.cpu().numpy() else: dn = d.cpu().numpy() if len(d.shape) == 3: dn /= (np.sqrt(np.sum(dn ** 2, axis=(1, 2))).reshape( (-1, 1, 1)) + 1e-16) else: raise ValueError("Dimension is not encoded yet.") dn = torch.from_numpy(dn) return dn
[docs]class regVAT(object): ''' Implementation of virtual adversarial training in a regression task The only difference compared to VAT is the change of KL divergence to MSE in measuring the original and perturbed point. ''' def __init__(self, device, eps, xi, alpha, k=1, cut=True): self.device = device self.xi = xi self.eps = eps self.alpha = alpha self.k = k self.cut = cut self.metric = torch.nn.L1Loss(reduction='none').to(device) def __call__(self, model, X, return_adv=False): model.eval() logits = model(X, update_batch_stats=False) prob_logits = logits.detach() d = df_l2_normalize(torch.randn(X.size()), X, self.cut).to(self.device) for __ in range(self.k): X_hat = X + d * self.xi X_hat.requires_grad = True logits_hat = model(X_hat, update_batch_stats=False) adv_distance = torch.mean(self.metric(logits_hat, prob_logits)) adv_distance.backward() d = df_l2_normalize(X_hat.grad, X, self.cut).to(self.device) if return_adv: model.train() return d logits_hat = model(X + self.eps * d, update_batch_stats=False) LDS = self.alpha * torch.mean(self.metric(logits_hat, prob_logits)) model.train() return LDS