Skip to content

Losses

pytorch-widedeep accepts a number of losses and objectives that can be passed to the Trainer class via the parameter objective (see pytorch-widedeep.training.Trainer). For most cases the loss function that pytorch-widedeep will use internally is already implemented in Pytorch.

In addition, pytorch-widedeep implements a series of "custom" loss functions. These are described below for completion since, as mentioned before, they are used internally by the Trainer. Of course, onen could always use them on their own and can be imported as:

from pytorch_widedeep.losses import FocalLoss


ℹ️ NOTE: Losses in this module expect the predictions and ground truth to have the same dimensions for regression and binary classification problems \((N_{samples}, 1)\). In the case of multiclass classification problems the ground truth is expected to be a 1D tensor with the corresponding classes. See Examples below


MSELoss

Bases: Module

Mean square error loss

Source code in pytorch_widedeep/losses.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class MSELoss(nn.Module):
    r"""Mean square error loss"""

    # legacy code from when we used to support FDS-LDS and this class could
    # taked the corresponding params. At this stage probably you want to use
    # torch.nn.MSELoss
    def __init__(self):
        super().__init__()

    def forward(
        self,
        input: Tensor,
        target: Tensor,
    ) -> Tensor:
        r"""
        Parameters
        ----------
        input: Tensor
            Input tensor with predictions
        target: Tensor
            Target tensor with the actual values

        Examples
        --------
        >>> import torch
        >>> from pytorch_widedeep.losses import MSELoss
        >>>
        >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
        >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
        >>> loss = MSELoss()(input, target)
        """
        loss = (input - target) ** 2
        return torch.mean(loss)

forward

forward(input, target)

Parameters:

Name Type Description Default
input Tensor

Input tensor with predictions

required
target Tensor

Target tensor with the actual values

required

Examples:

>>> import torch
>>> from pytorch_widedeep.losses import MSELoss
>>>
>>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
>>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
>>> loss = MSELoss()(input, target)
Source code in pytorch_widedeep/losses.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def forward(
    self,
    input: Tensor,
    target: Tensor,
) -> Tensor:
    r"""
    Parameters
    ----------
    input: Tensor
        Input tensor with predictions
    target: Tensor
        Target tensor with the actual values

    Examples
    --------
    >>> import torch
    >>> from pytorch_widedeep.losses import MSELoss
    >>>
    >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
    >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
    >>> loss = MSELoss()(input, target)
    """
    loss = (input - target) ** 2
    return torch.mean(loss)

MSLELoss

Bases: Module

Mean square log error loss

Source code in pytorch_widedeep/losses.py
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class MSLELoss(nn.Module):
    r"""Mean square log error loss"""

    def __init__(self):
        super().__init__()

    def forward(
        self,
        input: Tensor,
        target: Tensor,
    ) -> Tensor:
        r"""
        Parameters
        ----------
        input: Tensor
            Input tensor with predictions (not probabilities)
        target: Tensor
            Target tensor with the actual classes

        Examples
        --------
        >>> import torch
        >>> from pytorch_widedeep.losses import MSLELoss
        >>>
        >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
        >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
        >>> loss = MSLELoss()(input, target)
        """
        assert (
            input.min() >= 0
        ), """All input values must be >=0, if your model is predicting
            values <0 try to enforce positive values by activation function
            on last layer with `trainer.enforce_positive_output=True`"""
        assert target.min() >= 0, "All target values must be >=0"

        loss = (torch.log(input + 1) - torch.log(target + 1)) ** 2
        return torch.mean(loss)

forward

forward(input, target)

Parameters:

Name Type Description Default
input Tensor

Input tensor with predictions (not probabilities)

required
target Tensor

Target tensor with the actual classes

required

Examples:

>>> import torch
>>> from pytorch_widedeep.losses import MSLELoss
>>>
>>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
>>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
>>> loss = MSLELoss()(input, target)
Source code in pytorch_widedeep/losses.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def forward(
    self,
    input: Tensor,
    target: Tensor,
) -> Tensor:
    r"""
    Parameters
    ----------
    input: Tensor
        Input tensor with predictions (not probabilities)
    target: Tensor
        Target tensor with the actual classes

    Examples
    --------
    >>> import torch
    >>> from pytorch_widedeep.losses import MSLELoss
    >>>
    >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
    >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
    >>> loss = MSLELoss()(input, target)
    """
    assert (
        input.min() >= 0
    ), """All input values must be >=0, if your model is predicting
        values <0 try to enforce positive values by activation function
        on last layer with `trainer.enforce_positive_output=True`"""
    assert target.min() >= 0, "All target values must be >=0"

    loss = (torch.log(input + 1) - torch.log(target + 1)) ** 2
    return torch.mean(loss)

RMSELoss

Bases: Module

Root mean square error loss

Source code in pytorch_widedeep/losses.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
class RMSELoss(nn.Module):
    r"""Root mean square error loss"""

    # legacy code from when we used to support FDS-LDS and this class could
    # taked the corresponding params. At this stage probably you want to use
    # torch.sqrt(nn.MSELoss)
    def __init__(self):
        super().__init__()

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        r"""
        Parameters
        ----------
        input: Tensor
            Input tensor with predictions (not probabilities)
        target: Tensor
            Target tensor with the actual classes

        Examples
        --------
        >>> import torch
        >>> from pytorch_widedeep.losses import RMSELoss
        >>>
        >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
        >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
        >>> loss = RMSELoss()(input, target)
        """
        loss = (input - target) ** 2
        return torch.sqrt(torch.mean(loss))

forward

forward(input, target)

Parameters:

Name Type Description Default
input Tensor

Input tensor with predictions (not probabilities)

required
target Tensor

Target tensor with the actual classes

required

Examples:

>>> import torch
>>> from pytorch_widedeep.losses import RMSELoss
>>>
>>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
>>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
>>> loss = RMSELoss()(input, target)
Source code in pytorch_widedeep/losses.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
def forward(self, input: Tensor, target: Tensor) -> Tensor:
    r"""
    Parameters
    ----------
    input: Tensor
        Input tensor with predictions (not probabilities)
    target: Tensor
        Target tensor with the actual classes

    Examples
    --------
    >>> import torch
    >>> from pytorch_widedeep.losses import RMSELoss
    >>>
    >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
    >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
    >>> loss = RMSELoss()(input, target)
    """
    loss = (input - target) ** 2
    return torch.sqrt(torch.mean(loss))

RMSLELoss

Bases: Module

Root mean square log error loss

Source code in pytorch_widedeep/losses.py
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
class RMSLELoss(nn.Module):
    r"""Root mean square log error loss"""

    def __init__(self):
        super().__init__()

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        r"""
        Parameters
        ----------
        input: Tensor
            Input tensor with predictions (not probabilities)
        target: Tensor
            Target tensor with the actual classes

        Examples
        --------
        >>> import torch
        >>> from pytorch_widedeep.losses import RMSLELoss
        >>>
        >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
        >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
        >>> loss = RMSLELoss()(input, target)
        """
        assert (
            input.min() >= 0
        ), """All input values must be >=0, if your model is predicting
            values <0 try to enforce positive values by activation function
            on last layer with `trainer.enforce_positive_output=True`"""
        assert target.min() >= 0, "All target values must be >=0"

        loss = (torch.log(input + 1) - torch.log(target + 1)) ** 2
        return torch.sqrt(torch.mean(loss))

forward

forward(input, target)

Parameters:

Name Type Description Default
input Tensor

Input tensor with predictions (not probabilities)

required
target Tensor

Target tensor with the actual classes

required

Examples:

>>> import torch
>>> from pytorch_widedeep.losses import RMSLELoss
>>>
>>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
>>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
>>> loss = RMSLELoss()(input, target)
Source code in pytorch_widedeep/losses.py
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def forward(self, input: Tensor, target: Tensor) -> Tensor:
    r"""
    Parameters
    ----------
    input: Tensor
        Input tensor with predictions (not probabilities)
    target: Tensor
        Target tensor with the actual classes

    Examples
    --------
    >>> import torch
    >>> from pytorch_widedeep.losses import RMSLELoss
    >>>
    >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
    >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
    >>> loss = RMSLELoss()(input, target)
    """
    assert (
        input.min() >= 0
    ), """All input values must be >=0, if your model is predicting
        values <0 try to enforce positive values by activation function
        on last layer with `trainer.enforce_positive_output=True`"""
    assert target.min() >= 0, "All target values must be >=0"

    loss = (torch.log(input + 1) - torch.log(target + 1)) ** 2
    return torch.sqrt(torch.mean(loss))

QuantileLoss

Bases: Module

Quantile loss defined as:

\[ Loss = max(q \times (y-y_{pred}), (1-q) \times (y_{pred}-y)) \]

All credits go to the implementation at pytorch-forecasting.

Parameters:

Name Type Description Default
quantiles List[float]

List of quantiles

[0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98]
Source code in pytorch_widedeep/losses.py
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
class QuantileLoss(nn.Module):
    r"""Quantile loss defined as:

    $$
    Loss = max(q \times (y-y_{pred}), (1-q) \times (y_{pred}-y))
    $$

    All credits go to the implementation at
    [pytorch-forecasting](https://pytorch-forecasting.readthedocs.io/en/latest/_modules/pytorch_forecasting/metrics.html#QuantileLoss).

    Parameters
    ----------
    quantiles: List, default = [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98]
        List of quantiles
    """

    def __init__(
        self,
        quantiles: List[float] = [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98],
    ):
        super().__init__()
        self.quantiles = quantiles

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        r"""
        Parameters
        ----------
        input: Tensor
            Input tensor with predictions
        target: Tensor
            Target tensor with the actual values

        Examples
        --------
        >>> import torch
        >>>
        >>> from pytorch_widedeep.losses import QuantileLoss
        >>>
        >>> # REGRESSION
        >>> target = torch.tensor([[0.6, 1.5]]).view(-1, 1)
        >>> input = torch.tensor([[.1, .2,], [.4, .5]])
        >>> qloss = QuantileLoss([0.25, 0.75])
        >>> loss = qloss(input, target)
        """

        assert input.shape == torch.Size([target.shape[0], len(self.quantiles)]), (
            "The input and target have inconsistent shape. The dimension of the prediction "
            "of the model that is using QuantileLoss must be equal to number of quantiles, "
            f"i.e. {len(self.quantiles)}."
        )
        target = target.view(-1, 1).float()
        losses = []
        for i, q in enumerate(self.quantiles):
            errors = target - input[..., i]
            losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1))

        loss = torch.cat(losses, dim=2)

        return torch.mean(loss)

forward

forward(input, target)

Parameters:

Name Type Description Default
input Tensor

Input tensor with predictions

required
target Tensor

Target tensor with the actual values

required

Examples:

>>> import torch
>>>
>>> from pytorch_widedeep.losses import QuantileLoss
>>>
>>> # REGRESSION
>>> target = torch.tensor([[0.6, 1.5]]).view(-1, 1)
>>> input = torch.tensor([[.1, .2,], [.4, .5]])
>>> qloss = QuantileLoss([0.25, 0.75])
>>> loss = qloss(input, target)
Source code in pytorch_widedeep/losses.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
def forward(self, input: Tensor, target: Tensor) -> Tensor:
    r"""
    Parameters
    ----------
    input: Tensor
        Input tensor with predictions
    target: Tensor
        Target tensor with the actual values

    Examples
    --------
    >>> import torch
    >>>
    >>> from pytorch_widedeep.losses import QuantileLoss
    >>>
    >>> # REGRESSION
    >>> target = torch.tensor([[0.6, 1.5]]).view(-1, 1)
    >>> input = torch.tensor([[.1, .2,], [.4, .5]])
    >>> qloss = QuantileLoss([0.25, 0.75])
    >>> loss = qloss(input, target)
    """

    assert input.shape == torch.Size([target.shape[0], len(self.quantiles)]), (
        "The input and target have inconsistent shape. The dimension of the prediction "
        "of the model that is using QuantileLoss must be equal to number of quantiles, "
        f"i.e. {len(self.quantiles)}."
    )
    target = target.view(-1, 1).float()
    losses = []
    for i, q in enumerate(self.quantiles):
        errors = target - input[..., i]
        losses.append(torch.max((q - 1) * errors, q * errors).unsqueeze(-1))

    loss = torch.cat(losses, dim=2)

    return torch.mean(loss)

FocalLoss

Bases: Module

Implementation of the Focal loss for both binary and multiclass classification:

\[ FL(p_t) = \alpha (1 - p_t)^{\gamma} log(p_t) \]

where, for a case of a binary classification problem

\[ \begin{equation} p_t= \begin{cases}p, & \text{if $y=1$}.\\1-p, & \text{otherwise}. \end{cases} \end{equation} \]

Parameters:

Name Type Description Default
alpha float

Focal Loss alpha parameter

0.25
gamma float

Focal Loss gamma parameter

1.0
Source code in pytorch_widedeep/losses.py
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
class FocalLoss(nn.Module):
    r"""Implementation of the [Focal loss](https://arxiv.org/pdf/1708.02002.pdf)
    for both binary and multiclass classification:

    $$
    FL(p_t) = \alpha (1 - p_t)^{\gamma} log(p_t)
    $$

    where, for a case of a binary classification problem

    $$
    \begin{equation} p_t= \begin{cases}p, & \text{if $y=1$}.\\1-p, & \text{otherwise}. \end{cases} \end{equation}
    $$

    Parameters
    ----------
    alpha: float
        Focal Loss `alpha` parameter
    gamma: float
        Focal Loss `gamma` parameter
    """

    def __init__(self, alpha: float = 0.25, gamma: float = 1.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def _get_weight(self, p: Tensor, t: Tensor) -> Tensor:
        pt = p * t + (1 - p) * (1 - t)  # type: ignore
        w = self.alpha * t + (1 - self.alpha) * (1 - t)  # type: ignore
        return (w * (1 - pt).pow(self.gamma)).detach()  # type: ignore

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        r"""
        Parameters
        ----------
        input: Tensor
            Input tensor with predictions (not probabilities)
        target: Tensor
            Target tensor with the actual classes

        Examples
        --------
        >>> import torch
        >>>
        >>> from pytorch_widedeep.losses import FocalLoss
        >>>
        >>> # BINARY
        >>> target = torch.tensor([0, 1, 0, 1]).view(-1, 1)
        >>> input = torch.tensor([[0.6, 0.7, 0.3, 0.8]]).t()
        >>> loss = FocalLoss()(input, target)
        >>>
        >>> # MULTICLASS
        >>> target = torch.tensor([1, 0, 2]).view(-1, 1)
        >>> input = torch.tensor([[0.2, 0.5, 0.3], [0.8, 0.1, 0.1], [0.7, 0.2, 0.1]])
        >>> loss = FocalLoss()(input, target)
        """
        input_prob = torch.sigmoid(input)
        if input.size(1) == 1:
            input_prob = torch.cat([1 - input_prob, input_prob], axis=1)  # type: ignore
            num_class = 2
        else:
            num_class = input_prob.size(1)
        binary_target = torch.eye(num_class)[target.squeeze().cpu().long()]
        binary_target = binary_target.to(input.device)
        binary_target = binary_target.contiguous()
        weight = self._get_weight(input_prob, binary_target)

        return F.binary_cross_entropy(
            input_prob, binary_target, weight, reduction="mean"
        )

forward

forward(input, target)

Parameters:

Name Type Description Default
input Tensor

Input tensor with predictions (not probabilities)

required
target Tensor

Target tensor with the actual classes

required

Examples:

>>> import torch
>>>
>>> from pytorch_widedeep.losses import FocalLoss
>>>
>>> # BINARY
>>> target = torch.tensor([0, 1, 0, 1]).view(-1, 1)
>>> input = torch.tensor([[0.6, 0.7, 0.3, 0.8]]).t()
>>> loss = FocalLoss()(input, target)
>>>
>>> # MULTICLASS
>>> target = torch.tensor([1, 0, 2]).view(-1, 1)
>>> input = torch.tensor([[0.2, 0.5, 0.3], [0.8, 0.1, 0.1], [0.7, 0.2, 0.1]])
>>> loss = FocalLoss()(input, target)
Source code in pytorch_widedeep/losses.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
def forward(self, input: Tensor, target: Tensor) -> Tensor:
    r"""
    Parameters
    ----------
    input: Tensor
        Input tensor with predictions (not probabilities)
    target: Tensor
        Target tensor with the actual classes

    Examples
    --------
    >>> import torch
    >>>
    >>> from pytorch_widedeep.losses import FocalLoss
    >>>
    >>> # BINARY
    >>> target = torch.tensor([0, 1, 0, 1]).view(-1, 1)
    >>> input = torch.tensor([[0.6, 0.7, 0.3, 0.8]]).t()
    >>> loss = FocalLoss()(input, target)
    >>>
    >>> # MULTICLASS
    >>> target = torch.tensor([1, 0, 2]).view(-1, 1)
    >>> input = torch.tensor([[0.2, 0.5, 0.3], [0.8, 0.1, 0.1], [0.7, 0.2, 0.1]])
    >>> loss = FocalLoss()(input, target)
    """
    input_prob = torch.sigmoid(input)
    if input.size(1) == 1:
        input_prob = torch.cat([1 - input_prob, input_prob], axis=1)  # type: ignore
        num_class = 2
    else:
        num_class = input_prob.size(1)
    binary_target = torch.eye(num_class)[target.squeeze().cpu().long()]
    binary_target = binary_target.to(input.device)
    binary_target = binary_target.contiguous()
    weight = self._get_weight(input_prob, binary_target)

    return F.binary_cross_entropy(
        input_prob, binary_target, weight, reduction="mean"
    )

BayesianSELoss

Bases: Module

Squared Loss (log Gaussian) for the case of a regression as specified in the original publication Weight Uncertainty in Neural Networks.

Source code in pytorch_widedeep/losses.py
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
class BayesianSELoss(nn.Module):
    r"""Squared Loss (log Gaussian) for the case of a regression as specified in
    the original publication
    [Weight Uncertainty in Neural Networks](https://arxiv.org/abs/1505.05424).
    """

    def __init__(self):
        super().__init__()

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        r"""
        Parameters
        ----------
        input: Tensor
            Input tensor with predictions (not probabilities)
        target: Tensor
            Target tensor with the actual classes

        Examples
        --------
        >>> import torch
        >>> from pytorch_widedeep.losses import BayesianSELoss
        >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
        >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
        >>> loss = BayesianSELoss()(input, target)
        """
        return (0.5 * (input - target) ** 2).sum()

forward

forward(input, target)

Parameters:

Name Type Description Default
input Tensor

Input tensor with predictions (not probabilities)

required
target Tensor

Target tensor with the actual classes

required

Examples:

>>> import torch
>>> from pytorch_widedeep.losses import BayesianSELoss
>>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
>>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
>>> loss = BayesianSELoss()(input, target)
Source code in pytorch_widedeep/losses.py
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
def forward(self, input: Tensor, target: Tensor) -> Tensor:
    r"""
    Parameters
    ----------
    input: Tensor
        Input tensor with predictions (not probabilities)
    target: Tensor
        Target tensor with the actual classes

    Examples
    --------
    >>> import torch
    >>> from pytorch_widedeep.losses import BayesianSELoss
    >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
    >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
    >>> loss = BayesianSELoss()(input, target)
    """
    return (0.5 * (input - target) ** 2).sum()

TweedieLoss

Bases: Module

Tweedie loss for extremely unbalanced zero-inflated data

All credits go to Wenbo Shi. See this post and the original publication for details.

Source code in pytorch_widedeep/losses.py
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
class TweedieLoss(nn.Module):
    """
    Tweedie loss for extremely unbalanced zero-inflated data

    All credits go to Wenbo Shi. See
    [this post](https://towardsdatascience.com/tweedie-loss-function-for-right-skewed-data-2c5ca470678f)
    and the [original publication](https://arxiv.org/abs/1811.10192) for details.
    """

    def __init__(self):
        super().__init__()

    def forward(
        self,
        input: Tensor,
        target: Tensor,
        p: float = 1.5,
    ) -> Tensor:
        r"""
        Parameters
        ----------
        input: Tensor
            Input tensor with predictions
        target: Tensor
            Target tensor with the actual values
        p: float, default = 1.5
            the power to be used to compute the loss. See the original
            publication for details

        Examples
        --------
        >>> import torch
        >>> from pytorch_widedeep.losses import TweedieLoss
        >>>
        >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
        >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
        >>> loss = TweedieLoss()(input, target)
        """

        assert (
            input.min() > 0
        ), """All input values must be >=0, if your model is predicting
            values <0 try to enforce positive values by activation function
            on last layer with `trainer.enforce_positive_output=True`"""
        assert target.min() >= 0, "All target values must be >=0"
        loss = -target * torch.pow(input, 1 - p) / (1 - p) + torch.pow(input, 2 - p) / (
            2 - p
        )
        return torch.mean(loss)

forward

forward(input, target, p=1.5)

Parameters:

Name Type Description Default
input Tensor

Input tensor with predictions

required
target Tensor

Target tensor with the actual values

required
p float

the power to be used to compute the loss. See the original publication for details

1.5

Examples:

>>> import torch
>>> from pytorch_widedeep.losses import TweedieLoss
>>>
>>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
>>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
>>> loss = TweedieLoss()(input, target)
Source code in pytorch_widedeep/losses.py
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
def forward(
    self,
    input: Tensor,
    target: Tensor,
    p: float = 1.5,
) -> Tensor:
    r"""
    Parameters
    ----------
    input: Tensor
        Input tensor with predictions
    target: Tensor
        Target tensor with the actual values
    p: float, default = 1.5
        the power to be used to compute the loss. See the original
        publication for details

    Examples
    --------
    >>> import torch
    >>> from pytorch_widedeep.losses import TweedieLoss
    >>>
    >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
    >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
    >>> loss = TweedieLoss()(input, target)
    """

    assert (
        input.min() > 0
    ), """All input values must be >=0, if your model is predicting
        values <0 try to enforce positive values by activation function
        on last layer with `trainer.enforce_positive_output=True`"""
    assert target.min() >= 0, "All target values must be >=0"
    loss = -target * torch.pow(input, 1 - p) / (1 - p) + torch.pow(input, 2 - p) / (
        2 - p
    )
    return torch.mean(loss)

ZILNLoss

Bases: Module

Adjusted implementation of the Zero Inflated LogNormal Loss

See A Deep Probabilistic Model for Customer Lifetime Value Prediction and the corresponding code.

Source code in pytorch_widedeep/losses.py
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
class ZILNLoss(nn.Module):
    r"""Adjusted implementation of the Zero Inflated LogNormal Loss

    See [A Deep Probabilistic Model for Customer Lifetime Value Prediction](https://arxiv.org/pdf/1912.07753.pdf)
    and the corresponding
    [code](https://github.com/google/lifetime_value/blob/master/lifetime_value/zero_inflated_lognormal.py).
    """

    def __init__(self):
        super().__init__()

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        r"""
        Parameters
        ----------
        input: Tensor
            Input tensor with predictions with spape (N,3), where N is the batch size
        target: Tensor
            Target tensor with the actual target values

        Examples
        --------
        >>> import torch
        >>> from pytorch_widedeep.losses import ZILNLoss
        >>>
        >>> target = torch.tensor([[0., 1.5]]).view(-1, 1)
        >>> input = torch.tensor([[.1, .2, .3], [.4, .5, .6]])
        >>> loss = ZILNLoss()(input, target)
        """
        positive = target > 0
        positive = positive.float()

        assert input.shape == torch.Size([target.shape[0], 3]), (
            "Wrong shape of the 'input' tensor. The pred_dim of the "
            "model that is using ZILNLoss must be equal to 3."
        )

        positive_input = input[..., :1]

        classification_loss = F.binary_cross_entropy_with_logits(
            positive_input, positive, reduction="none"
        ).flatten()

        loc = input[..., 1:2]

        # when using max the two input tensors (input and other) have to be of
        # the same type
        max_input = F.softplus(input[..., 2:])
        max_other = self.get_eps(max_input)
        scale = torch.max(max_input, max_other)
        safe_labels = positive * target + (1 - positive) * torch.ones_like(target)

        regression_loss = -torch.mean(
            positive
            * torch.distributions.log_normal.LogNormal(loc=loc, scale=scale).log_prob(
                safe_labels
            ),
            dim=-1,
        )

        return torch.mean(classification_loss + regression_loss)

    @staticmethod
    def get_eps(max_input: Tensor) -> Tensor:
        if max_input.device.type == "mps":
            # For MPS, use float32 and then convert to the input type
            eps = torch.finfo(torch.float32).eps
            max_other = (
                torch.sqrt(torch.tensor([eps], device="cpu"))
                .to(max_input.device)
                .to(max_input.dtype)
            )
        else:
            # For other devices, use the original approach
            eps = torch.finfo(torch.double).eps
            max_other = (
                torch.sqrt(torch.tensor([eps])).to(max_input.device).to(max_input.dtype)
            )

        return max_other

forward

forward(input, target)

Parameters:

Name Type Description Default
input Tensor

Input tensor with predictions with spape (N,3), where N is the batch size

required
target Tensor

Target tensor with the actual target values

required

Examples:

>>> import torch
>>> from pytorch_widedeep.losses import ZILNLoss
>>>
>>> target = torch.tensor([[0., 1.5]]).view(-1, 1)
>>> input = torch.tensor([[.1, .2, .3], [.4, .5, .6]])
>>> loss = ZILNLoss()(input, target)
Source code in pytorch_widedeep/losses.py
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
def forward(self, input: Tensor, target: Tensor) -> Tensor:
    r"""
    Parameters
    ----------
    input: Tensor
        Input tensor with predictions with spape (N,3), where N is the batch size
    target: Tensor
        Target tensor with the actual target values

    Examples
    --------
    >>> import torch
    >>> from pytorch_widedeep.losses import ZILNLoss
    >>>
    >>> target = torch.tensor([[0., 1.5]]).view(-1, 1)
    >>> input = torch.tensor([[.1, .2, .3], [.4, .5, .6]])
    >>> loss = ZILNLoss()(input, target)
    """
    positive = target > 0
    positive = positive.float()

    assert input.shape == torch.Size([target.shape[0], 3]), (
        "Wrong shape of the 'input' tensor. The pred_dim of the "
        "model that is using ZILNLoss must be equal to 3."
    )

    positive_input = input[..., :1]

    classification_loss = F.binary_cross_entropy_with_logits(
        positive_input, positive, reduction="none"
    ).flatten()

    loc = input[..., 1:2]

    # when using max the two input tensors (input and other) have to be of
    # the same type
    max_input = F.softplus(input[..., 2:])
    max_other = self.get_eps(max_input)
    scale = torch.max(max_input, max_other)
    safe_labels = positive * target + (1 - positive) * torch.ones_like(target)

    regression_loss = -torch.mean(
        positive
        * torch.distributions.log_normal.LogNormal(loc=loc, scale=scale).log_prob(
            safe_labels
        ),
        dim=-1,
    )

    return torch.mean(classification_loss + regression_loss)

L1Loss

Bases: Module

L1 loss

Source code in pytorch_widedeep/losses.py
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
class L1Loss(nn.Module):
    r"""L1 loss"""

    # legacy code from when we used to support FDS-LDS and this class could
    # taked the corresponding params. At this stage probably you want to use
    # torch.nn.L1Loss
    def __init__(self):
        super().__init__()

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        r"""
        Parameters
        ----------
        input: Tensor
            Input tensor with predictions
        target: Tensor
            Target tensor with the actual values

        Examples
        --------
        >>> import torch
        >>>
        >>> from pytorch_widedeep.losses import L1Loss
        >>>
        >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
        >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
        >>> loss = L1Loss()(input, target)
        """
        loss = F.l1_loss(input, target, reduction="none")
        return torch.mean(loss)

forward

forward(input, target)

Parameters:

Name Type Description Default
input Tensor

Input tensor with predictions

required
target Tensor

Target tensor with the actual values

required

Examples:

>>> import torch
>>>
>>> from pytorch_widedeep.losses import L1Loss
>>>
>>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
>>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
>>> loss = L1Loss()(input, target)
Source code in pytorch_widedeep/losses.py
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
def forward(self, input: Tensor, target: Tensor) -> Tensor:
    r"""
    Parameters
    ----------
    input: Tensor
        Input tensor with predictions
    target: Tensor
        Target tensor with the actual values

    Examples
    --------
    >>> import torch
    >>>
    >>> from pytorch_widedeep.losses import L1Loss
    >>>
    >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
    >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
    >>> loss = L1Loss()(input, target)
    """
    loss = F.l1_loss(input, target, reduction="none")
    return torch.mean(loss)

FocalR_L1Loss

Bases: Module

Focal-R L1 loss

Based on Delving into Deep Imbalanced Regression.

Parameters:

Name Type Description Default
beta float

Focal Loss beta parameter in their implementation

0.2
gamma float

Focal Loss gamma parameter

1.0
activation_fn Literal[sigmoid, tanh]

Activation function to be used during the computation of the loss. Possible values are 'sigmoid' and 'tanh'. See the original publication for details.

'sigmoid'
Source code in pytorch_widedeep/losses.py
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
class FocalR_L1Loss(nn.Module):
    r"""Focal-R L1 loss

    Based on [Delving into Deep Imbalanced Regression](https://arxiv.org/abs/2102.09554).

    Parameters
    ----------
    beta: float
        Focal Loss `beta` parameter in their implementation
    gamma: float
        Focal Loss `gamma` parameter
    activation_fn: str, default = "sigmoid"
        Activation function to be used during the computation of the loss.
        Possible values are _'sigmoid'_ and _'tanh'_. See the original
        publication for details.
    """

    def __init__(
        self,
        beta: float = 0.2,
        gamma: float = 1.0,
        activation_fn: Literal["sigmoid", "tanh"] = "sigmoid",
    ):
        super().__init__()
        self.beta = beta
        self.gamma = gamma
        self.activation_fn = activation_fn

    def forward(
        self,
        input: Tensor,
        target: Tensor,
    ) -> Tensor:
        r"""
        Parameters
        ----------
        input: Tensor
            Input tensor with predictions (not probabilities)
        target: Tensor
            Target tensor with the actual classes

        Examples
        --------
        >>> import torch
        >>>
        >>> from pytorch_widedeep.losses import FocalR_L1Loss
        >>>
        >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
        >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
        >>> loss = FocalR_L1Loss()(input, target)
        """
        loss = F.l1_loss(input, target, reduction="none")
        if self.activation_fn == "tanh":
            loss *= (torch.tanh(self.beta * torch.abs(input - target))) ** self.gamma
        elif self.activation_fn == "sigmoid":
            loss *= (
                2 * torch.sigmoid(self.beta * torch.abs(input - target)) - 1
            ) ** self.gamma
        else:
            ValueError(
                "Incorrect activation function value - must be in ['sigmoid', 'tanh']"
            )
        return torch.mean(loss)

forward

forward(input, target)

Parameters:

Name Type Description Default
input Tensor

Input tensor with predictions (not probabilities)

required
target Tensor

Target tensor with the actual classes

required

Examples:

>>> import torch
>>>
>>> from pytorch_widedeep.losses import FocalR_L1Loss
>>>
>>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
>>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
>>> loss = FocalR_L1Loss()(input, target)
Source code in pytorch_widedeep/losses.py
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
def forward(
    self,
    input: Tensor,
    target: Tensor,
) -> Tensor:
    r"""
    Parameters
    ----------
    input: Tensor
        Input tensor with predictions (not probabilities)
    target: Tensor
        Target tensor with the actual classes

    Examples
    --------
    >>> import torch
    >>>
    >>> from pytorch_widedeep.losses import FocalR_L1Loss
    >>>
    >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
    >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
    >>> loss = FocalR_L1Loss()(input, target)
    """
    loss = F.l1_loss(input, target, reduction="none")
    if self.activation_fn == "tanh":
        loss *= (torch.tanh(self.beta * torch.abs(input - target))) ** self.gamma
    elif self.activation_fn == "sigmoid":
        loss *= (
            2 * torch.sigmoid(self.beta * torch.abs(input - target)) - 1
        ) ** self.gamma
    else:
        ValueError(
            "Incorrect activation function value - must be in ['sigmoid', 'tanh']"
        )
    return torch.mean(loss)

FocalR_MSELoss

Bases: Module

Focal-R MSE loss

Based on Delving into Deep Imbalanced Regression.

Parameters:

Name Type Description Default
beta float

Focal Loss beta parameter in their implementation

0.2
gamma float

Focal Loss gamma parameter

1.0
activation_fn Literal[sigmoid, tanh]

Activation function to be used during the computation of the loss. Possible values are 'sigmoid' and 'tanh'. See the original publication for details.

'sigmoid'
Source code in pytorch_widedeep/losses.py
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
class FocalR_MSELoss(nn.Module):
    r"""Focal-R MSE loss

    Based on [Delving into Deep Imbalanced Regression](https://arxiv.org/abs/2102.09554).

    Parameters
    ----------
    beta: float
        Focal Loss `beta` parameter in their implementation
    gamma: float
        Focal Loss `gamma` parameter
    activation_fn: str, default = "sigmoid"
        Activation function to be used during the computation of the loss.
        Possible values are _'sigmoid'_ and _'tanh'_. See the original
        publication for details.
    """

    def __init__(
        self,
        beta: float = 0.2,
        gamma: float = 1.0,
        activation_fn: Literal["sigmoid", "tanh"] = "sigmoid",
    ):
        super().__init__()
        self.beta = beta
        self.gamma = gamma
        self.activation_fn = activation_fn

    def forward(
        self,
        input: Tensor,
        target: Tensor,
    ) -> Tensor:
        r"""
        Parameters
        ----------
        input: Tensor
            Input tensor with predictions (not probabilities)
        target: Tensor
            Target tensor with the actual classes

        Examples
        --------
        >>> import torch
        >>>
        >>> from pytorch_widedeep.losses import FocalR_MSELoss
        >>>
        >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
        >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
        >>> loss = FocalR_MSELoss()(input, target)
        """
        loss = (input - target) ** 2
        if self.activation_fn == "tanh":
            loss *= (torch.tanh(self.beta * torch.abs(input - target))) ** self.gamma
        elif self.activation_fn == "sigmoid":
            loss *= (
                2 * torch.sigmoid(self.beta * torch.abs((input - target) ** 2)) - 1
            ) ** self.gamma
        else:
            ValueError(
                "Incorrect activation function value - must be in ['sigmoid', 'tanh']"
            )
        return torch.mean(loss)

forward

forward(input, target)

Parameters:

Name Type Description Default
input Tensor

Input tensor with predictions (not probabilities)

required
target Tensor

Target tensor with the actual classes

required

Examples:

>>> import torch
>>>
>>> from pytorch_widedeep.losses import FocalR_MSELoss
>>>
>>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
>>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
>>> loss = FocalR_MSELoss()(input, target)
Source code in pytorch_widedeep/losses.py
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
def forward(
    self,
    input: Tensor,
    target: Tensor,
) -> Tensor:
    r"""
    Parameters
    ----------
    input: Tensor
        Input tensor with predictions (not probabilities)
    target: Tensor
        Target tensor with the actual classes

    Examples
    --------
    >>> import torch
    >>>
    >>> from pytorch_widedeep.losses import FocalR_MSELoss
    >>>
    >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
    >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
    >>> loss = FocalR_MSELoss()(input, target)
    """
    loss = (input - target) ** 2
    if self.activation_fn == "tanh":
        loss *= (torch.tanh(self.beta * torch.abs(input - target))) ** self.gamma
    elif self.activation_fn == "sigmoid":
        loss *= (
            2 * torch.sigmoid(self.beta * torch.abs((input - target) ** 2)) - 1
        ) ** self.gamma
    else:
        ValueError(
            "Incorrect activation function value - must be in ['sigmoid', 'tanh']"
        )
    return torch.mean(loss)

FocalR_RMSELoss

Bases: Module

Focal-R RMSE loss

Based on Delving into Deep Imbalanced Regression.

Parameters:

Name Type Description Default
beta float

Focal Loss beta parameter in their implementation

0.2
gamma float

Focal Loss gamma parameter

1.0
activation_fn Literal[sigmoid, tanh]

Activation function to be used during the computation of the loss. Possible values are 'sigmoid' and 'tanh'. See the original publication for details.

'sigmoid'
Source code in pytorch_widedeep/losses.py
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
class FocalR_RMSELoss(nn.Module):
    r"""Focal-R RMSE loss

    Based on [Delving into Deep Imbalanced Regression](https://arxiv.org/abs/2102.09554).

    Parameters
    ----------
    beta: float
        Focal Loss `beta` parameter in their implementation
    gamma: float
        Focal Loss `gamma` parameter
    activation_fn: str, default = "sigmoid"
        Activation function to be used during the computation of the loss.
        Possible values are _'sigmoid'_ and _'tanh'_. See the original
        publication for details.
    """

    def __init__(
        self,
        beta: float = 0.2,
        gamma: float = 1.0,
        activation_fn: Literal["sigmoid", "tanh"] = "sigmoid",
    ):
        super().__init__()
        self.beta = beta
        self.gamma = gamma
        self.activation_fn = activation_fn

    def forward(
        self,
        input: Tensor,
        target: Tensor,
    ) -> Tensor:
        r"""
        Parameters
        ----------
        input: Tensor
            Input tensor with predictions (not probabilities)
        target: Tensor
            Target tensor with the actual classes

        Examples
        --------
        >>> import torch
        >>>
        >>> from pytorch_widedeep.losses import FocalR_RMSELoss
        >>>
        >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
        >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
        >>> loss = FocalR_RMSELoss()(input, target)
        """
        loss = (input - target) ** 2
        if self.activation_fn == "tanh":
            loss *= (torch.tanh(self.beta * torch.abs(input - target))) ** self.gamma
        elif self.activation_fn == "sigmoid":
            loss *= (
                2 * torch.sigmoid(self.beta * torch.abs((input - target) ** 2)) - 1
            ) ** self.gamma
        else:
            ValueError(
                "Incorrect activation function value - must be in ['sigmoid', 'tanh']"
            )
        return torch.sqrt(torch.mean(loss))

forward

forward(input, target)

Parameters:

Name Type Description Default
input Tensor

Input tensor with predictions (not probabilities)

required
target Tensor

Target tensor with the actual classes

required

Examples:

>>> import torch
>>>
>>> from pytorch_widedeep.losses import FocalR_RMSELoss
>>>
>>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
>>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
>>> loss = FocalR_RMSELoss()(input, target)
Source code in pytorch_widedeep/losses.py
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
def forward(
    self,
    input: Tensor,
    target: Tensor,
) -> Tensor:
    r"""
    Parameters
    ----------
    input: Tensor
        Input tensor with predictions (not probabilities)
    target: Tensor
        Target tensor with the actual classes

    Examples
    --------
    >>> import torch
    >>>
    >>> from pytorch_widedeep.losses import FocalR_RMSELoss
    >>>
    >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
    >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
    >>> loss = FocalR_RMSELoss()(input, target)
    """
    loss = (input - target) ** 2
    if self.activation_fn == "tanh":
        loss *= (torch.tanh(self.beta * torch.abs(input - target))) ** self.gamma
    elif self.activation_fn == "sigmoid":
        loss *= (
            2 * torch.sigmoid(self.beta * torch.abs((input - target) ** 2)) - 1
        ) ** self.gamma
    else:
        ValueError(
            "Incorrect activation function value - must be in ['sigmoid', 'tanh']"
        )
    return torch.sqrt(torch.mean(loss))

HuberLoss

Bases: Module

Hubbler Loss

Based on Delving into Deep Imbalanced Regression.

Source code in pytorch_widedeep/losses.py
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
class HuberLoss(nn.Module):
    r"""Hubbler Loss

    Based on [Delving into Deep Imbalanced Regression](https://arxiv.org/abs/2102.09554).
    """

    def __init__(self, beta: float = 0.2):
        super().__init__()
        self.beta = beta

    def forward(
        self,
        input: Tensor,
        target: Tensor,
    ) -> Tensor:
        r"""
        Parameters
        ----------
        input: Tensor
            Input tensor with predictions (not probabilities)
        target: Tensor
            Target tensor with the actual classes

        Examples
        --------
        >>> import torch
        >>>
        >>> from pytorch_widedeep.losses import HuberLoss
        >>>
        >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
        >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
        >>> loss = HuberLoss()(input, target)
        """
        l1_loss = torch.abs(input - target)
        cond = l1_loss < self.beta
        loss = torch.where(
            cond, 0.5 * l1_loss**2 / self.beta, l1_loss - 0.5 * self.beta
        )
        return torch.mean(loss)

forward

forward(input, target)

Parameters:

Name Type Description Default
input Tensor

Input tensor with predictions (not probabilities)

required
target Tensor

Target tensor with the actual classes

required

Examples:

>>> import torch
>>>
>>> from pytorch_widedeep.losses import HuberLoss
>>>
>>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
>>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
>>> loss = HuberLoss()(input, target)
Source code in pytorch_widedeep/losses.py
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
def forward(
    self,
    input: Tensor,
    target: Tensor,
) -> Tensor:
    r"""
    Parameters
    ----------
    input: Tensor
        Input tensor with predictions (not probabilities)
    target: Tensor
        Target tensor with the actual classes

    Examples
    --------
    >>> import torch
    >>>
    >>> from pytorch_widedeep.losses import HuberLoss
    >>>
    >>> target = torch.tensor([1, 1.2, 0, 2]).view(-1, 1)
    >>> input = torch.tensor([0.6, 0.7, 0.3, 0.8]).view(-1, 1)
    >>> loss = HuberLoss()(input, target)
    """
    l1_loss = torch.abs(input - target)
    cond = l1_loss < self.beta
    loss = torch.where(
        cond, 0.5 * l1_loss**2 / self.beta, l1_loss - 0.5 * self.beta
    )
    return torch.mean(loss)

InfoNCELoss

Bases: Module

InfoNCE Loss. Loss applied during the Contrastive Denoising Self Supervised Pre-training routine available in this library

ℹ️ NOTE: This loss is in principle not exposed to the user, as it is used internally in the library, but it is included here for completion.

See SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training and references therein

Partially inspired by the code in this repo

Parameters:

Name Type Description Default
temperature float

The logits are divided by the temperature before computing the loss value

0.1
reduction str

Loss reduction method

'mean'
Source code in pytorch_widedeep/losses.py
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
class InfoNCELoss(nn.Module):
    r"""InfoNCE Loss. Loss applied during the Contrastive Denoising Self
    Supervised Pre-training routine available in this library

    :information_source: **NOTE**: This loss is in principle not exposed to
     the user, as it is used internally in the library, but it is included
     here for completion.

    See [SAINT: Improved Neural Networks for Tabular Data via Row Attention
    and Contrastive Pre-Training](https://arxiv.org/abs/2106.01342) and
    references therein

    Partially inspired by the code in this [repo](https://github.com/RElbers/info-nce-pytorch)

    Parameters
    ----------
    temperature: float, default = 0.1
        The logits are divided by the temperature before computing the loss value
    reduction: str, default = "mean"
        Loss reduction method
    """

    def __init__(self, temperature: float = 0.1, reduction: str = "mean"):
        super(InfoNCELoss, self).__init__()

        self.temperature = temperature
        self.reduction = reduction

    def forward(self, g_projs: Tuple[Tensor, Tensor]) -> Tensor:
        r"""
        Parameters
        ----------
        g_projs: Tuple
            Tuple with the two tensors corresponding to the output of the two
            projection heads, as described 'SAINT: Improved Neural Networks
            for Tabular Data via Row Attention and Contrastive Pre-Training'.

        Examples
        --------
        >>> import torch
        >>> from pytorch_widedeep.losses import InfoNCELoss
        >>> g_projs = (torch.rand(3, 5, 16), torch.rand(3, 5, 16))
        >>> loss = InfoNCELoss()
        >>> res = loss(g_projs)
        """
        z, z_ = g_projs[0], g_projs[1]

        norm_z = F.normalize(z, dim=-1).flatten(1)
        norm_z_ = F.normalize(z_, dim=-1).flatten(1)

        logits = (norm_z @ norm_z_.t()) / self.temperature
        logits_ = (norm_z_ @ norm_z.t()) / self.temperature

        # the target/labels are the entries on the diagonal
        target = torch.arange(len(norm_z), device=norm_z.device)

        loss = F.cross_entropy(logits, target, reduction=self.reduction)
        loss_ = F.cross_entropy(logits_, target, reduction=self.reduction)

        return (loss + loss_) / 2.0

forward

forward(g_projs)

Parameters:

Name Type Description Default
g_projs Tuple[Tensor, Tensor]

Tuple with the two tensors corresponding to the output of the two projection heads, as described 'SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training'.

required

Examples:

>>> import torch
>>> from pytorch_widedeep.losses import InfoNCELoss
>>> g_projs = (torch.rand(3, 5, 16), torch.rand(3, 5, 16))
>>> loss = InfoNCELoss()
>>> res = loss(g_projs)
Source code in pytorch_widedeep/losses.py
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
def forward(self, g_projs: Tuple[Tensor, Tensor]) -> Tensor:
    r"""
    Parameters
    ----------
    g_projs: Tuple
        Tuple with the two tensors corresponding to the output of the two
        projection heads, as described 'SAINT: Improved Neural Networks
        for Tabular Data via Row Attention and Contrastive Pre-Training'.

    Examples
    --------
    >>> import torch
    >>> from pytorch_widedeep.losses import InfoNCELoss
    >>> g_projs = (torch.rand(3, 5, 16), torch.rand(3, 5, 16))
    >>> loss = InfoNCELoss()
    >>> res = loss(g_projs)
    """
    z, z_ = g_projs[0], g_projs[1]

    norm_z = F.normalize(z, dim=-1).flatten(1)
    norm_z_ = F.normalize(z_, dim=-1).flatten(1)

    logits = (norm_z @ norm_z_.t()) / self.temperature
    logits_ = (norm_z_ @ norm_z.t()) / self.temperature

    # the target/labels are the entries on the diagonal
    target = torch.arange(len(norm_z), device=norm_z.device)

    loss = F.cross_entropy(logits, target, reduction=self.reduction)
    loss_ = F.cross_entropy(logits_, target, reduction=self.reduction)

    return (loss + loss_) / 2.0

DenoisingLoss

Bases: Module

Denoising Loss. Loss applied during the Contrastive Denoising Self Supervised Pre-training routine available in this library

ℹ️ NOTE: This loss is in principle not exposed to the user, as it is used internally in the library, but it is included here for completion.

See SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training and references therein

Parameters:

Name Type Description Default
lambda_cat float

Multiplicative factor that will be applied to loss associated to the categorical features

1.0
lambda_cont float

Multiplicative factor that will be applied to loss associated to the continuous features

1.0
reduction str

Loss reduction method

'mean'
Source code in pytorch_widedeep/losses.py
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
class DenoisingLoss(nn.Module):
    r"""Denoising Loss. Loss applied during the Contrastive Denoising Self
    Supervised Pre-training routine available in this library

    :information_source: **NOTE**: This loss is in principle not exposed to
     the user, as it is used internally in the library, but it is included
     here for completion.

    See [SAINT: Improved Neural Networks for Tabular Data via Row Attention
    and Contrastive Pre-Training](https://arxiv.org/abs/2106.01342) and
    references therein

    Parameters
    ----------
    lambda_cat: float, default = 1.
        Multiplicative factor that will be applied to loss associated to the
        categorical features
    lambda_cont: float, default = 1.
        Multiplicative factor that will be applied to loss associated to the
        continuous features
    reduction: str, default = "mean"
        Loss reduction method
    """

    def __init__(
        self, lambda_cat: float = 1.0, lambda_cont: float = 1.0, reduction: str = "mean"
    ):
        super(DenoisingLoss, self).__init__()

        self.lambda_cat = lambda_cat
        self.lambda_cont = lambda_cont
        self.reduction = reduction

    def forward(
        self,
        x_cat_and_cat_: Optional[
            Union[List[Tuple[Tensor, Tensor]], Tuple[Tensor, Tensor]]
        ],
        x_cont_and_cont_: Optional[
            Union[List[Tuple[Tensor, Tensor]], Tuple[Tensor, Tensor]]
        ],
    ) -> Tensor:
        r"""
        Parameters
        ----------
        x_cat_and_cat_: tuple of Tensors or lists of tuples
            Tuple of tensors containing the raw input features and their
            encodings, referred in the SAINT paper as $x$ and $x''$
            respectively. If one denoising MLP is used per categorical
            feature `x_cat_and_cat_` will be a list of tuples, one per
            categorical feature
        x_cont_and_cont_: tuple of Tensors or lists of tuples
            same as `x_cat_and_cat_` but for continuous columns

        Examples
        --------
        >>> import torch
        >>> from pytorch_widedeep.losses import DenoisingLoss
        >>> x_cat_and_cat_ = (torch.empty(3).random_(3).long(), torch.randn(3, 3))
        >>> x_cont_and_cont_ = (torch.randn(3, 1), torch.randn(3, 1))
        >>> loss = DenoisingLoss()
        >>> res = loss(x_cat_and_cat_, x_cont_and_cont_)
        """

        loss_cat = (
            self._compute_cat_loss(x_cat_and_cat_)
            if x_cat_and_cat_ is not None
            else torch.tensor(0.0)
        )
        loss_cont = (
            self._compute_cont_loss(x_cont_and_cont_)
            if x_cont_and_cont_ is not None
            else torch.tensor(0.0)
        )

        return self.lambda_cat * loss_cat + self.lambda_cont * loss_cont

    def _compute_cat_loss(
        self, x_cat_and_cat_: Union[List[Tuple[Tensor, Tensor]], Tuple[Tensor, Tensor]]
    ) -> Tensor:
        loss_cat = torch.tensor(0.0, device=self._get_device(x_cat_and_cat_))
        if isinstance(x_cat_and_cat_, list):
            for x, x_ in x_cat_and_cat_:
                loss_cat += F.cross_entropy(x_, x, reduction=self.reduction)
        elif isinstance(x_cat_and_cat_, tuple):
            x, x_ = x_cat_and_cat_
            loss_cat += F.cross_entropy(x_, x, reduction=self.reduction)

        return loss_cat

    def _compute_cont_loss(self, x_cont_and_cont_) -> Tensor:
        loss_cont = torch.tensor(0.0, device=self._get_device(x_cont_and_cont_))
        if isinstance(x_cont_and_cont_, list):
            for x, x_ in x_cont_and_cont_:
                loss_cont += F.mse_loss(x_, x, reduction=self.reduction)
        elif isinstance(x_cont_and_cont_, tuple):
            x, x_ = x_cont_and_cont_
            loss_cont += F.mse_loss(x_, x, reduction=self.reduction)

        return loss_cont

    @staticmethod
    def _get_device(
        x_and_x_: Union[List[Tuple[Tensor, Tensor]], Tuple[Tensor, Tensor]]
    ):
        if isinstance(x_and_x_, tuple):
            device = x_and_x_[0].device
        elif isinstance(x_and_x_, list):
            device = x_and_x_[0][0].device
        return device

forward

forward(x_cat_and_cat_, x_cont_and_cont_)

Parameters:

Name Type Description Default
x_cat_and_cat_ Optional[Union[List[Tuple[Tensor, Tensor]], Tuple[Tensor, Tensor]]]

Tuple of tensors containing the raw input features and their encodings, referred in the SAINT paper as \(x\) and \(x''\) respectively. If one denoising MLP is used per categorical feature x_cat_and_cat_ will be a list of tuples, one per categorical feature

required
x_cont_and_cont_ Optional[Union[List[Tuple[Tensor, Tensor]], Tuple[Tensor, Tensor]]]

same as x_cat_and_cat_ but for continuous columns

required

Examples:

>>> import torch
>>> from pytorch_widedeep.losses import DenoisingLoss
>>> x_cat_and_cat_ = (torch.empty(3).random_(3).long(), torch.randn(3, 3))
>>> x_cont_and_cont_ = (torch.randn(3, 1), torch.randn(3, 1))
>>> loss = DenoisingLoss()
>>> res = loss(x_cat_and_cat_, x_cont_and_cont_)
Source code in pytorch_widedeep/losses.py
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
def forward(
    self,
    x_cat_and_cat_: Optional[
        Union[List[Tuple[Tensor, Tensor]], Tuple[Tensor, Tensor]]
    ],
    x_cont_and_cont_: Optional[
        Union[List[Tuple[Tensor, Tensor]], Tuple[Tensor, Tensor]]
    ],
) -> Tensor:
    r"""
    Parameters
    ----------
    x_cat_and_cat_: tuple of Tensors or lists of tuples
        Tuple of tensors containing the raw input features and their
        encodings, referred in the SAINT paper as $x$ and $x''$
        respectively. If one denoising MLP is used per categorical
        feature `x_cat_and_cat_` will be a list of tuples, one per
        categorical feature
    x_cont_and_cont_: tuple of Tensors or lists of tuples
        same as `x_cat_and_cat_` but for continuous columns

    Examples
    --------
    >>> import torch
    >>> from pytorch_widedeep.losses import DenoisingLoss
    >>> x_cat_and_cat_ = (torch.empty(3).random_(3).long(), torch.randn(3, 3))
    >>> x_cont_and_cont_ = (torch.randn(3, 1), torch.randn(3, 1))
    >>> loss = DenoisingLoss()
    >>> res = loss(x_cat_and_cat_, x_cont_and_cont_)
    """

    loss_cat = (
        self._compute_cat_loss(x_cat_and_cat_)
        if x_cat_and_cat_ is not None
        else torch.tensor(0.0)
    )
    loss_cont = (
        self._compute_cont_loss(x_cont_and_cont_)
        if x_cont_and_cont_ is not None
        else torch.tensor(0.0)
    )

    return self.lambda_cat * loss_cat + self.lambda_cont * loss_cont

EncoderDecoderLoss

Bases: Module

'Standard' Encoder Decoder Loss. Loss applied during the Endoder-Decoder Self-Supervised Pre-Training routine available in this library

ℹ️ NOTE: This loss is in principle not exposed to the user, as it is used internally in the library, but it is included here for completion.

The implementation of this lost is based on that at the tabnet repo, which is in itself an adaptation of that in the original paper TabNet: Attentive Interpretable Tabular Learning.

Parameters:

Name Type Description Default
eps float

Simply a small number to avoid dividing by zero

1e-09
Source code in pytorch_widedeep/losses.py
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
class EncoderDecoderLoss(nn.Module):
    r"""'_Standard_' Encoder Decoder Loss. Loss applied during the Endoder-Decoder
     Self-Supervised Pre-Training routine available in this library

    :information_source: **NOTE**: This loss is in principle not exposed to
     the user, as it is used internally in the library, but it is included
     here for completion.

    The implementation of this lost is based on that at the
    [tabnet repo](https://github.com/dreamquark-ai/tabnet), which is in itself an
    adaptation of that in the original paper [TabNet: Attentive
    Interpretable Tabular Learning](https://arxiv.org/abs/1908.07442).

    Parameters
    ----------
    eps: float
        Simply a small number to avoid dividing by zero
    """

    def __init__(self, eps: float = 1e-9):
        super(EncoderDecoderLoss, self).__init__()
        self.eps = eps

    def forward(self, x_true: Tensor, x_pred: Tensor, mask: Tensor) -> Tensor:
        r"""
        Parameters
        ----------
        x_true: Tensor
            Embeddings of the input data
        x_pred: Tensor
            Reconstructed embeddings
        mask: Tensor
            Mask with 1s indicated that the reconstruction, and therefore the
            loss, is based on those features.

        Examples
        --------
        >>> import torch
        >>> from pytorch_widedeep.losses import EncoderDecoderLoss
        >>> x_true = torch.rand(3, 3)
        >>> x_pred = torch.rand(3, 3)
        >>> mask = torch.empty(3, 3).random_(2)
        >>> loss = EncoderDecoderLoss()
        >>> res = loss(x_true, x_pred, mask)
        """

        errors = x_pred - x_true

        reconstruction_errors = torch.mul(errors, mask) ** 2

        x_true_means = torch.mean(x_true, dim=0)
        x_true_means[x_true_means == 0] = 1

        x_true_stds = torch.std(x_true, dim=0) ** 2
        x_true_stds[x_true_stds == 0] = x_true_means[x_true_stds == 0]

        features_loss = torch.matmul(reconstruction_errors, 1 / x_true_stds)
        nb_reconstructed_variables = torch.sum(mask, dim=1)
        features_loss_norm = features_loss / (nb_reconstructed_variables + self.eps)

        loss = torch.mean(features_loss_norm)

        return loss

forward

forward(x_true, x_pred, mask)

Parameters:

Name Type Description Default
x_true Tensor

Embeddings of the input data

required
x_pred Tensor

Reconstructed embeddings

required
mask Tensor

Mask with 1s indicated that the reconstruction, and therefore the loss, is based on those features.

required

Examples:

>>> import torch
>>> from pytorch_widedeep.losses import EncoderDecoderLoss
>>> x_true = torch.rand(3, 3)
>>> x_pred = torch.rand(3, 3)
>>> mask = torch.empty(3, 3).random_(2)
>>> loss = EncoderDecoderLoss()
>>> res = loss(x_true, x_pred, mask)
Source code in pytorch_widedeep/losses.py
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
def forward(self, x_true: Tensor, x_pred: Tensor, mask: Tensor) -> Tensor:
    r"""
    Parameters
    ----------
    x_true: Tensor
        Embeddings of the input data
    x_pred: Tensor
        Reconstructed embeddings
    mask: Tensor
        Mask with 1s indicated that the reconstruction, and therefore the
        loss, is based on those features.

    Examples
    --------
    >>> import torch
    >>> from pytorch_widedeep.losses import EncoderDecoderLoss
    >>> x_true = torch.rand(3, 3)
    >>> x_pred = torch.rand(3, 3)
    >>> mask = torch.empty(3, 3).random_(2)
    >>> loss = EncoderDecoderLoss()
    >>> res = loss(x_true, x_pred, mask)
    """

    errors = x_pred - x_true

    reconstruction_errors = torch.mul(errors, mask) ** 2

    x_true_means = torch.mean(x_true, dim=0)
    x_true_means[x_true_means == 0] = 1

    x_true_stds = torch.std(x_true, dim=0) ** 2
    x_true_stds[x_true_stds == 0] = x_true_means[x_true_stds == 0]

    features_loss = torch.matmul(reconstruction_errors, 1 / x_true_stds)
    nb_reconstructed_variables = torch.sum(mask, dim=1)
    features_loss_norm = features_loss / (nb_reconstructed_variables + self.eps)

    loss = torch.mean(features_loss_norm)

    return loss

MultiTargetRegressionLoss

Bases: Module

This class is a wrapper around the Pytorch MSELoss. It allows for multi-target regression problems. The user can provide a list of weights to apply to each target. The loss can be either the sum or the mean of the individual losses

Parameters:

Name Type Description Default
weights Optional[List[float]]

List of weights to apply to the loss associated to each target. The length of the list must match the number of targets. Alias: 'target_weights'

None
reduction Literal[mean, sum]

Specifies the reduction to apply to the loss associated to each target: 'mean' | 'sum'. Note that this is NOT the same as the reduction in the MSELoss. This reduction is applied after the loss for each target has been computed. Alias: 'target_reduction'

'mean'

Examples:

>>> import torch
>>> from pytorch_widedeep.losses_multitarget import MultiTargetRegressionLoss
>>> input = torch.randn(3, 2)
>>> target = torch.randn(3, 2)
>>> loss = MultiTargetRegressionLoss(weights=[0.5, 0.5], reduction="mean")
>>> output = loss(input, target)
Source code in pytorch_widedeep/losses_multitarget.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
class MultiTargetRegressionLoss(nn.Module):
    """
    This class is a wrapper around the Pytorch MSELoss. It allows for multi-target
    regression problems. The user can provide a list of weights to apply to each
    target. The loss can be either the sum or the mean of the individual losses

    Parameters
    ----------
    weights: Optional[List[float], default = None]
        List of weights to apply to the loss associated to each target. The
        length of the list must match the number of targets.
        Alias: 'target_weights'
    reduction: Literal["mean", "sum"], default = "mean
        Specifies the reduction to apply to the loss associated to each
        target: 'mean' | 'sum'. Note that this is NOT the same as the
        reduction in the MSELoss. This reduction is applied after the loss
        for each target has been computed. Alias: 'target_reduction'

    Examples
    --------
    >>> import torch
    >>> from pytorch_widedeep.losses_multitarget import MultiTargetRegressionLoss
    >>> input = torch.randn(3, 2)
    >>> target = torch.randn(3, 2)
    >>> loss = MultiTargetRegressionLoss(weights=[0.5, 0.5], reduction="mean")
    >>> output = loss(input, target)
    """

    @alias("reduction", ["target_reduction"])
    @alias("weights", ["target_weights"])
    def __init__(
        self,
        weights: Optional[List[float]] = None,
        reduction: Literal["mean", "sum"] = "mean",
    ):
        super(MultiTargetRegressionLoss, self).__init__()

        self.weights = weights
        self.reduction = reduction

        if self.reduction not in ["mean", "sum"]:
            raise ValueError("reduction must be either 'mean' or 'sum'")

    def forward(self, input: Tensor, target: Tensor) -> Tensor:

        assert input.size() == target.size()

        if self.weights is not None:

            assert len(self.weights) == input.size(1), (
                "The number of weights must match the number of targets. "
                f"Got {len(self.weights)} weights and {input.size(1)} targets"
            )

            loss = F.mse_loss(input, target, reduction="none") * torch.tensor(
                self.weights
            ).to(input.device)
        else:
            loss = F.mse_loss(input, target, reduction="none")

        return loss.mean() if self.reduction == "mean" else loss.sum()

MultiTargetClassificationLoss

Bases: Module

This class is a wrapper around the Pytorch binary_cross_entropy_with_logits and cross_entropy losses. It allows for multi-target classification problems. The user can provide a list of weights to apply to each target. The loss can be either the sum or the mean of the individual losses

Parameters:

Name Type Description Default
binary_config Optional[List[Union[int, Tuple[int, float]]]]

List of integers with the index of the target for binary classification or tuples with two elements: the index of the targets or binary classification and the positive weight for binary classification

None
multiclass_config Optional[List[Union[Tuple[int, int], Tuple[int, int, List[float]]]]]

List of tuples with two or three elements: the index of the target and the number of classes for multiclass classification, or a tuple with the index of the target, the number of classes and a list of weights to apply to each class (i.e. the 'weight' parameter in the cross_entropy loss)

None
weights Optional[List[float]]

List of weights to apply to the loss associated to each target. The length of the list must match the number of targets. Alias: 'target_weights'

None
reduction Literal[mean, sum]

Specifies the reduction to apply to the loss associated to each target: 'mean' | 'sum'. Note that this is NOT the same as the reduction in the cross_entropy loss or the binary_cross_entropy_with_logits. This reduction is applied after the loss for each target has been computed. Alias: 'target_reduction'

'mean'
binary_trick bool

If True, each target will be considered independently and the loss will be computed as binary_cross_entropy_with_logits. This is a faster implementation. Note that the 'weights' parameter is not compatible with binary_trick=True. Also note that if binary_trick=True, the 'binary_config' must be a list of integers and the 'multiclass_config' must be a list of tuples with two integers: the index of the target and the number of classes. Finally, if binary_trick=True, the binary targets must be the first targets in the target tensor.

ℹ️ NOTE: When using the binary_trick, the binary targets are considered as 2 classes. Therefore, the pred_dim parametere of the WideDeep class should be adjusted accordingly (adding 2 to per binary target). For example, in a problem with a binary target and a 4 class multiclassification target, the pred_dim should be 6.

False

Examples:

>>> import torch
>>> from pytorch_widedeep.losses_multitarget import MultiTargetClassificationLoss
>>> input = torch.randn(5, 4)
>>> input_binary_trick = torch.randn(5, 5)
>>> target = torch.stack([torch.tensor([0, 1, 0, 1, 1]), torch.tensor([0, 1, 2, 0, 2])], 1)
>>> loss_1 = MultiTargetClassificationLoss(binary_config=[0], multiclass_config=[(1, 3)], reduction="mean")
>>> output_1 = loss_1(input, target)
>>> loss_2 = MultiTargetClassificationLoss(binary_config=[(0, 0.5)], multiclass_config=[(1, 3, [1., 2., 3.])],
... reduction="sum", weights=[0.5, 0.5])
>>> output_2 = loss_2(input, target)
>>> loss_3 = MultiTargetClassificationLoss(binary_config=[0], multiclass_config=[(1, 3)], binary_trick=True)
>>> output_3 = loss_3(input_binary_trick, target)
Source code in pytorch_widedeep/losses_multitarget.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
class MultiTargetClassificationLoss(nn.Module):
    """
    This class is a wrapper around the Pytorch binary_cross_entropy_with_logits and
    cross_entropy losses. It allows for multi-target classification problems. The
    user can provide a list of weights to apply to each target. The loss can be
    either the sum or the mean of the individual losses

    Parameters
    ----------
    binary_config: Optional[List[int | Tuple[int, float]]], default = None
        List of integers with the index of the target for binary
        classification or tuples with two elements: the index of the targets
        or binary classification and the positive weight for binary
        classification
    multiclass_config: Optional[Tuple[int, int] | Tuple[int, int, List[float]]], default = None
        List of tuples with two or three elements: the index of the target and the
        number of classes for multiclass classification, or a tuple with the index of
        the target, the number of classes and a list of weights to apply to each class
        (i.e. the 'weight' parameter in the cross_entropy loss)
    weights: Optional[List[float], default = None]
        List of weights to apply to the loss associated to each target. The
        length of the list must match the number of targets.
        Alias: 'target_weights'
    reduction: Literal["mean", "sum"], default = "sum
        Specifies the reduction to apply to the loss associated to each
        target: 'mean' | 'sum'. Note that this is NOT the same as the
        reduction in the cross_entropy loss or the
        binary_cross_entropy_with_logits. This reduction is applied after the
        loss for each target has been computed. Alias: 'target_reduction'
    binary_trick: bool, default = False
        If True, each target will be considered independently and the loss
        will be computed as binary_cross_entropy_with_logits. This is a
        faster implementation. Note that the 'weights' parameter is not
        compatible with binary_trick=True. Also note that if
        binary_trick=True, the 'binary_config' must be a list of integers and
        the 'multiclass_config' must be a list of tuples with two integers:
        the index of the target and the number of classes. Finally, if
        binary_trick=True, the binary targets must be the first targets in
        the target tensor.

        :information_source: **NOTE**: When using the binary_trick, the binary targets are
          considered as 2 classes. Therefore, the pred_dim parametere of the
          WideDeep class should be adjusted accordingly (adding 2 to per
          binary target). For example, in a problem with a binary target and
          a 4 class multiclassification target, the pred_dim should be 6.


    Examples
    --------
    >>> import torch
    >>> from pytorch_widedeep.losses_multitarget import MultiTargetClassificationLoss
    >>> input = torch.randn(5, 4)
    >>> input_binary_trick = torch.randn(5, 5)
    >>> target = torch.stack([torch.tensor([0, 1, 0, 1, 1]), torch.tensor([0, 1, 2, 0, 2])], 1)
    >>> loss_1 = MultiTargetClassificationLoss(binary_config=[0], multiclass_config=[(1, 3)], reduction="mean")
    >>> output_1 = loss_1(input, target)
    >>> loss_2 = MultiTargetClassificationLoss(binary_config=[(0, 0.5)], multiclass_config=[(1, 3, [1., 2., 3.])],
    ... reduction="sum", weights=[0.5, 0.5])
    >>> output_2 = loss_2(input, target)
    >>> loss_3 = MultiTargetClassificationLoss(binary_config=[0], multiclass_config=[(1, 3)], binary_trick=True)
    >>> output_3 = loss_3(input_binary_trick, target)
    """

    @alias("reduction", ["target_reduction"])
    @alias("weights", ["target_weights"])
    def __init__(  # noqa: C901
        self,
        binary_config: Optional[List[Union[int, Tuple[int, float]]]] = None,
        multiclass_config: Optional[
            List[Union[Tuple[int, int], Tuple[int, int, List[float]]]]
        ] = None,
        weights: Optional[List[float]] = None,
        reduction: Literal["mean", "sum"] = "mean",
        binary_trick: bool = False,
    ):
        super(MultiTargetClassificationLoss, self).__init__()

        if reduction not in ["mean", "sum"]:
            raise ValueError("reduction must be either 'mean' or 'sum'")

        self.binary_config = binary_config
        self.multiclass_config = multiclass_config
        self.weights = weights
        self.reduction = reduction
        self.binary_trick = binary_trick

        if self.weights is not None:
            if len(self.weights) != (
                len(self.binary_config) if self.binary_config is not None else 0
            ) + (
                len(self.multiclass_config) if self.multiclass_config is not None else 0
            ):
                raise ValueError(
                    "The number of weights must match the number of binary and multiclass targets"
                )

        if self.binary_trick:
            self._check_inputs_with_binary_trick()
            self._binary_config: List[int] = binary_config  # type: ignore[assignment]
            self._multiclass_config: List[Tuple[int, int]] = self.multiclass_config  # type: ignore[assignment]
        else:
            self.binary_config_with_pos_weights = (
                (self._set_binary_config_without_binary_trick())
                if self.binary_config is not None
                else None
            )
            self.multiclass_config_with_weights = (
                (self._set_multiclass_config_without_binary_trick())
                if self.multiclass_config is not None
                else None
            )

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        if self.binary_trick:
            return self._forward_binary_trick(input, target)
        else:
            return self._forward_without_binary_trick(input, target)

    def _forward_binary_trick(self, input: Tensor, target: Tensor) -> Tensor:
        binary_target_tensors: List[Tensor] = []
        if self._binary_config:
            for idx in self._binary_config:
                binary_target_tensors.append(
                    torch.eye(2)[target[:, idx].long()].to(input.device)
                )
        if self._multiclass_config:
            for idx, n_classes in self._multiclass_config:
                binary_target_tensors.append(
                    torch.eye(n_classes)[target[:, idx].long()].to(input.device)
                )
        binary_target = torch.cat(binary_target_tensors, 1)
        return F.binary_cross_entropy_with_logits(input, binary_target)

    def _forward_without_binary_trick(self, input: Tensor, target: Tensor) -> Tensor:
        losses: List[Tensor] = []
        if self.binary_config_with_pos_weights:
            for idx, bpos_weight in self.binary_config_with_pos_weights:
                _loss = F.binary_cross_entropy_with_logits(
                    input[:, idx],
                    target[:, idx].float(),
                    pos_weight=(
                        torch.tensor(bpos_weight).to(input.device)
                        if bpos_weight is not None
                        else None
                    ),
                )
                losses.append(_loss)
        if self.multiclass_config_with_weights:
            for idx, n_classes, mpos_weight in self.multiclass_config_with_weights:
                _loss = F.cross_entropy(
                    input[:, idx : idx + n_classes],
                    target[:, idx].long(),
                    weight=(
                        torch.tensor(mpos_weight).to(input.device)
                        if mpos_weight is not None
                        else None
                    ),
                )
                losses.append(_loss)

            if self.weights is not None:
                losses = [l * w for l, w in zip(losses, self.weights)]  # noqa: E741

        return (
            torch.stack(losses).sum()
            if self.reduction == "sum"
            else torch.stack(losses).mean()
        )

    def _check_inputs_with_binary_trick(self):
        if self.binary_config is not None:
            if any(isinstance(bc, tuple) for bc in self.binary_config):
                raise ValueError(
                    "binary_trick=True is only compatible with binary_config as a list of integers"
                )

        if self.multiclass_config is not None:
            if not all(len(mc) == 2 for mc in self.multiclass_config):
                raise ValueError(
                    "binary_trick=True is only compatible with multiclass_config as a list of "
                    "tuples with two integers: the index of the target and the number of classes"
                )

        if self.binary_config is not None and self.multiclass_config is not None:
            last_binary_idx = (
                self.binary_config[-1][0]
                if isinstance(self.binary_config[-1], tuple)
                else self.binary_config[-1]
            )
            if last_binary_idx >= self.multiclass_config[0][0]:
                raise ValueError(
                    "When using binary_trick=True, the binary targets must be the first targets"
                    " in the target tensor"
                )

    def _set_binary_config_without_binary_trick(
        self,
    ) -> List[Tuple[int, Optional[float]]]:
        binary_config_with_pos_weights: List[Tuple[int, Optional[float]]] = []
        for bc in self.binary_config:
            if isinstance(bc, tuple):
                binary_config_with_pos_weights.append(bc)
            else:
                binary_config_with_pos_weights.append((bc, None))
        return binary_config_with_pos_weights

    def _set_multiclass_config_without_binary_trick(
        self,
    ) -> List[Tuple[int, int, Optional[List[float]]]]:
        multiclass_config_with_weights: List[Tuple[int, int, Optional[List[float]]]] = (
            []
        )
        for mc in self.multiclass_config:
            if len(mc) == 3:
                multiclass_config_with_weights.append(mc)  # type: ignore[arg-type]
            else:
                multiclass_config_with_weights.append((mc[0], mc[1], None))
        return multiclass_config_with_weights

MutilTargetRegressionAndClassificationLoss

Bases: Module

This class is a wrapper around the MultiTargetRegressionLoss and the MultiTargetClassificationLoss. It allows for multi-target regression and classification problems. The user can provide a list of weights to apply to each target. The loss can be either the sum or the mean of the individual losses

Parameters:

Name Type Description Default
regression_config List[int]

List of integers with the indices of the regression targets

[]
binary_config Optional[List[Union[int, Tuple[int, float]]]]

List of integers with the index of the target for binary classification or tuples with two elements: the index of the targets or binary classification and the positive weight for binary classification

None
multiclass_config Optional[List[Union[Tuple[int, int], Tuple[int, int, List[float]]]]]

List of tuples with two or three elements: the index of the target and the number of classes for multiclass classification, or a tuple with the index of the target, the number of classes and a list of weights to apply to each class (i.e. the 'weight' parameter in the cross_entropy loss)

None
weights Optional[List[float]]

List of weights to apply to the loss associated to each target. The length of the list must match the number of targets. Alias: 'target_weights'

None
reduction Literal[mean, sum]

Specifies the reduction to apply to the output: 'mean' | 'sum'. Note that this is NOT the same as the reduction in the cross_entropy loss, the binary_cross_entropy_with_logits or the MSELoss. This reduction is applied after each target has been computed. Alias: 'target_reduction'

'mean'
binary_trick bool

If True, each target will be considered independently and the loss will be computed as binary_cross_entropy_with_logits. This is a faster implementation. Note that the 'weights' parameter is not compatible with binary_trick=True. Also note that if binary_trick=True, the 'binary_config' must be a list of integers and the 'multiclass_config' must be a list of tuples with two integers: the index of the target and the number of classes. Finally, if binary_trick=True, the binary targets must be the first targets in the target tensor.

ℹ️ NOTE: When using the binary_trick, the binary targets are considered as 2 classes. Therefore, the pred_dim parametere of the WideDeep class should be adjusted accordingly (adding 2 to per binary target). For example, in a problem with a binary target and a 4 class multiclassification target, the pred_dim should be 6.

False

Examples:

>>> import torch
>>> from pytorch_widedeep.losses_multitarget import MutilTargetRegressionAndClassificationLoss
>>> input = torch.randn(5, 5)
>>> target = torch.stack([torch.randn(5), torch.tensor([0, 1, 0, 1, 1]), torch.tensor([0, 1, 2, 0, 2])], 1)
>>> loss = MutilTargetRegressionAndClassificationLoss(regression_config=[0], binary_config=[2],
... multiclass_config=[(2, 3)], reduction="mean")
>>> output = loss(input, target)
Source code in pytorch_widedeep/losses_multitarget.py
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
class MutilTargetRegressionAndClassificationLoss(nn.Module):
    """
    This class is a wrapper around the MultiTargetRegressionLoss and the
    MultiTargetClassificationLoss. It allows for multi-target regression and
    classification problems. The user can provide a list of weights to apply to
    each target. The loss can be either the sum or the mean of the individual losses

    Parameters
    ----------
    regression_config: List[int], default = []
        List of integers with the indices of the regression targets
    binary_config: Optional[List[int | Tuple[int, float]]], default = None
        List of integers with the index of the target for binary
        classification or tuples with two elements: the index of the targets
        or binary classification and the positive weight for binary
        classification
    multiclass_config: Optional[Tuple[int, int] | Tuple[int, int, List[float]]], default = None
        List of tuples with two or three elements: the index of the target and the
        number of classes for multiclass classification, or a tuple with the index of
        the target, the number of classes and a list of weights to apply to each class
        (i.e. the 'weight' parameter in the cross_entropy loss)
    weights: Optional[List[float], default = None]
        List of weights to apply to the loss associated to each target. The
        length of the list must match the number of targets.
        Alias: 'target_weights'
    reduction: Literal["mean", "sum"], default = "sum
        Specifies the reduction to apply to the output: 'mean' | 'sum'. Note
        that this is NOT the same as the reduction in the cross_entropy loss,
        the binary_cross_entropy_with_logits or the MSELoss. This reduction
        is applied after each target has been computed. Alias: 'target_reduction'
    binary_trick: bool, default = False
        If True, each target will be considered independently and the loss
        will be computed as binary_cross_entropy_with_logits. This is a
        faster implementation. Note that the 'weights' parameter is not
        compatible with binary_trick=True. Also note that if
        binary_trick=True, the 'binary_config' must be a list of integers and
        the 'multiclass_config' must be a list of tuples with two integers:
        the index of the target and the number of classes. Finally, if
        binary_trick=True, the binary targets must be the first targets in
        the target tensor.

        :information_source: **NOTE**: When using the binary_trick, the binary targets are
          considered as 2 classes. Therefore, the pred_dim parametere of the
          WideDeep class should be adjusted accordingly (adding 2 to per
          binary target). For example, in a problem with a binary target and
          a 4 class multiclassification target, the pred_dim should be 6.

    Examples
    --------
    >>> import torch
    >>> from pytorch_widedeep.losses_multitarget import MutilTargetRegressionAndClassificationLoss
    >>> input = torch.randn(5, 5)
    >>> target = torch.stack([torch.randn(5), torch.tensor([0, 1, 0, 1, 1]), torch.tensor([0, 1, 2, 0, 2])], 1)
    >>> loss = MutilTargetRegressionAndClassificationLoss(regression_config=[0], binary_config=[2],
    ... multiclass_config=[(2, 3)], reduction="mean")
    >>> output = loss(input, target)
    """

    @alias("reduction", ["target_reduction"])
    @alias("weights", ["target_weights"])
    def __init__(  # noqa: C901
        self,
        regression_config: List[int] = [],
        binary_config: Optional[List[Union[int, Tuple[int, float]]]] = None,
        multiclass_config: Optional[
            List[Union[Tuple[int, int], Tuple[int, int, List[float]]]]
        ] = None,
        weights: Optional[List[float]] = None,
        reduction: Literal["mean", "sum"] = "mean",
        binary_trick: bool = False,
    ):

        super(MutilTargetRegressionAndClassificationLoss, self).__init__()

        self.regression_config = regression_config

        assert binary_config is not None or multiclass_config is not None, (
            "Either binary_config or multiclass_config must be provided. "
            "Otherwise, use the MultiTargetRegressionLoss"
        )

        if binary_trick:
            self._check_inputs_with_binary_trick(
                regression_config, binary_config, multiclass_config
            )

        if weights is not None:
            if len(weights) != (
                len(regression_config)
                + (len(binary_config) if binary_config is not None else 0)
                + (len(multiclass_config) if multiclass_config is not None else 0)
            ):
                raise ValueError(
                    "The number of weights must match the number of regression, binary and multiclass targets"
                )

            self.weights_regression = self._prepare_weights_for_regression_targets(
                weights, regression_config
            )
            self.weights_binary = self._prepare_weights_per_binary_targets(
                weights, binary_config
            )
            self.weights_multiclass = self._prepare_weights_per_multiclass_targets(
                weights, multiclass_config
            )
            self.weights = weights
        else:
            self.weights_regression = None
            self.weights_binary = None
            self.weights_multiclass = None

        self.multi_target_regression_loss = MultiTargetRegressionLoss(
            weights=self.weights_regression, reduction=reduction
        )

        self.multi_target_classification_loss = MultiTargetClassificationLoss(
            binary_config=binary_config,
            multiclass_config=multiclass_config,
            weights=(
                self.weights_binary + self.weights_multiclass
                if self.weights_binary is not None
                and self.weights_multiclass is not None
                else (
                    self.weights_binary
                    if self.weights_binary is not None
                    else self.weights_multiclass
                )
            ),
            reduction=reduction,
            binary_trick=binary_trick,
        )

    def forward(self, input: Tensor, target: Tensor) -> Tensor:

        regression_loss = self.multi_target_regression_loss(
            input[:, self.regression_config],
            target[:, self.regression_config],
        )

        if self.multi_target_classification_loss.binary_trick:
            classification_loss = self.multi_target_classification_loss(
                input[:, len(self.regression_config) :], target
            )
        else:
            classification_loss = self.multi_target_classification_loss(input, target)

        return regression_loss + classification_loss

    def _check_inputs_with_binary_trick(
        self,
        regression_config: List[int],
        binary_config: Optional[List[Union[int, Tuple[int, float]]]],
        multiclass_config: Optional[
            List[Union[Tuple[int, int], Tuple[int, int, List[float]]]]
        ],
    ) -> None:

        error_msg = "When using binary_trick=True, the targets order must be: regression, binary and multiclass"

        first_regression_idx = regression_config[0]
        last_regression_idx = regression_config[-1]
        if first_regression_idx != 0:
            raise ValueError(error_msg)

        if binary_config is not None and multiclass_config is not None:
            first_binary_idx = (
                binary_config[0][0]
                if isinstance(binary_config[0], tuple)
                else binary_config[0]
            )
            last_binary_idx = (
                binary_config[-1][0]
                if isinstance(binary_config[-1], tuple)
                else binary_config[-1]
            )
            first_multiclass_idx = multiclass_config[0][0]

            if (first_binary_idx != last_regression_idx + 1) or (
                last_binary_idx >= first_multiclass_idx
            ):
                raise ValueError(error_msg)
        elif binary_config is not None:
            first_binary_idx = (
                binary_config[0][0]
                if isinstance(binary_config[0], tuple)
                else binary_config[0]
            )
            if first_binary_idx != last_regression_idx + 1:
                raise ValueError(error_msg)
        elif multiclass_config is not None:
            first_multiclass_idx = multiclass_config[0][0]
            if first_multiclass_idx != last_regression_idx + 1:
                raise ValueError(error_msg)
        else:
            raise ValueError(
                "Either binary_config or multiclass_config must be provided. "
                "Otherwise, use the MultiTargetRegressionLoss"
            )

    def _prepare_weights_for_regression_targets(
        self,
        weights: List[float],
        regression_config: List[int],
    ) -> List[float]:

        weights_regression = [
            w for idx, w in enumerate(weights) if idx in regression_config
        ]

        return weights_regression

    def _prepare_weights_per_binary_targets(
        self,
        weights: List[float],
        binary_config: Optional[List[Union[int, Tuple[int, float]]]],
    ) -> Optional[List[float]]:

        if binary_config is not None:
            binary_idx: List[int] = []
            for bc in binary_config:
                if isinstance(bc, tuple):
                    binary_idx.append(bc[0])
                else:
                    binary_idx.append(bc)
            weights_binary = [w for idx, w in enumerate(weights) if idx in binary_idx]
        else:
            weights_binary = None

        return weights_binary

    def _prepare_weights_per_multiclass_targets(
        self,
        weights: List[float],
        multiclass_config: Optional[
            List[Union[Tuple[int, int], Tuple[int, int, List[float]]]]
        ],
    ) -> Optional[List[float]]:

        if multiclass_config is not None:
            multiclass_idx: List[int] = [mc[0] for mc in multiclass_config]
            weights_multiclass = [
                w for idx, w in enumerate(weights) if idx in multiclass_idx
            ]
        else:
            weights_multiclass = None

        return weights_multiclass