diff --git a/go_emotion_of_transformers_multilabel_text_classification_v2.ipynb b/go_emotion_of_transformers_multilabel_text_classification_v2.ipynb index 834af36..68aaa9c 100644 --- a/go_emotion_of_transformers_multilabel_text_classification_v2.ipynb +++ b/go_emotion_of_transformers_multilabel_text_classification_v2.ipynb @@ -1651,8 +1651,8 @@ " loss = None\n", " if labels is not None:\n", " loss_fct = torch.nn.BCEWithLogitsLoss()\n", - " loss = loss_fct(logits.view(-1, self.num_labels), \n", - " labels.float().view(-1, self.num_labels))\n", + " loss = loss_fct(logits.view(-1, self.num_labels).cuda(), \n", + " labels.float().view(-1, self.num_labels).cuda() )\n", "\n", " if not return_dict:\n", " output = (logits,) + outputs[2:]\n", @@ -2425,4 +2425,4 @@ "outputs": [] } ] -} \ No newline at end of file +}