diff --git a/notebooks/example.ipynb b/notebooks/example.ipynb index 0586271..9e58733 100644 --- a/notebooks/example.ipynb +++ b/notebooks/example.ipynb @@ -1,1192 +1,1126 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Environment " - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ + "cells": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "Requirement already satisfied: torch in /opt/conda/lib/python3.12/site-packages (from -r ../requirements.txt (line 1)) (2.6.0+cpu)\n", - "Requirement already satisfied: pytorch_lightning in /opt/conda/lib/python3.12/site-packages (from -r ../requirements.txt (line 2)) (2.5.0.post0)\n", - "Requirement already satisfied: numpy in /opt/conda/lib/python3.12/site-packages (from -r ../requirements.txt (line 3)) (2.1.2)\n", - "Requirement already satisfied: pandas in /opt/conda/lib/python3.12/site-packages (from -r ../requirements.txt (line 4)) (2.2.3)\n", - "Requirement already satisfied: scikit-learn in /opt/conda/lib/python3.12/site-packages (from -r ../requirements.txt (line 5)) (1.6.1)\n", - "Requirement already satisfied: pyarrow in /opt/conda/lib/python3.12/site-packages (from -r ../requirements.txt (line 6)) (19.0.0)\n", - "Requirement already satisfied: nltk in /opt/conda/lib/python3.12/site-packages (from -r ../requirements.txt (line 7)) (3.9.1)\n", - "Requirement already satisfied: unidecode in /opt/conda/lib/python3.12/site-packages (from -r ../requirements.txt (line 8)) (1.3.8)\n", - "Requirement already satisfied: captum in /opt/conda/lib/python3.12/site-packages (from -r ../requirements.txt (line 9)) (0.7.0)\n", - "Requirement already satisfied: ipywidgets in /opt/conda/lib/python3.12/site-packages (from -r ../requirements.txt (line 10)) (8.1.5)\n", - "Requirement already satisfied: seaborn in /opt/conda/lib/python3.12/site-packages (from -r ../requirements.txt (line 11)) (0.13.2)\n", - "Requirement already satisfied: ruff>=0.7.1 in /opt/conda/lib/python3.12/site-packages (from -r ../requirements.txt (line 12)) (0.9.7)\n", - "Requirement already satisfied: pre-commit in /opt/conda/lib/python3.12/site-packages (from -r ../requirements.txt (line 13)) (4.1.0)\n", - "Requirement already satisfied: pytest in /opt/conda/lib/python3.12/site-packages (from -r ../requirements.txt (line 14)) (8.3.4)\n", - "Requirement already satisfied: filelock in /opt/conda/lib/python3.12/site-packages (from torch->-r ../requirements.txt (line 1)) (3.13.1)\n", - "Requirement already satisfied: typing-extensions>=4.10.0 in /opt/conda/lib/python3.12/site-packages (from torch->-r ../requirements.txt (line 1)) (4.12.2)\n", - "Requirement already satisfied: networkx in /opt/conda/lib/python3.12/site-packages (from torch->-r ../requirements.txt (line 1)) (3.3)\n", - "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.12/site-packages (from torch->-r ../requirements.txt (line 1)) (3.1.4)\n", - "Requirement already satisfied: fsspec in /opt/conda/lib/python3.12/site-packages (from torch->-r ../requirements.txt (line 1)) (2025.2.0)\n", - "Requirement already satisfied: setuptools in /opt/conda/lib/python3.12/site-packages (from torch->-r ../requirements.txt (line 1)) (75.8.0)\n", - "Requirement already satisfied: sympy==1.13.1 in /opt/conda/lib/python3.12/site-packages (from torch->-r ../requirements.txt (line 1)) (1.13.1)\n", - "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /opt/conda/lib/python3.12/site-packages (from sympy==1.13.1->torch->-r ../requirements.txt (line 1)) (1.3.0)\n", - "Requirement already satisfied: tqdm>=4.57.0 in /opt/conda/lib/python3.12/site-packages (from pytorch_lightning->-r ../requirements.txt (line 2)) (4.67.1)\n", - "Requirement already satisfied: PyYAML>=5.4 in /opt/conda/lib/python3.12/site-packages (from pytorch_lightning->-r ../requirements.txt (line 2)) (6.0.2)\n", - "Requirement already satisfied: torchmetrics>=0.7.0 in /opt/conda/lib/python3.12/site-packages (from pytorch_lightning->-r ../requirements.txt (line 2)) (1.6.1)\n", - "Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.12/site-packages (from pytorch_lightning->-r ../requirements.txt (line 2)) (24.2)\n", - "Requirement already satisfied: lightning-utilities>=0.10.0 in /opt/conda/lib/python3.12/site-packages (from pytorch_lightning->-r ../requirements.txt (line 2)) (0.12.0)\n", - "Requirement already satisfied: python-dateutil>=2.8.2 in /opt/conda/lib/python3.12/site-packages (from pandas->-r ../requirements.txt (line 4)) (2.9.0.post0)\n", - "Requirement already satisfied: pytz>=2020.1 in /opt/conda/lib/python3.12/site-packages (from pandas->-r ../requirements.txt (line 4)) (2025.1)\n", - "Requirement already satisfied: tzdata>=2022.7 in /opt/conda/lib/python3.12/site-packages (from pandas->-r ../requirements.txt (line 4)) (2025.1)\n", - "Requirement already satisfied: scipy>=1.6.0 in /opt/conda/lib/python3.12/site-packages (from scikit-learn->-r ../requirements.txt (line 5)) (1.15.2)\n", - "Requirement already satisfied: joblib>=1.2.0 in /opt/conda/lib/python3.12/site-packages (from scikit-learn->-r ../requirements.txt (line 5)) (1.4.2)\n", - "Requirement already satisfied: threadpoolctl>=3.1.0 in /opt/conda/lib/python3.12/site-packages (from scikit-learn->-r ../requirements.txt (line 5)) (3.5.0)\n", - "Requirement already satisfied: click in /opt/conda/lib/python3.12/site-packages (from nltk->-r ../requirements.txt (line 7)) (8.1.8)\n", - "Requirement already satisfied: regex>=2021.8.3 in /opt/conda/lib/python3.12/site-packages (from nltk->-r ../requirements.txt (line 7)) (2024.11.6)\n", - "Requirement already satisfied: matplotlib in /opt/conda/lib/python3.12/site-packages (from captum->-r ../requirements.txt (line 9)) (3.10.0)\n", - "Requirement already satisfied: comm>=0.1.3 in /opt/conda/lib/python3.12/site-packages (from ipywidgets->-r ../requirements.txt (line 10)) (0.2.2)\n", - "Requirement already satisfied: ipython>=6.1.0 in /opt/conda/lib/python3.12/site-packages (from ipywidgets->-r ../requirements.txt (line 10)) (8.32.0)\n", - "Requirement already satisfied: traitlets>=4.3.1 in /opt/conda/lib/python3.12/site-packages (from ipywidgets->-r ../requirements.txt (line 10)) (5.14.3)\n", - "Requirement already satisfied: widgetsnbextension~=4.0.12 in /opt/conda/lib/python3.12/site-packages (from ipywidgets->-r ../requirements.txt (line 10)) (4.0.13)\n", - "Requirement already satisfied: jupyterlab-widgets~=3.0.12 in /opt/conda/lib/python3.12/site-packages (from ipywidgets->-r ../requirements.txt (line 10)) (3.0.13)\n", - "Requirement already satisfied: cfgv>=2.0.0 in /opt/conda/lib/python3.12/site-packages (from pre-commit->-r ../requirements.txt (line 13)) (3.4.0)\n", - "Requirement already satisfied: identify>=1.0.0 in /opt/conda/lib/python3.12/site-packages (from pre-commit->-r ../requirements.txt (line 13)) (2.6.8)\n", - "Requirement already satisfied: nodeenv>=0.11.1 in /opt/conda/lib/python3.12/site-packages (from pre-commit->-r ../requirements.txt (line 13)) (1.9.1)\n", - "Requirement already satisfied: virtualenv>=20.10.0 in /opt/conda/lib/python3.12/site-packages (from pre-commit->-r ../requirements.txt (line 13)) (20.29.2)\n", - "Requirement already satisfied: iniconfig in /opt/conda/lib/python3.12/site-packages (from pytest->-r ../requirements.txt (line 14)) (2.0.0)\n", - "Requirement already satisfied: pluggy<2,>=1.5 in /opt/conda/lib/python3.12/site-packages (from pytest->-r ../requirements.txt (line 14)) (1.5.0)\n", - "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /opt/conda/lib/python3.12/site-packages (from fsspec[http]>=2022.5.0->pytorch_lightning->-r ../requirements.txt (line 2)) (3.11.12)\n", - "Requirement already satisfied: decorator in /opt/conda/lib/python3.12/site-packages (from ipython>=6.1.0->ipywidgets->-r ../requirements.txt (line 10)) (5.1.1)\n", - "Requirement already satisfied: jedi>=0.16 in /opt/conda/lib/python3.12/site-packages (from ipython>=6.1.0->ipywidgets->-r ../requirements.txt (line 10)) (0.19.2)\n", - "Requirement already satisfied: matplotlib-inline in /opt/conda/lib/python3.12/site-packages (from ipython>=6.1.0->ipywidgets->-r ../requirements.txt (line 10)) (0.1.7)\n", - "Requirement already satisfied: pexpect>4.3 in /opt/conda/lib/python3.12/site-packages (from ipython>=6.1.0->ipywidgets->-r ../requirements.txt (line 10)) (4.9.0)\n", - "Requirement already satisfied: prompt_toolkit<3.1.0,>=3.0.41 in /opt/conda/lib/python3.12/site-packages (from ipython>=6.1.0->ipywidgets->-r ../requirements.txt (line 10)) (3.0.50)\n", - "Requirement already satisfied: pygments>=2.4.0 in /opt/conda/lib/python3.12/site-packages (from ipython>=6.1.0->ipywidgets->-r ../requirements.txt (line 10)) (2.19.1)\n", - "Requirement already satisfied: stack_data in /opt/conda/lib/python3.12/site-packages (from ipython>=6.1.0->ipywidgets->-r ../requirements.txt (line 10)) (0.6.3)\n", - "Requirement already satisfied: contourpy>=1.0.1 in /opt/conda/lib/python3.12/site-packages (from matplotlib->captum->-r ../requirements.txt (line 9)) (1.3.1)\n", - "Requirement already satisfied: cycler>=0.10 in /opt/conda/lib/python3.12/site-packages (from matplotlib->captum->-r ../requirements.txt (line 9)) (0.12.1)\n", - "Requirement already satisfied: fonttools>=4.22.0 in /opt/conda/lib/python3.12/site-packages (from matplotlib->captum->-r ../requirements.txt (line 9)) (4.56.0)\n", - "Requirement already satisfied: kiwisolver>=1.3.1 in /opt/conda/lib/python3.12/site-packages (from matplotlib->captum->-r ../requirements.txt (line 9)) (1.4.8)\n", - "Requirement already satisfied: pillow>=8 in /opt/conda/lib/python3.12/site-packages (from matplotlib->captum->-r ../requirements.txt (line 9)) (11.0.0)\n", - "Requirement already satisfied: pyparsing>=2.3.1 in /opt/conda/lib/python3.12/site-packages (from matplotlib->captum->-r ../requirements.txt (line 9)) (3.2.1)\n", - "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas->-r ../requirements.txt (line 4)) (1.17.0)\n", - "Requirement already satisfied: distlib<1,>=0.3.7 in /opt/conda/lib/python3.12/site-packages (from virtualenv>=20.10.0->pre-commit->-r ../requirements.txt (line 13)) (0.3.9)\n", - "Requirement already satisfied: platformdirs<5,>=3.9.1 in /opt/conda/lib/python3.12/site-packages (from virtualenv>=20.10.0->pre-commit->-r ../requirements.txt (line 13)) (4.3.6)\n", - "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.12/site-packages (from jinja2->torch->-r ../requirements.txt (line 1)) (2.1.5)\n", - "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /opt/conda/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch_lightning->-r ../requirements.txt (line 2)) (2.4.6)\n", - "Requirement already satisfied: aiosignal>=1.1.2 in /opt/conda/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch_lightning->-r ../requirements.txt (line 2)) (1.3.2)\n", - "Requirement already satisfied: attrs>=17.3.0 in /opt/conda/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch_lightning->-r ../requirements.txt (line 2)) (25.1.0)\n", - "Requirement already satisfied: frozenlist>=1.1.1 in /opt/conda/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch_lightning->-r ../requirements.txt (line 2)) (1.5.0)\n", - "Requirement already satisfied: multidict<7.0,>=4.5 in /opt/conda/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch_lightning->-r ../requirements.txt (line 2)) (6.1.0)\n", - "Requirement already satisfied: propcache>=0.2.0 in /opt/conda/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch_lightning->-r ../requirements.txt (line 2)) (0.2.1)\n", - "Requirement already satisfied: yarl<2.0,>=1.17.0 in /opt/conda/lib/python3.12/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch_lightning->-r ../requirements.txt (line 2)) (1.18.3)\n", - "Requirement already satisfied: parso<0.9.0,>=0.8.4 in /opt/conda/lib/python3.12/site-packages (from jedi>=0.16->ipython>=6.1.0->ipywidgets->-r ../requirements.txt (line 10)) (0.8.4)\n", - "Requirement already satisfied: ptyprocess>=0.5 in /opt/conda/lib/python3.12/site-packages (from pexpect>4.3->ipython>=6.1.0->ipywidgets->-r ../requirements.txt (line 10)) (0.7.0)\n", - "Requirement already satisfied: wcwidth in /opt/conda/lib/python3.12/site-packages (from prompt_toolkit<3.1.0,>=3.0.41->ipython>=6.1.0->ipywidgets->-r ../requirements.txt (line 10)) (0.2.13)\n", - "Requirement already satisfied: executing>=1.2.0 in /opt/conda/lib/python3.12/site-packages (from stack_data->ipython>=6.1.0->ipywidgets->-r ../requirements.txt (line 10)) (2.1.0)\n", - "Requirement already satisfied: asttokens>=2.1.0 in /opt/conda/lib/python3.12/site-packages (from stack_data->ipython>=6.1.0->ipywidgets->-r ../requirements.txt (line 10)) (3.0.0)\n", - "Requirement already satisfied: pure_eval in /opt/conda/lib/python3.12/site-packages (from stack_data->ipython>=6.1.0->ipywidgets->-r ../requirements.txt (line 10)) (0.2.3)\n", - "Requirement already satisfied: idna>=2.0 in /opt/conda/lib/python3.12/site-packages (from yarl<2.0,>=1.17.0->aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2022.5.0->pytorch_lightning->-r ../requirements.txt (line 2)) (3.10)\n" - ] - } - ], - "source": [ - "!pip install -r ../requirements.txt" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "import sys\n", - "\n", - "import numpy as np\n", - "import pandas as pd\n", - "import pyarrow.parquet as pq\n", - "import s3fs\n", - "from sklearn.model_selection import train_test_split\n", - "from sklearn.preprocessing import LabelEncoder\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Temporary (torchFastText in active development)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "sys.path.append(\"../\")\n", - "from torchFastText import torchFastText\n", - "from torchFastText.preprocess import clean_text_feature\n", - "from torchFastText.datasets import NGramTokenizer\n", - "\n", - "%load_ext autoreload\n", - "%autoreload 2" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "# !pip install torchFastText" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Some useful functions that will help us format our dataset" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "sys.path.append(\"notebooks/\")\n", - "from utils import categorize_surface, clean_and_tokenize_df, stratified_split_rare_labels, add_libelles" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Load and preprocess data" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Exemple d’utilisation de la librairie `TorchFastText`\n", + "\n", + "*Warning*\n", + "\n", + "*`TorchFastText` library is still under active development. Have a\n", + "regular look to for\n", + "latest information.*\n", + "\n", + "To install package, you can run the following snippet" + ], + "id": "a01b1526-51df-4bf9-9fd4-11ef22ffcc79" + }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "2025-02-26 10:34:04 - botocore.httpchecksum - Skipping checksum validation. Response did not contain one of the following algorithms: ['crc32', 'sha1', 'sha256'].\n", - "2025-02-26 10:34:04 - botocore.httpchecksum - Skipping checksum validation. Response did not contain one of the following algorithms: ['crc32', 'sha1', 'sha256'].\n", - "2025-02-26 10:34:04 - botocore.httpchecksum - Skipping checksum validation. Response did not contain one of the following algorithms: ['crc32', 'sha1', 'sha256'].\n", - "2025-02-26 10:34:04 - botocore.httpchecksum - Skipping checksum validation. Response did not contain one of the following algorithms: ['crc32', 'sha1', 'sha256'].\n", - "2025-02-26 10:34:05 - botocore.httpchecksum - Skipping checksum validation. Response did not contain one of the following algorithms: ['crc32', 'sha1', 'sha256'].\n" - ] - } - ], - "source": [ - "fs = s3fs.S3FileSystem(\n", - " client_kwargs={\"endpoint_url\": \"https://minio.lab.sspcloud.fr\"},\n", - " anon=True,\n", - ")\n", - "df = (\n", - " pq.ParquetDataset(\n", - " \"projet-ape/extractions/20241027_sirene4.parquet\",\n", - " filesystem=fs,\n", - " )\n", - " .read_pandas()\n", - " .to_pandas()\n", - ").sample(frac=0.001).fillna(np.nan)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# Stable version\n", + "pip install torchFastText \n", + "# Development version\n", + "# pip install !https://github.com/InseeFrLab/torch-fastText.git" + ], + "id": "a00a2856" + }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "2025-02-26 10:34:09 - botocore.httpchecksum - Skipping checksum validation. Response did not contain one of the following algorithms: ['crc32', 'sha1', 'sha256'].\n" - ] + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load and preprocess data\n", + "\n", + "In that guide, we propose to illustrate main package functionalities\n", + "using that `DataFrame`:" + ], + "id": "b292ea76-57a1-4d4e-9bde-dcc9656dc447" }, { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
codelibelle
00111ZCulture de céréales (à l'exception du riz), de...
10112ZCulture du riz
20113ZCulture de légumes, de melons, de racines et d...
30114ZCulture de la canne à sucre
40115ZCulture du tabac
.........
7279609ZAutres services personnels n.c.a.
7289700ZActivités des ménages en tant qu'employeurs de...
7299810ZActivités indifférenciées des ménages en tant ...
7309820ZActivités indifférenciées des ménages en tant ...
7319900ZActivités des organisations et organismes extr...
\n", - "

732 rows × 2 columns

\n", - "
" - ], - "text/plain": [ - " code libelle\n", - "0 0111Z Culture de céréales (à l'exception du riz), de...\n", - "1 0112Z Culture du riz\n", - "2 0113Z Culture de légumes, de melons, de racines et d...\n", - "3 0114Z Culture de la canne à sucre\n", - "4 0115Z Culture du tabac\n", - ".. ... ...\n", - "727 9609Z Autres services personnels n.c.a.\n", - "728 9700Z Activités des ménages en tant qu'employeurs de...\n", - "729 9810Z Activités indifférenciées des ménages en tant ...\n", - "730 9820Z Activités indifférenciées des ménages en tant ...\n", - "731 9900Z Activités des organisations et organismes extr...\n", - "\n", - "[732 rows x 2 columns]" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "with fs.open(\"projet-ape/data/naf2008.csv\") as file:\n", - " naf2008 = pd.read_csv(file, sep=\";\")\n", - "naf2008" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "df = pd.read_parquet(\"https://minio.lab.sspcloud.fr/projet-ape/extractions/20241027_sirene4.parquet\")\n", + "df = df.sample(10000)" + ], + "id": "37c042fe" + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "\t*** 732 codes have been added in the database...\n", - "\n" - ] - } - ], - "source": [ - "categorical_features = [\"evenement_type\", \"cj\", \"activ_nat_et\", \"liasse_type\", \"activ_surf_et\", \"activ_perm_et\"]\n", - "text_feature = \"libelle\"\n", - "y = \"apet_finale\"\n", - "textual_features = None\n", - "\n", - "df = add_libelles(df, naf2008, y, text_feature, textual_features, categorical_features)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Preprocess text and target" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We make available our processing function clean_text_feature for the text." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "df[\"libelle_processed\"] = clean_text_feature(df[\"libelle\"])" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - "encoder = LabelEncoder()\n", - "df[\"apet_finale\"] = encoder.fit_transform(df[\"apet_finale\"])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Put the columns in the right format:\n", - " - First column contains the processed text (str)\n", - " - Next ones contain the \"encoded\" categorical (discrete) variables in int format\n", - "\n", - "X and y are arrays." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Our goal will be to build multilabel classification for the `code`\n", + "variable using `libelle` as feature.\n", + "\n", + "## Enriching our test dataset\n", + "\n", + "Unlike `Fasttext`, this package offers the possibility of having several\n", + "feature columns of different types (string for the text column and\n", + "additional variables in numeric form, for example). To illustrate that,\n", + "we propose the following enrichment of the example dataset:" + ], + "id": "c399b4b0-a9cb-450e-9a5e-480e0e657b8e" + }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/onyxia/work/torch-fastText/notebooks/utils.py:60: FutureWarning: Setting an item of incompatible dtype is deprecated and will raise an error in a future version of pandas. Value 'nan' has dtype incompatible with float64, please explicitly cast to a compatible dtype first.\n", - " df.fillna(\"nan\", inplace=True)\n" - ] - } - ], - "source": [ - "df, _ = clean_and_tokenize_df(df, text_feature=\"libelle_processed\")\n", - "X = df[[\"libelle_processed\", \"EVT\", \"CJ\", \"NAT\", \"TYP\", \"CRT\", \"SRF\"]].values\n", - "y = df[\"apet_finale\"].values\n" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.preprocessing import LabelEncoder\n", + "\n", + "def categorize_surface(\n", + " df: pd.DataFrame, surface_feature_name: int, like_sirene_3: bool = True\n", + ") -> pd.DataFrame:\n", + " \"\"\"\n", + " Categorize the surface of the activity.\n", + "\n", + " Args:\n", + " df (pd.DataFrame): DataFrame to categorize.\n", + " surface_feature_name (str): Name of the surface feature.\n", + " like_sirene_3 (bool): If True, categorize like Sirene 3.\n", + "\n", + " Returns:\n", + " pd.DataFrame: DataFrame with a new column \"surf_cat\".\n", + " \"\"\"\n", + " df_copy = df.copy()\n", + " df_copy[surface_feature_name] = df_copy[surface_feature_name].replace(\"nan\", np.nan)\n", + " df_copy[surface_feature_name] = df_copy[surface_feature_name].astype(float)\n", + " # Check surface feature exists\n", + " if surface_feature_name not in df.columns:\n", + " raise ValueError(f\"Surface feature {surface_feature_name} not found in DataFrame.\")\n", + " # Check surface feature is a float variable\n", + " if not (pd.api.types.is_float_dtype(df_copy[surface_feature_name])):\n", + " raise ValueError(f\"Surface feature {surface_feature_name} must be a float variable.\")\n", + "\n", + " if like_sirene_3:\n", + " # Categorize the surface\n", + " df_copy[\"surf_cat\"] = pd.cut(\n", + " df_copy[surface_feature_name],\n", + " bins=[0, 120, 400, 2500, np.inf],\n", + " labels=[\"1\", \"2\", \"3\", \"4\"],\n", + " ).astype(str)\n", + " else:\n", + " # Log transform the surface\n", + " df_copy[\"surf_log\"] = np.log(df[surface_feature_name])\n", + "\n", + " # Categorize the surface\n", + " df_copy[\"surf_cat\"] = pd.cut(\n", + " df_copy.surf_log,\n", + " bins=[0, 3, 4, 5, 12],\n", + " labels=[\"1\", \"2\", \"3\", \"4\"],\n", + " ).astype(str)\n", + "\n", + " df_copy[surface_feature_name] = df_copy[\"surf_cat\"].replace(\"nan\", \"0\")\n", + " df_copy[surface_feature_name] = df_copy[surface_feature_name].astype(int)\n", + " df_copy = df_copy.drop(columns=[\"surf_log\", \"surf_cat\"], errors=\"ignore\")\n", + " return df_copy\n", + "\n", + "\n", + "def clean_and_tokenize_df(\n", + " df,\n", + " categorical_features=[\"EVT\", \"CJ\", \"NAT\", \"TYP\", \"CRT\"],\n", + " text_feature=\"libelle_processed\",\n", + " label_col=\"apet_finale\",\n", + "):\n", + " df.fillna(\"nan\", inplace=True)\n", + "\n", + " df = df.rename(\n", + " columns={\n", + " \"evenement_type\": \"EVT\",\n", + " \"cj\": \"CJ\",\n", + " \"activ_nat_et\": \"NAT\",\n", + " \"liasse_type\": \"TYP\",\n", + " \"activ_surf_et\": \"SRF\",\n", + " \"activ_perm_et\": \"CRT\",\n", + " }\n", + " )\n", + "\n", + " les = []\n", + " for col in categorical_features:\n", + " le = LabelEncoder()\n", + " df[col] = le.fit_transform(df[col])\n", + " les.append(le)\n", + "\n", + " df = categorize_surface(df, \"SRF\", like_sirene_3=True)\n", + " df = df[[text_feature, \"EVT\", \"CJ\", \"NAT\", \"TYP\", \"SRF\", \"CRT\", label_col]]\n", + "\n", + " return df, les\n", + "\n", + "\n", + "def stratified_split_rare_labels(X, y, test_size=0.2, min_train_samples=1):\n", + " # Get unique labels and their frequencies\n", + " unique_labels, label_counts = np.unique(y, return_counts=True)\n", + "\n", + " # Separate rare and common labels\n", + " rare_labels = unique_labels[label_counts == 1]\n", + "\n", + " # Create initial mask for rare labels to go into training set\n", + " rare_label_mask = np.isin(y, rare_labels)\n", + "\n", + " # Separate data into rare and common label datasets\n", + " X_rare = X[rare_label_mask]\n", + " y_rare = y[rare_label_mask]\n", + " X_common = X[~rare_label_mask]\n", + " y_common = y[~rare_label_mask]\n", + "\n", + " # Split common labels stratified\n", + " X_common_train, X_common_test, y_common_train, y_common_test = train_test_split(\n", + " X_common, y_common, test_size=test_size, stratify=y_common\n", + " )\n", + "\n", + " # Combine rare labels with common labels split\n", + " X_train = np.concatenate([X_rare, X_common_train])\n", + " y_train = np.concatenate([y_rare, y_common_train])\n", + " X_test = X_common_test\n", + " y_test = y_common_test\n", + "\n", + " return X_train, X_test, y_train, y_test\n", + "\n", + "def add_libelles(\n", + " df: pd.DataFrame,\n", + " df_naf: pd.DataFrame,\n", + " y: str,\n", + " text_feature: str,\n", + " textual_features: list,\n", + " categorical_features: list,\n", + "):\n", + " missing_codes = set(df_naf[\"code\"])\n", + " fake_obs = df_naf[df_naf[\"code\"].isin(missing_codes)]\n", + " fake_obs[y] = fake_obs[\"code\"]\n", + " fake_obs[text_feature] = fake_obs[[text_feature]].apply(\n", + " lambda row: \" \".join(f\"[{col}] {val}\" for col, val in row.items() if val != \"\"), axis=1\n", + " )\n", + " df = pd.concat([df, fake_obs[[col for col in fake_obs.columns if col in df.columns]]])\n", + "\n", + " if textual_features is not None:\n", + " for feature in textual_features:\n", + " df[feature] = df[feature].fillna(value=\"\")\n", + " if categorical_features is not None:\n", + " for feature in categorical_features:\n", + " df[feature] = df[feature].fillna(value=\"NaN\")\n", + "\n", + " print(f\"\\t*** {len(missing_codes)} codes have been added in the database...\\n\")\n", + " return df" + ], + "id": "92402df7" + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "Features for the 3 first obs:\n", - "\n", - "[['vent lign' 2 26 7 1 1 0]\n", - " ['nettoyag murs, murets, bat' 6 26 7 7 1 0]\n", - " [\"creation vent bijoux fantais fait main soin (colliers, bracelets, boucl d'oreil etc).\"\n", - " 2 26 0 7 1 0]]\n", - "\n", - "\n", - "NAF codes (labels) for the 3 first obs:\n", - "\n", - "[480 628 293]\n" - ] - } - ], - "source": [ - "print(\"Features for the 3 first obs:\\n\")\n", - "print(X[:3])\n", - "print(\"\\n\")\n", - "print(\"NAF codes (labels) for the 3 first obs:\\n\")\n", - "print(y[:3])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We split the data into train and test sets. We especially take care that: \n", - "- classes with only one instance appear in the train set (instead of the test set)\n", - "- all classes are represented in the train set\n", - "\n", - "The `stratified_split_rare_labels` function from the `preprocess` subpackage is used to carefully split the data." - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "X_train, X_test, y_train, y_test = stratified_split_rare_labels(X, y)\n", - "assert set(range(len(naf2008[\"code\"]))) == set(np.unique(y_train))" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + " *** 732 codes have been added in the database...\n" + ] + } + ], + "source": [ + "categorical_features = [\"evenement_type\", \"cj\", \"activ_nat_et\", \"liasse_type\", \"activ_surf_et\", \"activ_perm_et\"]\n", + "text_feature = \"libelle\"\n", + "y = \"apet_finale\"\n", + "textual_features = None\n", + "\n", + "naf2008 = pd.read_csv(\"https://minio.lab.sspcloud.fr/projet-ape/data/naf2008.csv\", sep=\";\")\n", + "df = add_libelles(df, naf2008, y, text_feature, textual_features, categorical_features)" + ], + "id": "1fd02895" + }, { - "data": { - "text/plain": [ - "(2793, 7)" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "X_train.shape" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Build the torch-fastText model (without training it)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We first initialize the model (without building it)." - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [], - "source": [ - "# Parameters for model building\n", - "NUM_TOKENS= int(1e5) # Number of rows in the embedding matrix = size of the embedded vocabulary\n", - "EMBED_DIM = 50 # Dimension of the embedding = number of columns in the embedding matrix\n", - "SPARSE = False # Whether to use sparse Embedding layer for fast computation (see PyTorch documentation)\n", - "CAT_EMBED_DIM = 10 # Dimension of the embedding for categorical features\n", - "\n", - "# Parameters for tokenizer\n", - "MIN_COUNT = 1 # Minimum number of occurrences of a word in the corpus to be included in the vocabulary\n", - "MIN_N = 3 # Minimum length of char n-grams\n", - "MAX_N = 6 # Maximum length of char n-grams\n", - "LEN_WORD_NGRAMS = 3 # Length of word n-grams\n", - "\n", - "# Parameters for training - not useful immediately\n", - "LR = 4e-3 # Learning rate\n", - "NUM_EPOCHS = 1\n", - "BATCH_SIZE = 256\n", - "PATIENCE = 3\n" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "model = torchFastText(\n", - " num_tokens=NUM_TOKENS,\n", - " embedding_dim=EMBED_DIM,\n", - " categorical_embedding_dims=CAT_EMBED_DIM,\n", - " min_count=MIN_COUNT,\n", - " min_n=MIN_N,\n", - " max_n=MAX_N,\n", - " len_word_ngrams=LEN_WORD_NGRAMS,\n", - " sparse = SPARSE\n", - ")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We can save these parameters to a JSON file. Initialization can also be done providing a JSON file path." - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [], - "source": [ - "model.to_json('torchFastText_config.json')" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [], - "source": [ - "model = torchFastText.from_json('torchFastText_config.json')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "We build the model using the training data. \n", - "We have now access to the tokenizer, the PyTorch model as well as a PyTorch Lightning module ready to be trained.\n", - "Note that Lightning is high-level framework for PyTorch that simplifies the process of training, validating, and deploying machine learning models." - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Preprocessing\n", + "\n", + "To reduce noise in text fields, we recommend pre-processing before\n", + "training a model with our package. We assume this preprocessing is\n", + "handled by the package user : this gives him the opportunity to control\n", + "data cleansing.\n", + "\n", + "Here’s an example of the type of preprocessing that can be carried out\n", + "before moving on to the modeling phase" + ], + "id": "67f4160d-0c98-4700-80f4-1ba454e6a2df" + }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "2025-02-26 10:34:11 - torchFastText.model.pytorch_model - num_rows is different from the number of tokens in the tokenizer. Using provided num_rows.\n", - "2025-02-26 10:34:11 - torchFastText.torchFastText - No scheduler parameters provided. Using default parameters (suited for ReduceLROnPlateau).\n" - ] - } - ], - "source": [ - "model.build(X_train, y_train, lightning=True, lr = LR)" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "from torchFastText.preprocess import clean_text_feature\n", + "df[\"libelle_processed\"] = clean_text_feature(df[\"libelle\"])" + ], + "id": "61b0252e" + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "==== Model\n", - "\n", - "FastTextModel(\n", - " (embeddings): EmbeddingBag(103957, 50, mode='mean')\n", - " (emb_0): Embedding(24, 10)\n", - " (emb_1): Embedding(27, 10)\n", - " (emb_2): Embedding(9, 10)\n", - " (emb_3): Embedding(13, 10)\n", - " (emb_4): Embedding(3, 10)\n", - " (emb_5): Embedding(4, 10)\n", - " (fc): Linear(in_features=60, out_features=732, bias=True)\n", - ")\n", - "\n", - "==== Tokenizer\n", - "\n", - "\n", - "\n", - "==== Lightning Module\n", - "\n", - "FastTextModule(\n", - " (model): FastTextModel(\n", - " (embeddings): EmbeddingBag(103957, 50, mode='mean')\n", - " (emb_0): Embedding(24, 10)\n", - " (emb_1): Embedding(27, 10)\n", - " (emb_2): Embedding(9, 10)\n", - " (emb_3): Embedding(13, 10)\n", - " (emb_4): Embedding(3, 10)\n", - " (emb_5): Embedding(4, 10)\n", - " (fc): Linear(in_features=60, out_features=732, bias=True)\n", - " )\n", - " (loss): CrossEntropyLoss()\n", - " (accuracy_fn): MulticlassAccuracy()\n", - ")\n" - ] - } - ], - "source": [ - "print(\"==== Model\\n\")\n", - "print(model.pytorch_model)\n", - "print(\"\\n==== Tokenizer\\n\")\n", - "print(model.tokenizer)\n", - "print(\"\\n==== Lightning Module\\n\")\n", - "print(model.lightning_module)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This step is useful to initialize the full torchFastText model without training it, if needed for some reason. \n", - "\n", - "But if it is not necessary, we could have directly launched the training (building is then handled automatically)." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "You can play with the tokenizer." - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Right now, the model requires the label (variable y) to be a numerical\n", + "variable. If the label variable is a text variable, we recommend using\n", + "Scikit Learn’s\n", + "[LabelEncoder](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelEncoder.html)\n", + "to convert into a numeric variable. Using that function will give user\n", + "the possibility to get back labels from the encoder after running\n", + "predictions." + ], + "id": "acde2929-fe92-4107-8066-a5c8ac5d6428" + }, { - "name": "stdout", - "output_type": "stream", - "text": [ - "{0: '',\n", - " 4137: '',\n", - " 11086: 'lorem>',\n", - " 13334: 'ame',\n", - " 14893: '',\n", - " 17320: '',\n", - " 29346: 'or>',\n", - " 33345: '',\n", - " 40359: '',\n", - " 41703: 'sit>',\n", - " 41836: 'ipsu',\n", - " 44743: 'psu',\n", - " 44896: 'orem>',\n", - " 45751: '',\n", - " 53310: 'it>',\n", - " 53955: 'olor>',\n", - " 56480: 'lor>',\n", - " 56487: 'ore',\n", - " 58774: 'sum>',\n", - " 61437: 'met',\n", - " 61524: '',\n", - " 63743: 'olor',\n", - " 63950: 'orem',\n", - " 64494: 'psum',\n", - " 65147: 'ipsum dolor',\n", - " 65285: '',\n", - " 66014: 'lorem ipsum',\n", - " 67137: 'dolor sit amet',\n", - " 68123: 'rem',\n", - " 69783: 'ipsum>',\n", - " 72558: 'lor',\n", - " 73559: '',\n", - " 75729: 'dolor sit',\n", - " 78383: 'rem>',\n", - " 83592: '',\n", - " 88077: 'sit amet',\n", - " 88736: 'psum>',\n", - " 88774: '\n", - "\n", - "==== Lightning Module\n", - "\n", - "FastTextModule(\n", - " (model): FastTextModel(\n", - " (embeddings): EmbeddingBag(103957, 50, mode='mean')\n", - " (emb_0): Embedding(24, 10)\n", - " (emb_1): Embedding(27, 10)\n", - " (emb_2): Embedding(9, 10)\n", - " (emb_3): Embedding(13, 10)\n", - " (emb_4): Embedding(3, 10)\n", - " (emb_5): Embedding(4, 10)\n", - " (fc): Linear(in_features=60, out_features=732, bias=True)\n", - " )\n", - " (loss): CrossEntropyLoss()\n", - " (accuracy_fn): MulticlassAccuracy()\n", - ")\n" - ] - } - ], - "source": [ - "print(\"==== Model\\n\")\n", - "print(model.pytorch_model)\n", - "print(\"\\n==== Tokenizer\\n\")\n", - "print(model.tokenizer)\n", - "print(\"\\n==== Lightning Module\\n\")\n", - "print(model.lightning_module)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "If the PyTorch model building did not use the training data, please keep in mind that its architecture (that you customize here) should match the vocabulary size of the categorical variables and the total number of class, otherwise the model will raise an error during training." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Train a torchFastText model" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [ + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Splitting in train-test sets\n", + "\n", + "As usual in a learning approach, you need to break down your data into\n", + "learning and test/validation samples to obtain robust performance\n", + "statistics.\n", + "\n", + "This work is the responsibility of the package’s users. Here’s an\n", + "example of how to do it, using the\n", + "[`train_test_split`](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html)\n", + "function in `Scikit`." + ], + "id": "e70de831-dbc9-49be-b0c4-d70dd6479d03" + }, { - "name": "stderr", - "output_type": "stream", - "text": [ - "2025-02-26 10:34:12 - torchFastText.torchFastText - Checking inputs...\n", - "2025-02-26 10:34:12 - torchFastText.torchFastText - Inputs successfully checked. Starting the training process..\n", - "2025-02-26 10:34:12 - torchFastText.torchFastText - Running on: cpu\n", - "2025-02-26 10:34:12 - torchFastText.torchFastText - Lightning module successfully created.\n", - "GPU available: False, used: False\n", - "TPU available: False, using: 0 TPU cores\n", - "HPU available: False, using: 0 HPUs\n", - "/opt/conda/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default\n", - "2025-02-26 10:34:12 - torchFastText.torchFastText - Launching training...\n", - "\n", - " | Name | Type | Params | Mode \n", - "-----------------------------------------------------------\n", - "0 | model | FastTextModel | 5.2 M | train\n", - "1 | loss | CrossEntropyLoss | 0 | train\n", - "2 | accuracy_fn | MulticlassAccuracy | 0 | train\n", - "-----------------------------------------------------------\n", - "5.2 M Trainable params\n", - "0 Non-trainable params\n", - "5.2 M Total params\n", - "20.973 Total estimated model params size (MB)\n", - "11 Modules in train mode\n", - "0 Modules in eval mode\n" - ] + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "from sklearn.model_selection import train_test_split \n", + "X_train, X_test, y_train, y_test = train_test_split(X, y)" + ], + "id": "b593fd75" }, { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "73a0fb9624054797a6ee55f0b804ce71", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Sanity Checking: | | 0/? [00:00" - ] - }, - "metadata": {}, - "output_type": "display_data" + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "output_type": "display_data", + "metadata": {}, + "data": { + "text/plain": [ + "" + ] + } + } + ], + "source": [ + "model.tokenizer" + ], + "id": "d983b113" + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "output_type": "display_data", + "metadata": {}, + "data": { + "text/plain": [ + "FastTextModule(\n", + " (model): FastTextModel(\n", + " (embeddings): EmbeddingBag(107992, 50, mode='mean', padding_idx=107991)\n", + " (emb_0): Embedding(24, 10)\n", + " (emb_1): Embedding(40, 10)\n", + " (emb_2): Embedding(8, 10)\n", + " (emb_3): Embedding(13, 10)\n", + " (emb_4): Embedding(3, 10)\n", + " (emb_5): Embedding(4, 10)\n", + " (fc): Linear(in_features=60, out_features=646, bias=True)\n", + " )\n", + " (loss): CrossEntropyLoss()\n", + " (accuracy_fn): MulticlassAccuracy()\n", + ")" + ] + } + } + ], + "source": [ + "model.lightning_module" + ], + "id": "9b23f1ba" }, { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAbUAAAGsCAYAAABaczmOAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAJUdJREFUeJzt3X10VOWBx/HfTJJJAJnEhJAhkAAqR0KhpA0GguyGI1lTV7tF4y5NURCzslreJCovykvV9sTiIi9FNge3lXKEowdrPcoqbhorWokBEqzlxay1YFKSCaGYGV6TkLn7h4fRaUKYhAxJnnw/59yj3Dz33ufmjvP1zkyCzbIsSwAAGMDe1RMAAKCzEDUAgDGIGgDAGEQNAGAMogYAMAZRAwAYg6gBAIwR3tUT6Ao+n0/V1dXq37+/bDZbV08HANAGy7J06tQpJSYmym5v+16sV0aturpaSUlJXT0NAEA7VFVVaciQIW2O6ZVR69+/v6SvvkFOp7OLZwMAaIvX61VSUpL/ubstvTJqF19ydDqdRA0Aeohg3i7igyIAAGMQNQCAMYgaAMAYvfI9NQAIRnNzs5qamrp6Gr2Cw+G47Mf1g0HUAODvWJYlt9ut+vr6rp5Kr2G32zV8+HA5HI4r2g9RA4C/czFoAwcOVN++ffklDSF28Rdi1NTUKDk5+Yq+30QNAL6hubnZH7S4uLiunk6vER8fr+rqal24cEEREREd3g8fFAGAb7j4Hlrfvn27eCa9y8WXHZubm69oP0QNAFrBS45XV2d9v4kaAMAYRA0AYAw+KAIAQUhf9tJVPd6en95zVY9nCu7UAMAAw4YNk81ma7HMmTNHknT+/HnNmTNHcXFxuuaaa5STk6Pa2tqAfcyfP19paWmKjIxUampqq8f55JNP9A//8A+KiopSUlKSVq1a5f/a0aNHW53DxWX48OEhO/+LiBoAGGDv3r2qqanxL0VFRZKkf/3Xf5UkLVy4UG+++aa2b9+uXbt2qbq6WnfddVeL/dx///2aNm1aq8fwer269dZbNXToUJWVlenZZ5/VT37yE23atEmSlJSUFDCHi8ubb76psLAwf2BDiZcfAcAA8fHxAX9+5plndP311yszM1Mej0e//OUvtW3bNt1yyy2SpBdffFEpKSn66KOPNGHCBEnS+vXrJUl1dXX65JNPWhxj69atamxs1K9+9Ss5HA5961vf0scff6znnntOs2fPVlhYmFwuV8A2tbW1euihh5Sbm6tHH300FKcegDs1ADBMY2OjXnrpJd1///2y2WwqKytTU1OTsrKy/GNGjhyp5ORklZSUBL3fkpIS/eM//mPAr7LKzs5WRUWFvvzyyxbjm5qalJOTI5fLpRdeeOHKTipI3KkBgGFef/111dfX67777pP01a/9cjgciomJCRiXkJAgt9sd9H7dbneL98USEhL8X7v22msDvjZ37lx9/vnn2rt3r6Kiotp/Ih1A1ADAML/85S912223KTExscvmUFhYqM2bN+v3v/+9hgwZctWOS9QAwCBffPGFfve73+m1117zr3O5XGpsbFR9fX3A3VptbW2L98Da4nK5Wnxi8uKfv7mfP/zhD5o/f742btyoiRMndvBMOob31ADAIC+++KIGDhyo22+/3b8uLS1NERERKi4u9q+rqKhQZWWlMjIygt53RkaG3n///YC/Y66oqEg33nij/6XHqqoq5eTkaPbs2fr3f//3Tjij9uFODQAM4fP59OKLL2rmzJkKD//66T06Olp5eXnKz89XbGysnE6n5s2bp4yMDP8nHyXpz3/+s06fPi23261z587p448/liSNGjVKDodDP/rRj/Tkk08qLy9Pixcv1oEDB7Ru3TqtWbNG0lc/C3fnnXdq8ODBWrJkSavv17XnzrBDrF7I4/FYkiyPx9PVUwHQzZw7d846dOiQde7cua6eSru98847liSroqKixdfOnTtn/fjHP7auvfZaq2/fvtadd95p1dTUBIzJzMy0JLVYjhw54h/zxz/+0Zo0aZIVGRlpDR482HrmmWf8X3vvvfda3f6by6W09X1vz3O2zbIsK7TZ7H68Xq+io6Pl8XjkdDq7ejoAupHz58/ryJEjGj58+FX7xB7a/r635zmb99QAAMYgagAAYxA1AIAxiBoAwBhEDQBa4fP5unoKvUpnfWaRn1MDgG9wOByy2+2qrq5WfHy8HA6HbDZbV0/LaJZlqa6uTjabTREREVe0L6IGAN9gt9s1fPhw1dTUqLq6uqun02vYbDYNGTJEYWFhV7QfogYAf8fhcCg5OVkXLlxQc3NzV0+nV4iIiLjioElEDQBadfGlsCt9OQxXFx8UAQAYg6gBAIxB1AAAxiBqAABjEDUAgDGIGgDAGEQNAGAMogYAMAZRAwAYg6gBAIxB1AAAxiBqAABjEDUAgDGIGgDAGCGP2vPPP69hw4YpKipK48eP1549e9ocv337do0cOVJRUVEaM2aM3nrrrUuOffDBB2Wz2bR27dpOnjUAoCcKadReeeUV5efna+XKlSovL9fYsWOVnZ2t48ePtzp+9+7dys3NVV5envbv36+pU6dq6tSpOnDgQIuxv/3tb/XRRx8pMTExlKcAAOhBQhq15557Tg888IBmzZqlUaNGqbCwUH379tWvfvWrVsevW7dO3/ve9/TYY48pJSVFTz/9tL773e9qw4YNAeOOHTumefPmaevWrfwFfgAAv5BFrbGxUWVlZcrKyvr6YHa7srKyVFJS0uo2JSUlAeMlKTs7O2C8z+fTvffeq8cee0zf+ta3gppLQ0ODvF5vwAIAME/IonbixAk1NzcrISEhYH1CQoLcbner27jd7suO//nPf67w8HDNnz8/6LkUFBQoOjravyQlJbXjTAAAPUWP+vRjWVmZ1q1bp82bN8tmswW93dKlS+XxePxLVVVVCGcJAOgqIYvagAEDFBYWptra2oD1tbW1crlcrW7jcrnaHP/BBx/o+PHjSk5OVnh4uMLDw/XFF1/okUce0bBhwy45l8jISDmdzoAFAGCekEXN4XAoLS1NxcXF/nU+n0/FxcXKyMhodZuMjIyA8ZJUVFTkH3/vvffqk08+0ccff+xfEhMT9dhjj+mdd94J1akAAHqI8FDuPD8/XzNnztS4ceOUnp6utWvX6syZM5o1a5YkacaMGRo8eLAKCgokSQsWLFBmZqZWr16t22+/XS+//LL27dunTZs2SZLi4uIUFxcXcIyIiAi5XC7deOONoTwVAEAPENKoTZs2TXV1dVqxYoXcbrdSU1O1c+dO/4dBKisrZbd/fbM4ceJEbdu2TcuWLdPjjz+uESNG6PXXX9fo0aNDOU0AgCFslmVZXT2Jq83r9So6Oloej4f31wCgm2vPc3aP+vQjAABtIWoAAGMQNQCAMYgaAMAYRA0AYAyiBgAwBlEDABiDqAEAjEHUAADGIGoAAGMQNQCAMYgaAMAYRA0AYAyiBgAwBlEDABiDqAEAjEHUAADGIGoAAGMQNQCAMYgaAMAYRA0AYAyiBgAwBlEDABiDqAEAjEHUAADGIGoAAGMQNQCAMYgaAMAYRA0AYAyiBgAwBlEDABiDqAEAjEHUAADGIGoAAGMQNQCAMYgaAMAYRA0AYAyiBgAwBlEDABiDqAEAjEHUAADGIGoAAGMQNQCAMYgaAMAYRA0AYAyiBgAwBlEDABiDqAEAjEHUAADGIGoAAGMQNQCAMYgaAMAYRA0AYAyiBgAwBlEDABiDqAEAjEHUAADGIGoAAGMQNQCAMUIeteeff17Dhg1TVFSUxo8frz179rQ5fvv27Ro5cqSioqI0ZswYvfXWW/6vNTU1afHixRozZoz69eunxMREzZgxQ9XV1aE+DQBADxDSqL3yyivKz8/XypUrVV5errFjxyo7O1vHjx9vdfzu3buVm5urvLw87d+/X1OnTtXUqVN14MABSdLZs2dVXl6u5cuXq7y8XK+99poqKir0L//yL6E8DQBAD2GzLMsK1c7Hjx+vm266SRs2bJAk+Xw+JSUlad68eVqyZEmL8dOmTdOZM2e0Y8cO/7oJEyYoNTVVhYWFrR5j7969Sk9P1xdffKHk5OSg5uX1ehUdHS2PxyOn09mBMwMAXC3tec4O2Z1aY2OjysrKlJWV9fXB7HZlZWWppKSk1W1KSkoCxktSdnb2JcdLksfjkc1mU0xMzCXHNDQ0yOv1BiwAAPOELGonTpxQc3OzEhISAtYnJCTI7Xa3uo3b7W7X+PPnz2vx4sXKzc1ts94FBQWKjo72L0lJSe08GwBAT9BjP/3Y1NSkf/u3f5NlWfqv//qvNscuXbpUHo/Hv1RVVV2lWQIArqbwUO14wIABCgsLU21tbcD62tpauVyuVrdxuVxBjb8YtC+++ELvvvvuZV9jjYyMVGRkZAfOAgDQk4TsTs3hcCgtLU3FxcX+dT6fT8XFxcrIyGh1m4yMjIDxklRUVBQw/mLQPvvsM/3ud79TXFxcaE4AANDjhOxOTZLy8/M1c+ZMjRs3Tunp6Vq7dq3OnDmjWbNmSZJmzJihwYMHq6CgQJK0YMECZWZmavXq1br99tv18ssva9++fdq0aZOkr4J29913q7y8XDt27FBzc7P//bbY2Fg5HI5Qng4AoJsLadSmTZumuro6rVixQm63W6mpqdq5c6f/wyCVlZWy27++WZw4caK2bdumZcuW6fHHH9eIESP0+uuva/To0ZKkY8eO6Y033pAkpaamBhzr97//vSZPnhzK0wEAdHMh/Tm17oqfUwOAnqNb/JwaAABXG1EDABiDqAEAjEHUAADGIGoAAGMQNQCAMYgaAMAYRA0AYAyiBgAwBlEDABiDqAEAjEHUAADGIGoAAGMQNQCAMYgaAMAYRA0AYAyiBgAwBlEDABiDqAEAjEHUAADGIGoAAGMQNQCAMYgaAMAYRA0AYAyiBgAwBlEDABiDqAEAjEHUAADGIGoAAGMQNQCAMYgaAMAYRA0AYAyiBgAwBlEDABiDqAEAjEHUAADGIGoAAGMQNQCAMYgaAMAYRA0AYAyiBgAwBlEDABiDqAEAjEHUAADGIGoAAGMQNQCAMYgaAMAYRA0AYAyiBgAwBlEDABiDqAEAjEHUAADGIGoAAGMQNQCAMYgaAMAYRA0AYAyiBgAwBlEDABiDqAEAjBHyqD3//PMaNmyYoqKiNH78eO3Zs6fN8du3b9fIkSMVFRWlMWPG6K233gr4umVZWrFihQYNGqQ+ffooKytLn332WShPAQDQQ4Q0aq+88ory8/O1cuVKlZeXa+zYscrOztbx48dbHb97927l5uYqLy9P+/fv19SpUzV16lQdOHDAP2bVqlVav369CgsLVVpaqn79+ik7O1vnz58P5akAAHoAm2VZVqh2Pn78eN10003asGGDJMnn8ykpKUnz5s3TkiVLWoyfNm2azpw5ox07dvjXTZgwQampqSosLJRlWUpMTNQjjzyiRx99VJLk8XiUkJCgzZs364c//GFQ8/J6vYqOjpbH45HT6eyEMwUAhEp7nrNDdqfW2NiosrIyZWVlfX0wu11ZWVkqKSlpdZuSkpKA8ZKUnZ3tH3/kyBG53e6AMdHR0Ro/fvwl9ylJDQ0N8nq9AQsAwDwhi9qJEyfU3NyshISEgPUJCQlyu92tbuN2u9scf/Gf7dmnJBUUFCg6Otq/JCUltft8AADdX6/49OPSpUvl8Xj8S1VVVVdPCQAQAiGL2oABAxQWFqba2tqA9bW1tXK5XK1u43K52hx/8Z/t2ackRUZGyul0BiwAAPOELGoOh0NpaWkqLi72r/P5fCouLlZGRkar22RkZASMl6SioiL/+OHDh8vlcgWM8Xq9Ki0tveQ+AQC9R3god56fn6+ZM2dq3LhxSk9P19q1a3XmzBnNmjVLkjRjxgwNHjxYBQUFkqQFCxYoMzNTq1ev1u23366XX35Z+/bt06ZNmyRJNptNDz/8sH76059qxIgRGj58uJYvX67ExERNnTo1lKcCAOgBQhq1adOmqa6uTitWrJDb7VZqaqp27tzp/6BHZWWl7PavbxYnTpyobdu2admyZXr88cc1YsQIvf766xo9erR/zKJFi3TmzBnNnj1b9fX1mjRpknbu3KmoqKhQngoAoAcI6c+pdVf8nBoA9Bzd4ufUAAC42ogaAMAYRA0AYAyiBgAwBlEDABiDqAEAjEHUAADGIGoAAGMQNQCAMYgaAMAYRA0AYAyiBgAwBlEDABiDqAEAjEHUAADGIGoAAGMQNQCAMYgaAMAYRA0AYAyiBgAwBlEDABiDqAEAjEHUAADGIGoAAGMQNQCAMYgaAMAYRA0AYAyiBgAwBlEDABiDqAEAjEHUAADGIGoAAGMQNQCAMYgaAMAYRA0AYAyiBgAwBlEDABiDqAEAjEHUAADGIGoAAGMQNQCAMYgaAMAYRA0AYAyiBgAwBlEDABiDqAEAjEHUAADGIGoAAGMQNQCAMYgaAMAYRA0AYAyiBgAwBlEDABiDqAEAjEHUAADGIGoAAGMQNQCAMYgaAMAYRA0AYIyQRe3kyZOaPn26nE6nYmJilJeXp9OnT7e5zfnz5zVnzhzFxcXpmmuuUU5Ojmpra/1f/+Mf/6jc3FwlJSWpT58+SklJ0bp160J1CgCAHiZkUZs+fboOHjyooqIi7dixQ++//75mz57d5jYLFy7Um2++qe3bt2vXrl2qrq7WXXfd5f96WVmZBg4cqJdeekkHDx7UE088oaVLl2rDhg2hOg0AQA9isyzL6uydHj58WKNGjdLevXs1btw4SdLOnTv1z//8z/rrX/+qxMTEFtt4PB7Fx8dr27ZtuvvuuyVJn376qVJSUlRSUqIJEya0eqw5c+bo8OHDevfdd4Oen9frVXR0tDwej5xOZwfOEABwtbTnOTskd2olJSWKiYnxB02SsrKyZLfbVVpa2uo2ZWVlampqUlZWln/dyJEjlZycrJKSkksey+PxKDY2ts35NDQ0yOv1BiwAAPOEJGput1sDBw4MWBceHq7Y2Fi53e5LbuNwOBQTExOwPiEh4ZLb7N69W6+88splX9YsKChQdHS0f0lKSgr+ZAAAPUa7orZkyRLZbLY2l08//TRUcw1w4MAB/eAHP9DKlSt16623tjl26dKl8ng8/qWqquqqzBEAcHWFt2fwI488ovvuu6/NMdddd51cLpeOHz8esP7ChQs6efKkXC5Xq9u5XC41Njaqvr4+4G6ttra2xTaHDh3SlClTNHv2bC1btuyy846MjFRkZORlxwEAerZ2RS0+Pl7x8fGXHZeRkaH6+nqVlZUpLS1NkvTuu+/K5/Np/PjxrW6TlpamiIgIFRcXKycnR5JUUVGhyspKZWRk+McdPHhQt9xyi2bOnKmf/exn7Zk+AMBwIfn0oyTddtttqq2tVWFhoZqamjRr1iyNGzdO27ZtkyQdO3ZMU6ZM0ZYtW5Seni5Jeuihh/TWW29p8+bNcjqdmjdvnqSv3juTvnrJ8ZZbblF2draeffZZ/7HCwsKCiu1FfPoRAHqO9jxnt+tOrT22bt2quXPnasqUKbLb7crJydH69ev9X29qalJFRYXOnj3rX7dmzRr/2IaGBmVnZ2vjxo3+r7/66quqq6vTSy+9pJdeesm/fujQoTp69GioTgUA0EOE7E6tO+NODQB6ji7/OTUAALoCUQMAGIOoAQCMQdQAAMYgagAAYxA1AIAxiBoAwBhEDQBgDKIGADAGUQMAGIOoAQCMQdQAAMYgagAAYxA1AIAxiBoAwBhEDQBgDKIGADAGUQMAGIOoAQCMQdQAAMYgagAAYxA1AIAxiBoAwBhEDQBgDKIGADAGUQMAGIOoAQCMQdQAAMYgagAAYxA1AIAxiBoAwBhEDQBgDKIGADAGUQMAGIOoAQCMQdQAAMYgagAAYxA1AIAxiBoAwBhEDQBgDKIGADAGUQMAGIOoAQCMQdQAAMYgagAAYxA1AIAxiBoAwBhEDQBgDKIGADAGUQMAGIOoAQCMQdQAAMYgagAAYxA1AIAxiBoAwBhEDQBgDKIGADAGUQMAGIOoAQCMEbKonTx5UtOnT5fT6VRMTIzy8vJ0+vTpNrc5f/685syZo7i4OF1zzTXKyclRbW1tq2P/9re/aciQIbLZbKqvrw/BGQAAepqQRW369Ok6ePCgioqKtGPHDr3//vuaPXt2m9ssXLhQb775prZv365du3apurpad911V6tj8/Ly9O1vfzsUUwcA9FRWCBw6dMiSZO3du9e/7u2337ZsNpt17NixVrepr6+3IiIirO3bt/vXHT582JJklZSUBIzduHGjlZmZaRUXF1uSrC+//LJd8/N4PJYky+PxtGs7AMDV157n7JDcqZWUlCgmJkbjxo3zr8vKypLdbldpaWmr25SVlampqUlZWVn+dSNHjlRycrJKSkr86w4dOqSnnnpKW7Zskd0e3PQbGhrk9XoDFgCAeUISNbfbrYEDBwasCw8PV2xsrNxu9yW3cTgciomJCVifkJDg36ahoUG5ubl69tlnlZycHPR8CgoKFB0d7V+SkpLad0IAgB6hXVFbsmSJbDZbm8unn34aqrlq6dKlSklJ0T333NPu7Twej3+pqqoK0QwBAF0pvD2DH3nkEd13331tjrnuuuvkcrl0/PjxgPUXLlzQyZMn5XK5Wt3O5XKpsbFR9fX1AXdrtbW1/m3effdd/elPf9Krr74qSbIsS5I0YMAAPfHEE3ryySdb3XdkZKQiIyODOUUAQA/WrqjFx8crPj7+suMyMjJUX1+vsrIypaWlSfoqSD6fT+PHj291m7S0NEVERKi4uFg5OTmSpIqKClVWViojI0OS9Jvf/Ebnzp3zb7N3717df//9+uCDD3T99de351QAAAZqV9SClZKSou9973t64IEHVFhYqKamJs2dO1c//OEPlZiYKEk6duyYpkyZoi1btig9PV3R0dHKy8tTfn6+YmNj5XQ6NW/ePGVkZGjChAmS1CJcJ06c8B/v79+LAwD0PiGJmiRt3bpVc+fO1ZQpU2S325WTk6P169f7v97U1KSKigqdPXvWv27NmjX+sQ0NDcrOztbGjRtDNUUAgGFs1sU3pnoRr9er6OhoeTweOZ3Orp4OAKAN7XnO5nc/AgCMQdQAAMYgagAAYxA1AIAxiBoAwBhEDQBgDKIGADAGUQMAGIOoAQCMQdQAAMYgagAAYxA1AIAxiBoAwBhEDQBgDKIGADAGUQMAGIOoAQCMQdQAAMYgagAAYxA1AIAxiBoAwBhEDQBgDKIGADAGUQMAGIOoAQCMQdQAAMYgagAAYxA1AIAxiBoAwBhEDQBgDKIGADAGUQMAGIOoAQCMQdQAAMYgagAAYxA1AIAxiBoAwBhEDQBgDKIGADAGUQMAGIOoAQCMQdQAAMYgagAAYxA1AIAxwrt6Al3BsixJktfr7eKZAAAu5+Jz9cXn7rb0yqidOnVKkpSUlNTFMwEABOvUqVOKjo5uc4zNCiZ9hvH5fKqurlb//v1ls9muaF9er1dJSUmqqqqS0+m84nGh2GdXjWOOnEt3OnZvPJeeMsfLsSxLp06dUmJiouz2tt8165V3ana7XUOGDOnUfTqdzqAuXLDjQrHPrhrHHK/uOOZ4dccxx86bY1sud4d2ER8UAQAYg6gBAIxB1K5QZGSkVq5cqcjIyE4ZF4p9dtU45si5dKdj98Zz6Slz7Ey98oMiAAAzcacGADAGUQMAGIOoAQCMQdQA4O+8/fbb2rJlS1dPAx1A1LqhyZMn6+GHH77sOMuyNHv2bMXGxspms+njjz++onEmCPZ7Z/ocuoIp5/3Xv/5VDz/8sNauXasPP/ywU/ZpyvdG6v7n0it/o0h399prrykiIuKy43bu3KnNmzfrvffe03XXXacBAwZc0TjgSgT7uO3ufvzjH2vjxo1KSkrSj370I/3hD39QVFRUi3GTJ09Wamqq1q5de9l9mvK9kbr/uRC1big2NjaocZ9//rkGDRqkiRMndsq4UGtsbJTD4ejSOaDjLnf9gn3cdndvvPGG/9/37dvXKfvsad+btq51dz8XXn68Aj6fT6tWrdINN9ygyMhIJScn62c/+1mLcQ0NDZo/f74GDhyoqKgoTZo0SXv37r3kfoO5vb/vvvs0b948VVZWymazadiwYVc0zufzqaCgQMOHD1efPn00duxYvfrqqwFjhg0b1uL/SlNTU/WTn/zkkucxd+5cPfzwwxowYICys7PbPKfOcuHCBc2dO1fR0dEaMGCAli9f3upfWRHsdWnveQdj586dmjRpkmJiYhQXF6c77rhDn3/+ecCYyZMna/78+Vq0aJFiY2PlcrkuecxTp05p+vTp6tevnwYNGqQ1a9a0+jgK5jp/8/jtuX6Xe9y259id6eJ5BPOYCOa6SF/9d7Vr1y6tW7dONptNNptNR48ebXMOl/tv+tVXX9WYMWPUp08fxcXFKSsrS2fOnOnwuFBd62DOpauutUTUrsjSpUv1zDPPaPny5Tp06JC2bdumhISEFuMWLVqk3/zmN/r1r3+t8vJy3XDDDcrOztbJkyc7fOx169bpqaee0pAhQ1RTU3PJSAY7rqCgQFu2bFFhYaEOHjyohQsX6p577tGuXbs6PEdJ+vWvfy2Hw6EPP/xQhYWFV7Sv9hwzPDxce/bs0bp16/Tcc8/pv//7v1uMC8V1CdaZM2eUn5+vffv2qbi4WHa7XXfeead8Pl+Lc+nXr59KS0u1atUqPfXUUyoqKmqxv/z8fH344Yd64403VFRUpA8++EDl5eUtxrX3Onfm9QvVYywYwT4mgr0u69atU0ZGhh544AHV1NSopqbmiv4qq5qaGuXm5ur+++/X4cOH9d577+muu+5qEd5gx0m991rLQod4vV4rMjLSeuGFF9ocd/r0aSsiIsLaunWrf11jY6OVmJhorVq1qtVtMjMzrQULFlx2DmvWrLGGDh16xePOnz9v9e3b19q9e3fA+ry8PCs3N9f/56FDh1pr1qwJGDN27Fhr5cqVre43MzPT+s53vnPZ+XWmzMxMKyUlxfL5fP51ixcvtlJSUgLGtee6dOS8g7l+31RXV2dJsv70pz8F7GfSpEkB42666SZr8eLFAeu8Xq8VERFhbd++3b+uvr7e6tu3b8A8gr3O3zx+e65fW+fd3mN3pmAfE61p7bp8c7/BXufLjS0rK7MkWUePHm1zP8GOC+W1vty5dOW1tizL4j21Djp8+LAaGho0ZcqUNsd9/vnnampq0s033+xfFxERofT0dB0+fDjU0wzKn//8Z509e1b/9E//FLC+sbFR3/nOd65o32lpaVe0fUdMmDAh4O/Jy8jI0OrVq9Xc3KywsDBJXX9dPvvsM61YsUKlpaU6ceKE/06gsrJSo0eP9o/79re/HbDdoEGDdPz48YB1f/nLX9TU1KT09HT/uujoaN14440B4zpynTvr+oXyMRaMYB4TUvDXpbONHTtWU6ZM0ZgxY5Sdna1bb71Vd999t6699toOjevN15qodVCfPn26egqd5vTp05Kk//mf/9HgwYMDvvbNX0Zqt9tbvMzR1NTU5r779evXSbPsOh0578v5/ve/r6FDh+qFF15QYmKifD6fRo8ercbGxoBxf/8pM5vN1uKlsGAFe52/qbOuX0eO3RWCvS6dLSwsTEVFRdq9e7f+93//V7/4xS/0xBNPqLS0VMOHD2/3uN58rXlPrYNGjBihPn36qLi4uM1x119/vf916ouampq0d+9ejRo1KtTTDMqoUaMUGRmpyspK3XDDDQHLN98niI+PV01Njf/PXq9XR44c6Yopt6m0tDTgzx999JFGjBgR8H/k7bkunX3ef/vb31RRUaFly5ZpypQpSklJ0Zdfftnh/V133XWKiIgIeL/U4/Ho//7v/wLGBXudQ6Erjy0F95ho73VxOBxqbm7utDnabDbdfPPNevLJJ7V//345HA799re/7dC43nytuVProKioKC1evFiLFi2Sw+HQzTffrLq6Oh08eFB5eXn+cf369dNDDz2kxx57TLGxsUpOTtaqVat09uzZgHFdqX///nr00Ue1cOFC+Xw+TZo0SR6PRx9++KGcTqdmzpwpSbrlllu0efNmff/731dMTIxWrFgR8KTQXVRWVio/P1//8R//ofLycv3iF7/Q6tWrA8a057p09nlfe+21iouL06ZNmzRo0CBVVlZqyZIlHd5f//79NXPmTP+5DBw4UCtXrpTdbg94yS3Y6xwKXXlsKbjHRHuvy7Bhw1RaWqqjR4/qmmuuUWxsrOz2jt0nlJaWqri4WLfeeqsGDhyo0tJS1dXVKSUlpUPjevO1JmpXYPny5QoPD9eKFStUXV2tQYMG6cEHH2wx7plnnpHP59O9996rU6dOady4cXrnnXdavA7elZ5++mnFx8eroKBAf/nLXxQTE6Pvfve7evzxx/1jli5dqiNHjuiOO+5QdHS0nn766U67U9u8ebNmzZrV6qe42mvGjBk6d+6c0tPTFRYWpgULFmj27NktxgV7XTr7vO12u15++WXNnz9fo0eP1o033qj169dr8uTJHd7nc889pwcffFB33HGHnE6nFi1apKqqqhY/NBzMdQ6V9h77aj8m2ntdHn30Uc2cOVOjRo3SuXPndOTIkUv+yMzlOJ1Ovf/++1q7dq28Xq+GDh2q1atX67bbbuvQOKlnXevOxN+nhm5h5cqV2rVrl957772unooRzpw5o8GDB2v16tXd5hWB9uqsx0R7fvMHej7u1NAtvP3229qwYUNXT6PH2r9/vz799FOlp6fL4/HoqaeekiT94Ac/6OKZdRyPCXQEUUO3sGfPnq6eQo/3n//5n6qoqJDD4VBaWpo++OCDHv17PnlMoCN4+REAYAw+0g8AMAZRAwAYg6gBAIxB1AAAxiBqAABjEDUAgDGIGgDAGEQNAGCM/wdDW5lHu7yVsAAAAABJRU5ErkJggg==", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" + "cell_type": "markdown", + "metadata": {}, + "source": [ + "One can also retrieve more precise information regarding the tokenizer.\n", + "This can be useful to know how text is parsed before being given to the\n", + "neural network:" + ], + "id": "b804391a-979a-4a74-a5f7-d8e27550e20e" + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "{0: '',\n", + " 8097: 'lorem ipsum dolor',\n", + " 8172: '',\n", + " 8949: '',\n", + " 15121: 'lorem>',\n", + " 17369: 'ame',\n", + " 18928: '',\n", + " 21355: '',\n", + " 33381: 'or>',\n", + " 35841: 'ipsum dolor',\n", + " 37380: '',\n", + " 44394: '',\n", + " 45738: 'sit>',\n", + " 45871: 'ipsu',\n", + " 48778: 'psu',\n", + " 48931: 'orem>',\n", + " 49786: '',\n", + " 57345: 'it>',\n", + " 57990: 'olor>',\n", + " 60515: 'lor>',\n", + " 60522: 'ore',\n", + " 62809: 'sum>',\n", + " 65472: 'met',\n", + " 65559: '',\n", + " 67778: 'olor',\n", + " 67985: 'orem',\n", + " 68529: 'psum',\n", + " 69320: '',\n", + " 72158: 'rem',\n", + " 73818: 'ipsum>',\n", + " 74637: 'dolor sit',\n", + " 76593: 'lor',\n", + " 77594: '',\n", + " 87627: '',\n", + " 92771: 'psum>',\n", + " 92809: '', '', 'H '],\n", + " ['', '', 'e '],\n", + " ['', '', 'l '],\n", + " ['', '', 'l '],\n", + " ['', '', 'o '],\n", + " [''],\n", + " ['', '', 'w '],\n", + " ['', '', 'o '],\n", + " ['', '', 'r '],\n", + " ['', '', 'l '],\n", + " ['', '', 'd ']],\n", + " [tensor([40876, 0, 51965]),\n", + " tensor([51907, 0, 77296]),\n", + " tensor([74312, 0, 26137]),\n", + " tensor([74312, 0, 26137]),\n", + " tensor([ 9853, 0, 53786]),\n", + " tensor([0]),\n", + " tensor([29925, 0, 74978]),\n", + " tensor([ 9853, 0, 53786]),\n", + " tensor([ 8646, 0, 13223]),\n", + " tensor([74312, 0, 26137]),\n", + " tensor([ 89472, 0, 104945])],\n", + " [{40876: '', 0: '', 51965: 'H '},\n", + " {51907: '', 0: '', 77296: 'e '},\n", + " {74312: '', 0: '', 26137: 'l '},\n", + " {74312: '', 0: '', 26137: 'l '},\n", + " {9853: '', 0: '', 53786: 'o '},\n", + " {0: ''},\n", + " {29925: '', 0: '', 74978: 'w '},\n", + " {9853: '', 0: '', 53786: 'o '},\n", + " {8646: '', 0: '', 13223: 'r '},\n", + " {74312: '', 0: '', 26137: 'l '},\n", + " {89472: '', 0: '', 104945: 'd '}],\n", + " [{'': 40876, '': 0, 'H ': 51965},\n", + " {'': 51907, '': 0, 'e ': 77296},\n", + " {'': 74312, '': 0, 'l ': 26137},\n", + " {'': 74312, '': 0, 'l ': 26137},\n", + " {'': 9853, '': 0, 'o ': 53786},\n", + " {'': 0},\n", + " {'': 29925, '': 0, 'w ': 74978},\n", + " {'': 9853, '': 0, 'o ': 53786},\n", + " {'': 8646, '': 0, 'r ': 13223},\n", + " {'': 74312, '': 0, 'l ': 26137},\n", + " {'': 89472, '': 0, 'd ': 104945}])" + ] + } + } + ], + "source": [ + "tokenizer.tokenize(\"Hello world\")" + ], + "id": "0b4964f3" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "However, there is a more straightforward way to do: creating directly\n", + "the `NGramTokenizer` instance:" + ], + "id": "fd5b6899-7831-40a6-9841-bbc1b0804956" + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer = NGramTokenizer(\n", + " **parameters,\n", + " training_text=training_text\n", + " )" + ], + "id": "8a6ee96b" + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "output_type": "display_data", + "metadata": {}, + "data": { + "text/plain": [ + "([['', '', 'H '],\n", + " ['', '', 'e '],\n", + " ['', '', 'l '],\n", + " ['', '', 'l '],\n", + " ['', '', 'o '],\n", + " [''],\n", + " ['', '', 'w '],\n", + " ['', '', 'o '],\n", + " ['', '', 'r '],\n", + " ['', '', 'l '],\n", + " ['', '', 'd ']],\n", + " [tensor([40876, 0, 51965]),\n", + " tensor([51907, 0, 77296]),\n", + " tensor([74312, 0, 26137]),\n", + " tensor([74312, 0, 26137]),\n", + " tensor([ 9853, 0, 53786]),\n", + " tensor([0]),\n", + " tensor([29925, 0, 74978]),\n", + " tensor([ 9853, 0, 53786]),\n", + " tensor([ 8646, 0, 13223]),\n", + " tensor([74312, 0, 26137]),\n", + " tensor([ 89472, 0, 104945])],\n", + " [{40876: '', 0: '', 51965: 'H '},\n", + " {51907: '', 0: '', 77296: 'e '},\n", + " {74312: '', 0: '', 26137: 'l '},\n", + " {74312: '', 0: '', 26137: 'l '},\n", + " {9853: '', 0: '', 53786: 'o '},\n", + " {0: ''},\n", + " {29925: '', 0: '', 74978: 'w '},\n", + " {9853: '', 0: '', 53786: 'o '},\n", + " {8646: '', 0: '', 13223: 'r '},\n", + " {74312: '', 0: '', 26137: 'l '},\n", + " {89472: '', 0: '', 104945: 'd '}],\n", + " [{'': 40876, '': 0, 'H ': 51965},\n", + " {'': 51907, '': 0, 'e ': 77296},\n", + " {'': 74312, '': 0, 'l ': 26137},\n", + " {'': 74312, '': 0, 'l ': 26137},\n", + " {'': 9853, '': 0, 'o ': 53786},\n", + " {'': 0},\n", + " {'': 29925, '': 0, 'w ': 74978},\n", + " {'': 9853, '': 0, 'o ': 53786},\n", + " {'': 8646, '': 0, 'r ': 13223},\n", + " {'': 74312, '': 0, 'l ': 26137},\n", + " {'': 89472, '': 0, 'd ': 104945}])" + ] + } + } + ], + "source": [ + "tokenizer.tokenize(\"Hello world\")" + ], + "id": "776636e6" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Why creating a `NGramTokenizer` separately ? Because model constructor\n", + "is now independent from training data:" + ], + "id": "6b0fd6c0-9740-4a32-9bb2-4a3cfe174ea8" + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "2025-03-05 16:27:41 - torchFastText.model.pytorch_model - num_rows is different from the number of tokens in the tokenizer. Using provided num_rows.\n", + "2025-03-05 16:27:42 - torchFastText.torchFastText - No scheduler parameters provided. Using default parameters (suited for ReduceLROnPlateau)." + ] + } + ], + "source": [ + "model = torchFastText.build_from_tokenizer(\n", + " tokenizer, \n", + " embedding_dim=parameters[\"embedding_dim\"], \n", + " categorical_embedding_dims=parameters[\"categorical_embedding_dims\"], \n", + " sparse=parameters[\"sparse\"], \n", + " lr=parameters_train[\"lr\"], \n", + " num_classes=NUM_CLASSES, \n", + " num_categorical_features=NUM_CAT_VAR, \n", + " categorical_vocabulary_sizes=CAT_VOCAB_SIZE\n", + ")" + ], + "id": "ee5dbe0b" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Warning**:\n", + "\n", + "If the PyTorch model building did not use the training data, please keep\n", + "in mind that its architecture (that you customize here) should match the\n", + "vocabulary size of the categorical variables and the total number of\n", + "class, otherwise the model will raise an error during training.\n", + "\n", + "# Train a torchFastText model directly\n", + "\n", + "If no advanced customization or PyTorch tuning is necessary, there is a\n", + "direct way of training model." + ], + "id": "f53080e9-9d78-479f-a446-2feb4a92b1de" + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "model.train(\n", + " X_train,\n", + " y_train,\n", + " X_test,\n", + " y_test,\n", + " num_epochs=parameters_train['num_epochs'],\n", + " batch_size=parameters_train['batch_size'],\n", + " patience_scheduler=parameters_train['patience'],\n", + " patience_train=parameters_train['patience'],\n", + " lr=parameters_train['lr'],\n", + " verbose = True\n", + ")" + ], + "id": "ce5dc4a1" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Load a trained model from a Lightning checkpoint\n", + "\n", + "/! TOCOMPLETE" + ], + "id": "919b67ed-4a65-4c26-92a9-771a4be3cd15" + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "model.load_from_checkpoint(model.best_model_path) # or any other checkpoint path (string)" + ], + "id": "f560047b" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Predicting from new labels" + ], + "id": "e521a23b-77c4-4b0c-9940-a17b19b8111d" + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "text = [\"coiffeur, boulangerie, pâtisserie\"] # one text description\n", + "X= np.array([[text[0], 0, 0, 0, 0, 0, 0]]) # our new entry\n", + "TOP_K = 5\n", + "\n", + "pred, conf = model.predict(X, top_k=TOP_K)\n", + "pred_naf = encoder.inverse_transform(pred.reshape(-1))\n", + "subset = naf2008.set_index(\"code\").loc[np.flip(pred_naf)]\n", + "\n", + "for i in range(TOP_K-1, -1, -1):\n", + " print(f\"Prediction: {pred_naf[i]}, confidence: {conf[0, i]}, description: {subset['libelle'][pred_naf[i]]}\")" + ], + "id": "dbbad77d" + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Explainability" + ], + "id": "f84e6bff-8fa7-4896-b60a-005ae5f1d3eb" + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "from torchFastText.explainability.visualisation import (\n", + " visualize_letter_scores,\n", + " visualize_word_scores,\n", + ")\n", + "\n", + "pred, conf, all_scores, all_scores_letters = model.predict_and_explain(X)\n", + "visualize_word_scores(all_scores, text, pred_naf.reshape(1, -1))\n", + "visualize_letter_scores(all_scores_letters, text, pred_naf.reshape(1, -1))" + ], + "id": "58c46021" + } + ], + "nbformat": 4, + "nbformat_minor": 5, + "metadata": { + "kernelspec": { + "name": "python3", + "display_name": "Python 3 (ipykernel)", + "language": "python", + "path": "/opt/conda/share/jupyter/kernels/python3" + }, + "language_info": { + "name": "python", + "codemirror_mode": { + "name": "ipython", + "version": "3" + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.7" } - ], - "source": [ - "from torchFastText.explainability.visualisation import (\n", - " visualize_letter_scores,\n", - " visualize_word_scores,\n", - ")\n", - "\n", - "pred, conf, all_scores, all_scores_letters = model.predict_and_explain(X)\n", - "visualize_word_scores(all_scores, text, pred_naf.reshape(1, -1))\n", - "visualize_letter_scores(all_scores_letters, text, pred_naf.reshape(1, -1))" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "base", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.7" } - }, - "nbformat": 4, - "nbformat_minor": 2 -} +} \ No newline at end of file diff --git a/notebooks/example.qmd b/notebooks/example.qmd new file mode 100644 index 0000000..666c33b --- /dev/null +++ b/notebooks/example.qmd @@ -0,0 +1,499 @@ +--- +title: "Exemple d'utilisation de la librairie `TorchFastText`" +--- + + + +_Warning_ + +_`TorchFastText` library is still under active development. Have a regular look to [https://github.com/inseefrlab/torch-fastText](https://github.com/inseefrlab/torch-fastText) for latest information._ + +To install package, you can run the following snippet + +```{python} +#| output: false +#| eval: false + +# Stable version +pip install torchFastText +# Development version +# pip install !https://github.com/InseeFrLab/torch-fastText.git +``` + +# Load and preprocess data + +In that guide, we propose to illustrate main package functionalities using that `DataFrame`: + +```{python} +import pandas as pd +df = pd.read_parquet("https://minio.lab.sspcloud.fr/projet-ape/extractions/20241027_sirene4.parquet") +df = df.sample(10000) +``` + +Our goal will be to build multilabel classification for the `code` variable using `libelle` as feature. + +## Enriching our test dataset + +Unlike `Fasttext`, this package offers the possibility of having several feature columns of different types (string for the text column and additional variables in numeric form, for example). To illustrate that, we propose the following enrichment of the example dataset: + + +```{python} +import pandas as pd +import numpy as np +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import LabelEncoder + +def categorize_surface( + df: pd.DataFrame, surface_feature_name: int, like_sirene_3: bool = True +) -> pd.DataFrame: + """ + Categorize the surface of the activity. + + Args: + df (pd.DataFrame): DataFrame to categorize. + surface_feature_name (str): Name of the surface feature. + like_sirene_3 (bool): If True, categorize like Sirene 3. + + Returns: + pd.DataFrame: DataFrame with a new column "surf_cat". + """ + df_copy = df.copy() + df_copy[surface_feature_name] = df_copy[surface_feature_name].replace("nan", np.nan) + df_copy[surface_feature_name] = df_copy[surface_feature_name].astype(float) + # Check surface feature exists + if surface_feature_name not in df.columns: + raise ValueError(f"Surface feature {surface_feature_name} not found in DataFrame.") + # Check surface feature is a float variable + if not (pd.api.types.is_float_dtype(df_copy[surface_feature_name])): + raise ValueError(f"Surface feature {surface_feature_name} must be a float variable.") + + if like_sirene_3: + # Categorize the surface + df_copy["surf_cat"] = pd.cut( + df_copy[surface_feature_name], + bins=[0, 120, 400, 2500, np.inf], + labels=["1", "2", "3", "4"], + ).astype(str) + else: + # Log transform the surface + df_copy["surf_log"] = np.log(df[surface_feature_name]) + + # Categorize the surface + df_copy["surf_cat"] = pd.cut( + df_copy.surf_log, + bins=[0, 3, 4, 5, 12], + labels=["1", "2", "3", "4"], + ).astype(str) + + df_copy[surface_feature_name] = df_copy["surf_cat"].replace("nan", "0") + df_copy[surface_feature_name] = df_copy[surface_feature_name].astype(int) + df_copy = df_copy.drop(columns=["surf_log", "surf_cat"], errors="ignore") + return df_copy + + +def clean_and_tokenize_df( + df, + categorical_features=["EVT", "CJ", "NAT", "TYP", "CRT"], + text_feature="libelle_processed", + label_col="apet_finale", +): + df.fillna("nan", inplace=True) + + df = df.rename( + columns={ + "evenement_type": "EVT", + "cj": "CJ", + "activ_nat_et": "NAT", + "liasse_type": "TYP", + "activ_surf_et": "SRF", + "activ_perm_et": "CRT", + } + ) + + les = [] + for col in categorical_features: + le = LabelEncoder() + df[col] = le.fit_transform(df[col]) + les.append(le) + + df = categorize_surface(df, "SRF", like_sirene_3=True) + df = df[[text_feature, "EVT", "CJ", "NAT", "TYP", "SRF", "CRT", label_col]] + + return df, les + + +def stratified_split_rare_labels(X, y, test_size=0.2, min_train_samples=1): + # Get unique labels and their frequencies + unique_labels, label_counts = np.unique(y, return_counts=True) + + # Separate rare and common labels + rare_labels = unique_labels[label_counts == 1] + + # Create initial mask for rare labels to go into training set + rare_label_mask = np.isin(y, rare_labels) + + # Separate data into rare and common label datasets + X_rare = X[rare_label_mask] + y_rare = y[rare_label_mask] + X_common = X[~rare_label_mask] + y_common = y[~rare_label_mask] + + # Split common labels stratified + X_common_train, X_common_test, y_common_train, y_common_test = train_test_split( + X_common, y_common, test_size=test_size, stratify=y_common + ) + + # Combine rare labels with common labels split + X_train = np.concatenate([X_rare, X_common_train]) + y_train = np.concatenate([y_rare, y_common_train]) + X_test = X_common_test + y_test = y_common_test + + return X_train, X_test, y_train, y_test + +def add_libelles( + df: pd.DataFrame, + df_naf: pd.DataFrame, + y: str, + text_feature: str, + textual_features: list, + categorical_features: list, +): + missing_codes = set(df_naf["code"]) + fake_obs = df_naf[df_naf["code"].isin(missing_codes)] + fake_obs[y] = fake_obs["code"] + fake_obs[text_feature] = fake_obs[[text_feature]].apply( + lambda row: " ".join(f"[{col}] {val}" for col, val in row.items() if val != ""), axis=1 + ) + df = pd.concat([df, fake_obs[[col for col in fake_obs.columns if col in df.columns]]]) + + if textual_features is not None: + for feature in textual_features: + df[feature] = df[feature].fillna(value="") + if categorical_features is not None: + for feature in categorical_features: + df[feature] = df[feature].fillna(value="NaN") + + print(f"\t*** {len(missing_codes)} codes have been added in the database...\n") + return df +``` + +```{python} +categorical_features = ["evenement_type", "cj", "activ_nat_et", "liasse_type", "activ_surf_et", "activ_perm_et"] +text_feature = "libelle" +y = "apet_finale" +textual_features = None + +naf2008 = pd.read_csv("https://minio.lab.sspcloud.fr/projet-ape/data/naf2008.csv", sep=";") +df = add_libelles(df, naf2008, y, text_feature, textual_features, categorical_features) +``` + + +## Preprocessing + +To reduce noise in text fields, we recommend pre-processing before training a model with our package. We assume this preprocessing is handled by the package user : this gives him the opportunity to control data cleansing. + +Here's an example of the type of preprocessing that can be carried out before moving on to the modeling phase + +```{python} +from torchFastText.preprocess import clean_text_feature +df["libelle_processed"] = clean_text_feature(df["libelle"]) +``` + +Right now, the model requires the label (variable y) to be a numerical variable. If the label variable is a text variable, we recommend using Scikit Learn's [LabelEncoder](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.LabelEncoder.html) to convert into a numeric variable. Using that function will give user the possibility to get back labels from the encoder after running predictions. + +```{python} +encoder = LabelEncoder() +df["apet_finale"] = encoder.fit_transform(df["apet_finale"]) +``` + +The function `clean_and_tokenize_df` requires special `DataFrame` formatting: + +- First column contains the processed text (str) +- Next ones contain the "encoded" categorical (discrete) variables in int format + + +```{python} +df, _ = clean_and_tokenize_df(df, text_feature="libelle_processed") +X = df[["libelle_processed", "EVT", "CJ", "NAT", "TYP", "CRT", "SRF"]].values +y = df["apet_finale"].values +``` + +## Splitting in train-test sets + +As usual in a learning approach, you need to break down your data into learning and test/validation samples to obtain robust performance statistics. + +This work is the responsibility of the package's users. Here's an example of how to do it, using the [`train_test_split`](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html) function in `Scikit`. + +```{python} +from sklearn.model_selection import train_test_split +X_train, X_test, y_train, y_test = train_test_split(X, y) +``` + +# Build the torch-fastText model (without training it) + +There are several ways to define and train a pytorch.fasttext model in this package. + +We first show how to initialize the model and then afterwars build it. + +`torchFastText` function accepts the following parameters: + +| Parameter | Meaning | Example Value | +|---------------------------------|---------------------------------------------------------------------|--------------| +| `num_tokens` | Number of rows in the embedding matrix (size of the vocabulary) | 100000 | +| `embedding_dim` | Dimension of the embedding (number of columns in the matrix) | 50 | +| `sparse` | Use sparse embedding for fast computation (PyTorch) | False | +| `categorical_embedding_dims` | Dimension of the embedding for categorical features | 10 | +| `min_count` | Minimum occurrences of a word in the corpus to be included | 1 | +| `min_n` | Minimum length of character n-grams | 3 | +| `max_n` | Maximum length of character n-grams | 6 | +| `len_word_ngrams` | Length of word n-grams | 3 | + + +```{python} +from torchFastText import torchFastText + +parameters = { + "num_tokens": 100000, + "embedding_dim": 50, + "sparse": False, + "categorical_embedding_dims": 10, + "min_count": 1, + "min_n": 3, + "max_n": 6, + "len_word_ngrams": 3, +} + +parameters_train = { + "lr": 0.004, + "num_epochs": 1, + "batch_size": 256, + "patience": 3 +} + +model = torchFastText(**parameters) +``` + +`model` is then a special `torchFastText` object: + +```{python} +type(model) +``` + +As any `PyTorch` model, it accepts being save as a JSON for later on use: + +```{python} +model.to_json('torchFastText_config.json') +# model = torchFastText.from_json('torchFastText_config.json') +``` + +We can apply `build` to finally train our model. These are the parameters accepted by the `build` method + +| Parameter | Meaning | Example Value | +|---------------------------------|---------------------------------------------------------------------|--------------| +| `lr` | Learning rate | 0.004 | +| `num_epochs` | Number of training epochs | 1 | +| `batch_size` | Batch size for training | 256 | +| `patience` | Early stopping patience (number of epochs without improvement) | 3 | + + +We build the model using the training data. +We have now access to the tokenizer, the PyTorch model as well as a PyTorch Lightning module ready to be trained. +Note that Lightning is high-level framework for PyTorch that simplifies the process of training, validating, and deploying machine learning models. + + +```{python} +model.build(X_train, y_train, lightning=True, lr=parameters_train.get("lr")) +``` + +One can retrieve different objects from `model` instance: + +* `model.pytorch_model` +* `model.tokenizer` +* `model.lightning_module` + + +```{python} +model.pytorch_model +``` + +```{python} +model.tokenizer +``` + +```{python} +model.lightning_module +``` + +One can also retrieve more precise information regarding the tokenizer. This can be useful to know how text is parsed before being given to the neural network: + + +```{python} +from pprint import pprint +sentence = ["lorem ipsum dolor sit amet"] +pprint(model.tokenizer.tokenize(sentence)[2][0]) +``` + + +Saving parameters to JSON can also be done after building, but the model needs to be rebuilt after loading. + +```{python} +model.to_json('torchFastText_config.json') +``` + + +## Alternative way to build torchFastText + +The training data is only useful to initialize the tokenizer, but X_train and y_train are not needed to initialize the PyTorch model, provided we give the right parameters to construct layer. + +To highlight this, we provide a lower-level process to build the model where one can first build the tokenizer, and then build the model with custom architecture parameters. + +The tokenizer can be loaded **from the same JSON file** as the model parameters, or initialized using the right arguments. + + +```{python} +del model +``` + +Let's decompose our features in two group: + +* We have our textual feature stored in the first column of the features matrix +* All other columns are categorical variables + +```{python} +training_text = X_train[:, 0].tolist() +categorical_variables = X_train[:, 1:] +``` + +We need to create a few variables that will be useful afterwards + +```{python} +CAT_VOCAB_SIZE = (np.max(categorical_variables, axis=0) + 1).astype(int).tolist() +NUM_CLASSES = len(np.unique(y_train)) +NUM_CAT_VAR = categorical_variables.shape[1] +``` + +Now let's come to the nitty gritty. There are several ways to create an instance of the tokenizer. + +First, we can create the tokenizer from : + +* model definition in the JSON file created beforehand +* textual data in training dataset + +```{python} +from torchFastText.datasets import NGramTokenizer +tokenizer = NGramTokenizer.from_json('torchFastText_config.json', training_text) +``` + +```{python} +tokenizer.tokenize("Hello world") +``` + +However, there is a more straightforward way to do: creating directly the `NGramTokenizer` instance: + + +```{python} +tokenizer = NGramTokenizer( + **parameters, + training_text=training_text + ) +``` + +```{python} +tokenizer.tokenize("Hello world") +``` + +Why creating a `NGramTokenizer` separately ? Because model constructor is now independent from training data: + +```{python} +#| echo: false +#| eval: false +# TODO : allow to do that +#torchFastText.build_from_tokenizer( + #tokenizer, + #**parameters, + #**parameters_build +# ) +``` + +```{python} +model = torchFastText.build_from_tokenizer( + tokenizer, + embedding_dim=parameters["embedding_dim"], + categorical_embedding_dims=parameters["categorical_embedding_dims"], + sparse=parameters["sparse"], + lr=parameters_train["lr"], + num_classes=NUM_CLASSES, + num_categorical_features=NUM_CAT_VAR, + categorical_vocabulary_sizes=CAT_VOCAB_SIZE +) +``` + +__Warning__: + +If the PyTorch model building did not use the training data, please keep in mind that its architecture (that you customize here) should match the vocabulary size of the categorical variables and the total number of class, otherwise the model will raise an error during training. + + +# Train a torchFastText model directly + +If no advanced customization or PyTorch tuning is necessary, there is a direct way of training model. + + +```{python} +#| eval: false +model.train( + X_train, + y_train, + X_test, + y_test, + num_epochs=parameters_train['num_epochs'], + batch_size=parameters_train['batch_size'], + patience_scheduler=parameters_train['patience'], + patience_train=parameters_train['patience'], + lr=parameters_train['lr'], + verbose = True +) +``` + +# Load a trained model from a Lightning checkpoint + +/!\ TOCOMPLETE + + +```{python} +#| eval: false +model.load_from_checkpoint(model.best_model_path) # or any other checkpoint path (string) +``` + +# Predicting from new labels + + +```{python} +#| eval: false +text = ["coiffeur, boulangerie, pâtisserie"] # one text description +X= np.array([[text[0], 0, 0, 0, 0, 0, 0]]) # our new entry +TOP_K = 5 + +pred, conf = model.predict(X, top_k=TOP_K) +pred_naf = encoder.inverse_transform(pred.reshape(-1)) +subset = naf2008.set_index("code").loc[np.flip(pred_naf)] + +for i in range(TOP_K-1, -1, -1): + print(f"Prediction: {pred_naf[i]}, confidence: {conf[0, i]}, description: {subset['libelle'][pred_naf[i]]}") + +``` + +# Explainability + + +```{python} +#| eval: false +from torchFastText.explainability.visualisation import ( + visualize_letter_scores, + visualize_word_scores, +) + +pred, conf, all_scores, all_scores_letters = model.predict_and_explain(X) +visualize_word_scores(all_scores, text, pred_naf.reshape(1, -1)) +visualize_letter_scores(all_scores_letters, text, pred_naf.reshape(1, -1)) +``` \ No newline at end of file diff --git a/notebooks/torchFastText_config.json b/notebooks/torchFastText_config.json index cd9edef..ff40f49 100644 --- a/notebooks/torchFastText_config.json +++ b/notebooks/torchFastText_config.json @@ -1,27 +1,22 @@ { - "num_buckets": 100000, "embedding_dim": 50, "sparse": false, + "num_tokens": 100000, "min_count": 1, "min_n": 3, "max_n": 6, "len_word_ngrams": 3, - "num_classes": 732, + "num_classes": 646, + "num_rows": 107992, "categorical_vocabulary_sizes": [ - 21, - 26, + 24, + 40, 8, - 12, + 13, 3, 4 ], - "categorical_embedding_dims": [ - 10, - 10, - 10, - 10, - 10, - 10 - ], - "num_categorical_features": 6 + "categorical_embedding_dims": 10, + "num_categorical_features": 6, + "direct_bagging": true } \ No newline at end of file