Tpot NN Colab test

This commit is contained in:
Marius Ciepluch 2024-01-16 13:51:20 +01:00 committed by GitHub
commit 66e9e67473
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

633
TPOT_NN.ipynb Normal file
View File

@ -0,0 +1,633 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "A100"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU",
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"301a7a8edb37404da369a8607c8e7f39": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_550f9e30494a4cedb8265819b10eda9f",
"IPY_MODEL_305334ab788d4c939ef090481cd37014",
"IPY_MODEL_5e6d831eb4da408db14779d0bfa1c3ce"
],
"layout": "IPY_MODEL_59def110c5334bc7865801cc2c2b1201"
}
},
"550f9e30494a4cedb8265819b10eda9f": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_30407bbfe0534d839965fda6ff2ad7a6",
"placeholder": "",
"style": "IPY_MODEL_7a64c0004a5b4f63a5e3e123a1bedb42",
"value": "Optimization Progress: 100%"
}
},
"305334ab788d4c939ef090481cd37014": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_09b1baa49b8e4261aa12044843037665",
"max": 110,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_b895890e16f04226ae3379d0a2b061b2",
"value": 110
}
},
"5e6d831eb4da408db14779d0bfa1c3ce": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_72f3f1a3683e4b7d93549d9bf998f1c7",
"placeholder": "",
"style": "IPY_MODEL_b6633832327f415c85a7d0fc13a1aca3",
"value": " 110/110 [11:46<00:00, 5.36s/pipeline]"
}
},
"59def110c5334bc7865801cc2c2b1201": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": "hidden",
"width": null
}
},
"30407bbfe0534d839965fda6ff2ad7a6": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"7a64c0004a5b4f63a5e3e123a1bedb42": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"09b1baa49b8e4261aa12044843037665": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"b895890e16f04226ae3379d0a2b061b2": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"72f3f1a3683e4b7d93549d9bf998f1c7": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"b6633832327f415c85a7d0fc13a1aca3": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
}
}
}
},
"cells": [
{
"cell_type": "code",
"source": [
"!pip install deap update_checker tqdm stopit xgboost\n",
"!pip install dask[delayed] dask[dataframe] dask-ml fsspec>=0.3.3 distributed>=2.10.0\n",
"!pip install scikit-mdr skrebate\n",
"!pip install tpot"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4deY2TpmK4OV",
"outputId": "13db84e2-85a0-4280-e984-a3969cf32afe"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting deap\n",
" Downloading deap-1.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (135 kB)\n",
"\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/135.4 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m\u001b[90m━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m71.7/135.4 kB\u001b[0m \u001b[31m2.0 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m135.4/135.4 kB\u001b[0m \u001b[31m2.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting update_checker\n",
" Downloading update_checker-0.18.0-py3-none-any.whl (7.0 kB)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (4.66.1)\n",
"Collecting stopit\n",
" Downloading stopit-1.1.2.tar.gz (18 kB)\n",
" Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
"Requirement already satisfied: xgboost in /usr/local/lib/python3.10/dist-packages (2.0.3)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from deap) (1.23.5)\n",
"Requirement already satisfied: requests>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from update_checker) (2.31.0)\n",
"Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from xgboost) (1.11.4)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.3.0->update_checker) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.3.0->update_checker) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.3.0->update_checker) (2.0.7)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.3.0->update_checker) (2023.11.17)\n",
"Building wheels for collected packages: stopit\n",
" Building wheel for stopit (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for stopit: filename=stopit-1.1.2-py3-none-any.whl size=11937 sha256=4915a58d356a870dbd83a619555155796895ae1868a1f566245606afe7205dcb\n",
" Stored in directory: /root/.cache/pip/wheels/af/f9/87/bf5b3d565c2a007b4dae9d8142dccc85a9f164e517062dd519\n",
"Successfully built stopit\n",
"Installing collected packages: stopit, deap, update_checker\n",
"Successfully installed deap-1.4.1 stopit-1.1.2 update_checker-0.18.0\n",
"Collecting scikit-mdr\n",
" Downloading scikit_MDR-0.4.5-py3-none-any.whl (15 kB)\n",
"Collecting skrebate\n",
" Downloading skrebate-0.62.tar.gz (19 kB)\n",
" Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from scikit-mdr) (1.23.5)\n",
"Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from scikit-mdr) (1.11.4)\n",
"Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (from scikit-mdr) (1.2.2)\n",
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (from scikit-mdr) (3.7.1)\n",
"Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->scikit-mdr) (1.2.0)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib->scikit-mdr) (0.12.1)\n",
"Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->scikit-mdr) (4.47.0)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->scikit-mdr) (1.4.5)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->scikit-mdr) (23.2)\n",
"Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->scikit-mdr) (9.4.0)\n",
"Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->scikit-mdr) (3.1.1)\n",
"Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib->scikit-mdr) (2.8.2)\n",
"Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->scikit-mdr) (1.3.2)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->scikit-mdr) (3.2.0)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib->scikit-mdr) (1.16.0)\n",
"Building wheels for collected packages: skrebate\n",
" Building wheel for skrebate (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for skrebate: filename=skrebate-0.62-py3-none-any.whl size=29253 sha256=fabaec25d816d5c70e391d46c1a80a11dd7324cce03c4f89989c8007ebfe3fe3\n",
" Stored in directory: /root/.cache/pip/wheels/dd/67/40/683074a684607162bd0e34dcf7ccdfcab5861c3b2a83286f3a\n",
"Successfully built skrebate\n",
"Installing collected packages: skrebate, scikit-mdr\n",
"Successfully installed scikit-mdr-0.4.5 skrebate-0.62\n",
"Collecting tpot\n",
" Downloading TPOT-0.12.1-py3-none-any.whl (87 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m87.4/87.4 kB\u001b[0m \u001b[31m2.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: numpy>=1.16.3 in /usr/local/lib/python3.10/dist-packages (from tpot) (1.23.5)\n",
"Requirement already satisfied: scipy>=1.3.1 in /usr/local/lib/python3.10/dist-packages (from tpot) (1.11.4)\n",
"Requirement already satisfied: scikit-learn>=0.22.0 in /usr/local/lib/python3.10/dist-packages (from tpot) (1.2.2)\n",
"Requirement already satisfied: deap>=1.2 in /usr/local/lib/python3.10/dist-packages (from tpot) (1.4.1)\n",
"Requirement already satisfied: update-checker>=0.16 in /usr/local/lib/python3.10/dist-packages (from tpot) (0.18.0)\n",
"Requirement already satisfied: tqdm>=4.36.1 in /usr/local/lib/python3.10/dist-packages (from tpot) (4.66.1)\n",
"Requirement already satisfied: stopit>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from tpot) (1.1.2)\n",
"Requirement already satisfied: pandas>=0.24.2 in /usr/local/lib/python3.10/dist-packages (from tpot) (1.5.3)\n",
"Requirement already satisfied: joblib>=0.13.2 in /usr/local/lib/python3.10/dist-packages (from tpot) (1.3.2)\n",
"Requirement already satisfied: xgboost>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from tpot) (2.0.3)\n",
"Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=0.24.2->tpot) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=0.24.2->tpot) (2023.3.post1)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.22.0->tpot) (3.2.0)\n",
"Requirement already satisfied: requests>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from update-checker>=0.16->tpot) (2.31.0)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas>=0.24.2->tpot) (1.16.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.3.0->update-checker>=0.16->tpot) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.3.0->update-checker>=0.16->tpot) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.3.0->update-checker>=0.16->tpot) (2.0.7)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.3.0->update-checker>=0.16->tpot) (2023.11.17)\n",
"Installing collected packages: tpot\n",
"Successfully installed tpot-0.12.1\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import torch\n",
"print(torch.__version__)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5UH_q7gJL3YU",
"outputId": "13ebf85c-bc0a-4c4c-9466-e7194d1d6bc8"
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"2.1.0+cu121\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 448,
"referenced_widgets": [
"301a7a8edb37404da369a8607c8e7f39",
"550f9e30494a4cedb8265819b10eda9f",
"305334ab788d4c939ef090481cd37014",
"5e6d831eb4da408db14779d0bfa1c3ce",
"59def110c5334bc7865801cc2c2b1201",
"30407bbfe0534d839965fda6ff2ad7a6",
"7a64c0004a5b4f63a5e3e123a1bedb42",
"09b1baa49b8e4261aa12044843037665",
"b895890e16f04226ae3379d0a2b061b2",
"72f3f1a3683e4b7d93549d9bf998f1c7",
"b6633832327f415c85a7d0fc13a1aca3"
]
},
"id": "luBYuKTlKPlX",
"outputId": "b13f1749-e3b4-4152-d48f-cb857d408727"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"Optimization Progress: 0%| | 0/110 [00:00<?, ?pipeline/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "301a7a8edb37404da369a8607c8e7f39"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"Generation 1 - Current best internal CV score: 1.0\n",
"\n",
"Generation 2 - Current best internal CV score: 1.0\n",
"\n",
"Generation 3 - Current best internal CV score: 1.0\n",
"\n",
"Generation 4 - Current best internal CV score: 1.0\n",
"\n",
"Generation 5 - Current best internal CV score: 1.0\n",
"\n",
"Generation 6 - Current best internal CV score: 1.0\n",
"\n",
"Generation 7 - Current best internal CV score: 1.0\n",
"\n",
"Generation 8 - Current best internal CV score: 1.0\n",
"\n",
"Generation 9 - Current best internal CV score: 1.0\n",
"\n",
"Generation 10 - Current best internal CV score: 1.0\n",
"\n",
"Best pipeline: PytorchLRClassifier(MaxAbsScaler(SelectPercentile(input_matrix, percentile=57)), batch_size=4, learning_rate=1.0, num_epochs=15, weight_decay=0)\n",
"1.0\n"
]
}
],
"source": [
"from tpot import TPOTClassifier\n",
"from sklearn.datasets import make_blobs\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"X, y = make_blobs(n_samples=100, centers=2, n_features=3, random_state=42)\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.75, test_size=0.25)\n",
"\n",
"clf = TPOTClassifier(config_dict='TPOT NN', template='Selector-Transformer-PytorchLRClassifier',\n",
" verbosity=2, population_size=10, generations=10)\n",
"clf.fit(X_train, y_train)\n",
"print(clf.score(X_test, y_test))\n",
"clf.export('tpot_nn_demo_pipeline.py')"
]
},
{
"cell_type": "code",
"source": [
"from deap import creator\n",
"from sklearn.model_selection import cross_val_score\n",
"import numpy as np\n",
"\n",
"print(clf.score(X_test, y_test))\n",
"# print part of the pipeline dictionary\n",
"print(dict(list(clf.evaluated_individuals_.items())[0:2]))\n",
"# print a pipeline and its values\n",
"pipeline_str = list(clf.evaluated_individuals_.keys())[0]\n",
"print(pipeline_str)\n",
"print(clf.evaluated_individuals_[pipeline_str])\n",
"# convert pipeline string to scikit-learn pipeline object\n",
"optimized_pipeline = creator.Individual.from_string(pipeline_str, clf._pset) # deap object\n",
"fitted_pipeline = clf._toolbox.compile(expr=optimized_pipeline) # scikit-learn pipeline object\n",
"# print scikit-learn pipeline object\n",
"print(fitted_pipeline)\n",
"\n",
"# Fix random state when the operator allows (optional) just for getting consistent CV score\n",
"for step in fitted_pipeline.steps:\n",
" if 'random_state' in step[1].get_params():\n",
" step[1].set_params(random_state=42)\n",
"\n",
"# CV scores from scikit-learn\n",
"scores = cross_val_score(fitted_pipeline, X_train, y_train, cv=5, scoring='accuracy', verbose=0)\n",
"print(np.mean(scores))\n",
"print(clf.evaluated_individuals_[pipeline_str])\n",
""
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "mq0lkWNuO1WU",
"outputId": "f2615618-f61f-4f75-91da-c6fd50745777"
},
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"1.0\n",
"{'PytorchLRClassifier(MaxAbsScaler(SelectPercentile(input_matrix, SelectPercentile__percentile=57)), PytorchLRClassifier__batch_size=4, PytorchLRClassifier__learning_rate=1.0, PytorchLRClassifier__num_epochs=15, PytorchLRClassifier__weight_decay=0)': {'generation': 0, 'mutation_count': 0, 'crossover_count': 0, 'predecessor': ('ROOT',), 'operator_count': 3, 'internal_cv_score': 1.0}, 'PytorchLRClassifier(OneHotEncoder(SelectFwe(input_matrix, SelectFwe__alpha=0.039), OneHotEncoder__minimum_fraction=0.2, OneHotEncoder__sparse=False, OneHotEncoder__threshold=10), PytorchLRClassifier__batch_size=8, PytorchLRClassifier__learning_rate=0.01, PytorchLRClassifier__num_epochs=15, PytorchLRClassifier__weight_decay=0.001)': {'generation': 0, 'mutation_count': 0, 'crossover_count': 0, 'predecessor': ('ROOT',), 'operator_count': 3, 'internal_cv_score': 1.0}}\n",
"PytorchLRClassifier(MaxAbsScaler(SelectPercentile(input_matrix, SelectPercentile__percentile=57)), PytorchLRClassifier__batch_size=4, PytorchLRClassifier__learning_rate=1.0, PytorchLRClassifier__num_epochs=15, PytorchLRClassifier__weight_decay=0)\n",
"{'generation': 0, 'mutation_count': 0, 'crossover_count': 0, 'predecessor': ('ROOT',), 'operator_count': 3, 'internal_cv_score': 1.0}\n",
"Pipeline(steps=[('selectpercentile', SelectPercentile(percentile=57)),\n",
" ('maxabsscaler', MaxAbsScaler()),\n",
" ('pytorchlrclassifier',\n",
" PytorchLRClassifier(batch_size=4, learning_rate=1.0,\n",
" num_epochs=15, weight_decay=0))])\n",
"1.0\n",
"{'generation': 0, 'mutation_count': 0, 'crossover_count': 0, 'predecessor': ('ROOT',), 'operator_count': 3, 'internal_cv_score': 1.0}\n"
]
}
]
}
]
}