diff --git a/notebooks/example.ipynb b/notebooks/example.ipynb index 9e58733..c173d83 100644 --- a/notebooks/example.ipynb +++ b/notebooks/example.ipynb @@ -2,6 +2,7 @@ "cells": [ { "cell_type": "markdown", + "id": "a01b1526-51df-4bf9-9fd4-11ef22ffcc79", "metadata": {}, "source": [ "# Exemple d’utilisation de la librairie `TorchFastText`\n", @@ -13,47 +14,56 @@ "latest information.*\n", "\n", "To install package, you can run the following snippet" - ], - "id": "a01b1526-51df-4bf9-9fd4-11ef22ffcc79" + ] }, { "cell_type": "code", "execution_count": 1, + "id": "a00a2856", "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "SyntaxError", + "evalue": "invalid syntax (2499593094.py, line 2)", + "output_type": "error", + "traceback": [ + " \u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[31m \u001b[39m\u001b[31mpip install torchFastText\u001b[39m\n ^\n\u001b[31mSyntaxError\u001b[39m\u001b[31m:\u001b[39m invalid syntax\n" + ] + } + ], "source": [ "# Stable version\n", "pip install torchFastText \n", "# Development version\n", "# pip install !https://github.com/InseeFrLab/torch-fastText.git" - ], - "id": "a00a2856" + ] }, { "cell_type": "markdown", + "id": "b292ea76-57a1-4d4e-9bde-dcc9656dc447", "metadata": {}, "source": [ "# Load and preprocess data\n", "\n", "In that guide, we propose to illustrate main package functionalities\n", "using that `DataFrame`:" - ], - "id": "b292ea76-57a1-4d4e-9bde-dcc9656dc447" + ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, + "id": "37c042fe", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "df = pd.read_parquet(\"https://minio.lab.sspcloud.fr/projet-ape/extractions/20241027_sirene4.parquet\")\n", "df = df.sample(10000)" - ], - "id": "37c042fe" + ] }, { "cell_type": "markdown", + "id": "c399b4b0-a9cb-450e-9a5e-480e0e657b8e", "metadata": {}, "source": [ "Our goal will be to build multilabel classification for the `code`\n", @@ -65,12 +75,12 @@ "feature columns of different types (string for the text column and\n", "additional variables in numeric form, for example). To illustrate that,\n", "we propose the following enrichment of the example dataset:" - ], - "id": "c399b4b0-a9cb-450e-9a5e-480e0e657b8e" + ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, + "id": "92402df7", "metadata": {}, "outputs": [], "source": [ @@ -212,19 +222,20 @@ "\n", " print(f\"\\t*** {len(missing_codes)} codes have been added in the database...\\n\")\n", " return df" - ], - "id": "92402df7" + ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 8, + "id": "1fd02895", "metadata": {}, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ - " *** 732 codes have been added in the database...\n" + "\t*** 732 codes have been added in the database...\n", + "\n" ] } ], @@ -236,11 +247,11 @@ "\n", "naf2008 = pd.read_csv(\"https://minio.lab.sspcloud.fr/projet-ape/data/naf2008.csv\", sep=\";\")\n", "df = add_libelles(df, naf2008, y, text_feature, textual_features, categorical_features)" - ], - "id": "1fd02895" + ] }, { "cell_type": "markdown", + "id": "67f4160d-0c98-4700-80f4-1ba454e6a2df", "metadata": {}, "source": [ "## Preprocessing\n", @@ -252,22 +263,22 @@ "\n", "Here’s an example of the type of preprocessing that can be carried out\n", "before moving on to the modeling phase" - ], - "id": "67f4160d-0c98-4700-80f4-1ba454e6a2df" + ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 9, + "id": "61b0252e", "metadata": {}, "outputs": [], "source": [ "from torchFastText.preprocess import clean_text_feature\n", "df[\"libelle_processed\"] = clean_text_feature(df[\"libelle\"])" - ], - "id": "61b0252e" + ] }, { "cell_type": "markdown", + "id": "acde2929-fe92-4107-8066-a5c8ac5d6428", "metadata": {}, "source": [ "Right now, the model requires the label (variable y) to be a numerical\n", @@ -277,22 +288,22 @@ "to convert into a numeric variable. Using that function will give user\n", "the possibility to get back labels from the encoder after running\n", "predictions." - ], - "id": "acde2929-fe92-4107-8066-a5c8ac5d6428" + ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 10, + "id": "8c02a833", "metadata": {}, "outputs": [], "source": [ "encoder = LabelEncoder()\n", "df[\"apet_finale\"] = encoder.fit_transform(df[\"apet_finale\"])" - ], - "id": "8c02a833" + ] }, { "cell_type": "markdown", + "id": "25593e1a-1661-49e3-9734-272ec4745de1", "metadata": {}, "source": [ "The function `clean_and_tokenize_df` requires special `DataFrame`\n", @@ -301,20 +312,20 @@ "- First column contains the processed text (str)\n", "- Next ones contain the “encoded” categorical (discrete) variables in\n", " int format" - ], - "id": "25593e1a-1661-49e3-9734-272ec4745de1" + ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 11, + "id": "5fb5b0c7", "metadata": {}, "outputs": [ { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ - "/tmp/ipykernel_90631/2075507147.py:60: FutureWarning: Setting an item of incompatible dtype is deprecated and will raise an error in a future version of pandas. Value 'nan' has dtype incompatible with float64, please explicitly cast to a compatible dtype first.\n", - " df.fillna(\"nan\", inplace=True)" + "/tmp/ipykernel_26526/2075507147.py:60: FutureWarning: Setting an item of incompatible dtype is deprecated and will raise an error in a future version of pandas. Value 'nan' has dtype incompatible with float64, please explicitly cast to a compatible dtype first.\n", + " df.fillna(\"nan\", inplace=True)\n" ] } ], @@ -322,11 +333,11 @@ "df, _ = clean_and_tokenize_df(df, text_feature=\"libelle_processed\")\n", "X = df[[\"libelle_processed\", \"EVT\", \"CJ\", \"NAT\", \"TYP\", \"CRT\", \"SRF\"]].values\n", "y = df[\"apet_finale\"].values" - ], - "id": "5fb5b0c7" + ] }, { "cell_type": "markdown", + "id": "e70de831-dbc9-49be-b0c4-d70dd6479d03", "metadata": {}, "source": [ "## Splitting in train-test sets\n", @@ -335,26 +346,35 @@ "learning and test/validation samples to obtain robust performance\n", "statistics.\n", "\n", - "This work is the responsibility of the package’s users. Here’s an\n", - "example of how to do it, using the\n", - "[`train_test_split`](https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html)\n", - "function in `Scikit`." - ], - "id": "e70de831-dbc9-49be-b0c4-d70dd6479d03" + "This work is the responsibility of the package’s users. Please make sure that `np.max(y_train) == len(np.unique(y_train))-1` (i.e. your labels are well encoded, in a consecutive manner, starting from 0), and that all the possible labels appear at least once in the training set.\n", + "\n", + "We provide the function `stratified_train_test_split` to match these requirements here." + ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 19, + "id": "b593fd75", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "True\n" + ] + } + ], "source": [ "from sklearn.model_selection import train_test_split \n", - "X_train, X_test, y_train, y_test = train_test_split(X, y)" - ], - "id": "b593fd75" + "X_train, X_test, y_train, y_test = stratified_split_rare_labels(X, y)\n", + "\n", + "print(np.max(y_train) == len(np.unique(y_train))-1)" + ] }, { "cell_type": "markdown", + "id": "8729c5f4-9038-4437-929b-fc500dc0db7a", "metadata": {}, "source": [ "# Build the torch-fastText model (without training it)\n", @@ -376,12 +396,12 @@ "| `min_n` | Minimum length of character n-grams | 3 |\n", "| `max_n` | Maximum length of character n-grams | 6 |\n", "| `len_word_ngrams` | Length of word n-grams | 3 |" - ], - "id": "8729c5f4-9038-4437-929b-fc500dc0db7a" + ] }, { "cell_type": "code", "execution_count": 9, + "id": "5879ca88", "metadata": {}, "outputs": [], "source": [ @@ -406,59 +426,59 @@ "}\n", "\n", "model = torchFastText(**parameters)" - ], - "id": "5879ca88" + ] }, { "cell_type": "markdown", + "id": "05f9d26b-f08f-41be-93e4-b55a2c86690c", "metadata": {}, "source": [ "`model` is then a special `torchFastText` object:" - ], - "id": "05f9d26b-f08f-41be-93e4-b55a2c86690c" + ] }, { "cell_type": "code", "execution_count": 10, + "id": "ebf5608b", "metadata": {}, "outputs": [ { - "output_type": "display_data", - "metadata": {}, "data": { "text/plain": [ "torchFastText.torchFastText.torchFastText" ] - } + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ "type(model)" - ], - "id": "ebf5608b" + ] }, { "cell_type": "markdown", + "id": "dcbe8289-f506-48f9-b854-96f25974368f", "metadata": {}, "source": [ "As any `PyTorch` model, it accepts being save as a JSON for later on\n", "use:" - ], - "id": "dcbe8289-f506-48f9-b854-96f25974368f" + ] }, { "cell_type": "code", "execution_count": 11, + "id": "6c3b2b85", "metadata": {}, "outputs": [], "source": [ "model.to_json('torchFastText_config.json')\n", "# model = torchFastText.from_json('torchFastText_config.json')" - ], - "id": "6c3b2b85" + ] }, { "cell_type": "markdown", + "id": "5f8b017f-66a1-413d-85e8-1981adf64823", "metadata": {}, "source": [ "We can apply `build` to finally train our model. These are the\n", @@ -476,17 +496,17 @@ "to be trained. Note that Lightning is high-level framework for PyTorch\n", "that simplifies the process of training, validating, and deploying\n", "machine learning models." - ], - "id": "5f8b017f-66a1-413d-85e8-1981adf64823" + ] }, { "cell_type": "code", "execution_count": 12, + "id": "e2e43d0e", "metadata": {}, "outputs": [ { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "2025-03-05 16:27:41 - torchFastText.model.pytorch_model - num_rows is different from the number of tokens in the tokenizer. Using provided num_rows.\n", "2025-03-05 16:27:41 - torchFastText.torchFastText - No scheduler parameters provided. Using default parameters (suited for ReduceLROnPlateau)." @@ -495,11 +515,11 @@ ], "source": [ "model.build(X_train, y_train, lightning=True, lr=parameters_train.get(\"lr\"))" - ], - "id": "e2e43d0e" + ] }, { "cell_type": "markdown", + "id": "b5a7d5fa-596a-470b-892e-e8fafdb8221a", "metadata": {}, "source": [ "One can retrieve different objects from `model` instance:\n", @@ -507,17 +527,15 @@ "- `model.pytorch_model`\n", "- `model.tokenizer`\n", "- `model.lightning_module`" - ], - "id": "b5a7d5fa-596a-470b-892e-e8fafdb8221a" + ] }, { "cell_type": "code", "execution_count": 13, + "id": "091024e6", "metadata": {}, "outputs": [ { - "output_type": "display_data", - "metadata": {}, "data": { "text/plain": [ "FastTextModel(\n", @@ -531,42 +549,42 @@ " (fc): Linear(in_features=60, out_features=646, bias=True)\n", ")" ] - } + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ "model.pytorch_model" - ], - "id": "091024e6" + ] }, { "cell_type": "code", "execution_count": 14, + "id": "d983b113", "metadata": {}, "outputs": [ { - "output_type": "display_data", - "metadata": {}, "data": { "text/plain": [ "" ] - } + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ "model.tokenizer" - ], - "id": "d983b113" + ] }, { "cell_type": "code", "execution_count": 15, + "id": "9b23f1ba", "metadata": {}, "outputs": [ { - "output_type": "display_data", - "metadata": {}, "data": { "text/plain": [ "FastTextModule(\n", @@ -584,32 +602,34 @@ " (accuracy_fn): MulticlassAccuracy()\n", ")" ] - } + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ "model.lightning_module" - ], - "id": "9b23f1ba" + ] }, { "cell_type": "markdown", + "id": "b804391a-979a-4a74-a5f7-d8e27550e20e", "metadata": {}, "source": [ "One can also retrieve more precise information regarding the tokenizer.\n", "This can be useful to know how text is parsed before being given to the\n", "neural network:" - ], - "id": "b804391a-979a-4a74-a5f7-d8e27550e20e" + ] }, { "cell_type": "code", "execution_count": 16, + "id": "00c077b0", "metadata": {}, "outputs": [ { - "output_type": "stream", "name": "stdout", + "output_type": "stream", "text": [ "{0: '',\n", " 8097: 'lorem ipsum dolor',\n", @@ -685,30 +705,30 @@ "from pprint import pprint \n", "sentence = [\"lorem ipsum dolor sit amet\"]\n", "pprint(model.tokenizer.tokenize(sentence)[2][0])" - ], - "id": "00c077b0" + ] }, { "cell_type": "markdown", + "id": "c1ed6afa-6b7d-4d51-a4f0-3e8845a38704", "metadata": {}, "source": [ "Saving parameters to JSON can also be done after building, but the model\n", "needs to be rebuilt after loading." - ], - "id": "c1ed6afa-6b7d-4d51-a4f0-3e8845a38704" + ] }, { "cell_type": "code", "execution_count": 17, + "id": "bab40010", "metadata": {}, "outputs": [], "source": [ "model.to_json('torchFastText_config.json')" - ], - "id": "bab40010" + ] }, { "cell_type": "markdown", + "id": "017f8d12-0be8-45df-a0e4-80919c89db2d", "metadata": {}, "source": [ "## Alternative way to build torchFastText\n", @@ -723,21 +743,21 @@ "\n", "The tokenizer can be loaded **from the same JSON file** as the model\n", "parameters, or initialized using the right arguments." - ], - "id": "017f8d12-0be8-45df-a0e4-80919c89db2d" + ] }, { "cell_type": "code", "execution_count": 18, + "id": "3b13330d", "metadata": {}, "outputs": [], "source": [ "del model" - ], - "id": "3b13330d" + ] }, { "cell_type": "markdown", + "id": "c6d7a335-bd31-455c-9184-48d2f2d60fbd", "metadata": {}, "source": [ "Let’s decompose our features in two group:\n", @@ -745,42 +765,42 @@ "- We have our textual feature stored in the first column of the\n", " features matrix\n", "- All other columns are categorical variables" - ], - "id": "c6d7a335-bd31-455c-9184-48d2f2d60fbd" + ] }, { "cell_type": "code", "execution_count": 19, + "id": "5f75f055", "metadata": {}, "outputs": [], "source": [ "training_text = X_train[:, 0].tolist()\n", "categorical_variables = X_train[:, 1:]" - ], - "id": "5f75f055" + ] }, { "cell_type": "markdown", + "id": "adc8da37-5b6f-4c8a-8198-6c4080ffc7be", "metadata": {}, "source": [ "We need to create a few variables that will be useful afterwards" - ], - "id": "adc8da37-5b6f-4c8a-8198-6c4080ffc7be" + ] }, { "cell_type": "code", "execution_count": 20, + "id": "931103ec", "metadata": {}, "outputs": [], "source": [ "CAT_VOCAB_SIZE = (np.max(categorical_variables, axis=0) + 1).astype(int).tolist()\n", "NUM_CLASSES = len(np.unique(y_train))\n", "NUM_CAT_VAR = categorical_variables.shape[1]" - ], - "id": "931103ec" + ] }, { "cell_type": "markdown", + "id": "0d3ad544-d0f6-4d46-b322-af979c48bb43", "metadata": {}, "source": [ "Now let’s come to the nitty gritty. There are several ways to create an\n", @@ -790,28 +810,26 @@ "\n", "- model definition in the JSON file created beforehand\n", "- textual data in training dataset" - ], - "id": "0d3ad544-d0f6-4d46-b322-af979c48bb43" + ] }, { "cell_type": "code", "execution_count": 21, + "id": "0357c85f", "metadata": {}, "outputs": [], "source": [ "from torchFastText.datasets import NGramTokenizer\n", "tokenizer = NGramTokenizer.from_json('torchFastText_config.json', training_text)" - ], - "id": "0357c85f" + ] }, { "cell_type": "code", "execution_count": 22, + "id": "0b4964f3", "metadata": {}, "outputs": [ { - "output_type": "display_data", - "metadata": {}, "data": { "text/plain": [ "([['', '', 'H '],\n", @@ -859,26 +877,28 @@ " {'': 74312, '': 0, 'l ': 26137},\n", " {'': 89472, '': 0, 'd ': 104945}])" ] - } + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ "tokenizer.tokenize(\"Hello world\")" - ], - "id": "0b4964f3" + ] }, { "cell_type": "markdown", + "id": "fd5b6899-7831-40a6-9841-bbc1b0804956", "metadata": {}, "source": [ "However, there is a more straightforward way to do: creating directly\n", "the `NGramTokenizer` instance:" - ], - "id": "fd5b6899-7831-40a6-9841-bbc1b0804956" + ] }, { "cell_type": "code", "execution_count": 23, + "id": "8a6ee96b", "metadata": {}, "outputs": [], "source": [ @@ -886,17 +906,15 @@ " **parameters,\n", " training_text=training_text\n", " )" - ], - "id": "8a6ee96b" + ] }, { "cell_type": "code", "execution_count": 24, + "id": "776636e6", "metadata": {}, "outputs": [ { - "output_type": "display_data", - "metadata": {}, "data": { "text/plain": [ "([['', '', 'H '],\n", @@ -944,31 +962,33 @@ " {'': 74312, '': 0, 'l ': 26137},\n", " {'': 89472, '': 0, 'd ': 104945}])" ] - } + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ "tokenizer.tokenize(\"Hello world\")" - ], - "id": "776636e6" + ] }, { "cell_type": "markdown", + "id": "6b0fd6c0-9740-4a32-9bb2-4a3cfe174ea8", "metadata": {}, "source": [ "Why creating a `NGramTokenizer` separately ? Because model constructor\n", "is now independent from training data:" - ], - "id": "6b0fd6c0-9740-4a32-9bb2-4a3cfe174ea8" + ] }, { "cell_type": "code", "execution_count": 26, + "id": "ee5dbe0b", "metadata": {}, "outputs": [ { - "output_type": "stream", "name": "stderr", + "output_type": "stream", "text": [ "2025-03-05 16:27:41 - torchFastText.model.pytorch_model - num_rows is different from the number of tokens in the tokenizer. Using provided num_rows.\n", "2025-03-05 16:27:42 - torchFastText.torchFastText - No scheduler parameters provided. Using default parameters (suited for ReduceLROnPlateau)." @@ -986,11 +1006,11 @@ " num_categorical_features=NUM_CAT_VAR, \n", " categorical_vocabulary_sizes=CAT_VOCAB_SIZE\n", ")" - ], - "id": "ee5dbe0b" + ] }, { "cell_type": "markdown", + "id": "f53080e9-9d78-479f-a446-2feb4a92b1de", "metadata": {}, "source": [ "**Warning**:\n", @@ -1004,12 +1024,12 @@ "\n", "If no advanced customization or PyTorch tuning is necessary, there is a\n", "direct way of training model." - ], - "id": "f53080e9-9d78-479f-a446-2feb4a92b1de" + ] }, { "cell_type": "code", "execution_count": 27, + "id": "ce5dc4a1", "metadata": {}, "outputs": [], "source": [ @@ -1025,40 +1045,40 @@ " lr=parameters_train['lr'],\n", " verbose = True\n", ")" - ], - "id": "ce5dc4a1" + ] }, { "cell_type": "markdown", + "id": "919b67ed-4a65-4c26-92a9-771a4be3cd15", "metadata": {}, "source": [ "# Load a trained model from a Lightning checkpoint\n", "\n", "/! TOCOMPLETE" - ], - "id": "919b67ed-4a65-4c26-92a9-771a4be3cd15" + ] }, { "cell_type": "code", "execution_count": 28, + "id": "f560047b", "metadata": {}, "outputs": [], "source": [ "model.load_from_checkpoint(model.best_model_path) # or any other checkpoint path (string)" - ], - "id": "f560047b" + ] }, { "cell_type": "markdown", + "id": "e521a23b-77c4-4b0c-9940-a17b19b8111d", "metadata": {}, "source": [ "# Predicting from new labels" - ], - "id": "e521a23b-77c4-4b0c-9940-a17b19b8111d" + ] }, { "cell_type": "code", "execution_count": 29, + "id": "dbbad77d", "metadata": {}, "outputs": [], "source": [ @@ -1072,20 +1092,20 @@ "\n", "for i in range(TOP_K-1, -1, -1):\n", " print(f\"Prediction: {pred_naf[i]}, confidence: {conf[0, i]}, description: {subset['libelle'][pred_naf[i]]}\")" - ], - "id": "dbbad77d" + ] }, { "cell_type": "markdown", + "id": "f84e6bff-8fa7-4896-b60a-005ae5f1d3eb", "metadata": {}, "source": [ "# Explainability" - ], - "id": "f84e6bff-8fa7-4896-b60a-005ae5f1d3eb" + ] }, { "cell_type": "code", "execution_count": 30, + "id": "58c46021", "metadata": {}, "outputs": [], "source": [ @@ -1097,30 +1117,28 @@ "pred, conf, all_scores, all_scores_letters = model.predict_and_explain(X)\n", "visualize_word_scores(all_scores, text, pred_naf.reshape(1, -1))\n", "visualize_letter_scores(all_scores_letters, text, pred_naf.reshape(1, -1))" - ], - "id": "58c46021" + ] } ], - "nbformat": 4, - "nbformat_minor": 5, "metadata": { "kernelspec": { - "name": "python3", - "display_name": "Python 3 (ipykernel)", + "display_name": ".venv", "language": "python", - "path": "/opt/conda/share/jupyter/kernels/python3" + "name": "python3" }, "language_info": { - "name": "python", "codemirror_mode": { "name": "ipython", - "version": "3" + "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", + "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.7" + "version": "3.12.9" } - } -} \ No newline at end of file + }, + "nbformat": 4, + "nbformat_minor": 5 +}