log2ml/0-seminar-wismar/TPOT_NN.ipynb
2024-05-16 18:23:13 +02:00

633 lines
30 KiB
Plaintext
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "A100"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU",
"widgets": {
"application/vnd.jupyter.widget-state+json": {
"301a7a8edb37404da369a8607c8e7f39": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
"IPY_MODEL_550f9e30494a4cedb8265819b10eda9f",
"IPY_MODEL_305334ab788d4c939ef090481cd37014",
"IPY_MODEL_5e6d831eb4da408db14779d0bfa1c3ce"
],
"layout": "IPY_MODEL_59def110c5334bc7865801cc2c2b1201"
}
},
"550f9e30494a4cedb8265819b10eda9f": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_30407bbfe0534d839965fda6ff2ad7a6",
"placeholder": "",
"style": "IPY_MODEL_7a64c0004a5b4f63a5e3e123a1bedb42",
"value": "Optimization Progress: 100%"
}
},
"305334ab788d4c939ef090481cd37014": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_09b1baa49b8e4261aa12044843037665",
"max": 110,
"min": 0,
"orientation": "horizontal",
"style": "IPY_MODEL_b895890e16f04226ae3379d0a2b061b2",
"value": 110
}
},
"5e6d831eb4da408db14779d0bfa1c3ce": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
"layout": "IPY_MODEL_72f3f1a3683e4b7d93549d9bf998f1c7",
"placeholder": "",
"style": "IPY_MODEL_b6633832327f415c85a7d0fc13a1aca3",
"value": " 110/110 [11:46<00:00, 5.36s/pipeline]"
}
},
"59def110c5334bc7865801cc2c2b1201": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": "hidden",
"width": null
}
},
"30407bbfe0534d839965fda6ff2ad7a6": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"7a64c0004a5b4f63a5e3e123a1bedb42": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
"09b1baa49b8e4261aa12044843037665": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"b895890e16f04226ae3379d0a2b061b2": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
"72f3f1a3683e4b7d93549d9bf998f1c7": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
"b6633832327f415c85a7d0fc13a1aca3": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
}
}
}
},
"cells": [
{
"cell_type": "code",
"source": [
"!pip install deap update_checker tqdm stopit xgboost\n",
"!pip install dask[delayed] dask[dataframe] dask-ml fsspec>=0.3.3 distributed>=2.10.0\n",
"!pip install scikit-mdr skrebate\n",
"!pip install tpot"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4deY2TpmK4OV",
"outputId": "13db84e2-85a0-4280-e984-a3969cf32afe"
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Collecting deap\n",
" Downloading deap-1.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (135 kB)\n",
"\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/135.4 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m\u001b[90m━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m71.7/135.4 kB\u001b[0m \u001b[31m2.0 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m135.4/135.4 kB\u001b[0m \u001b[31m2.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hCollecting update_checker\n",
" Downloading update_checker-0.18.0-py3-none-any.whl (7.0 kB)\n",
"Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (4.66.1)\n",
"Collecting stopit\n",
" Downloading stopit-1.1.2.tar.gz (18 kB)\n",
" Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
"Requirement already satisfied: xgboost in /usr/local/lib/python3.10/dist-packages (2.0.3)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from deap) (1.23.5)\n",
"Requirement already satisfied: requests>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from update_checker) (2.31.0)\n",
"Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from xgboost) (1.11.4)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.3.0->update_checker) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.3.0->update_checker) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.3.0->update_checker) (2.0.7)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.3.0->update_checker) (2023.11.17)\n",
"Building wheels for collected packages: stopit\n",
" Building wheel for stopit (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for stopit: filename=stopit-1.1.2-py3-none-any.whl size=11937 sha256=4915a58d356a870dbd83a619555155796895ae1868a1f566245606afe7205dcb\n",
" Stored in directory: /root/.cache/pip/wheels/af/f9/87/bf5b3d565c2a007b4dae9d8142dccc85a9f164e517062dd519\n",
"Successfully built stopit\n",
"Installing collected packages: stopit, deap, update_checker\n",
"Successfully installed deap-1.4.1 stopit-1.1.2 update_checker-0.18.0\n",
"Collecting scikit-mdr\n",
" Downloading scikit_MDR-0.4.5-py3-none-any.whl (15 kB)\n",
"Collecting skrebate\n",
" Downloading skrebate-0.62.tar.gz (19 kB)\n",
" Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from scikit-mdr) (1.23.5)\n",
"Requirement already satisfied: scipy in /usr/local/lib/python3.10/dist-packages (from scikit-mdr) (1.11.4)\n",
"Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (from scikit-mdr) (1.2.2)\n",
"Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (from scikit-mdr) (3.7.1)\n",
"Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->scikit-mdr) (1.2.0)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib->scikit-mdr) (0.12.1)\n",
"Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->scikit-mdr) (4.47.0)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->scikit-mdr) (1.4.5)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->scikit-mdr) (23.2)\n",
"Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->scikit-mdr) (9.4.0)\n",
"Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->scikit-mdr) (3.1.1)\n",
"Requirement already satisfied: python-dateutil>=2.7 in /usr/local/lib/python3.10/dist-packages (from matplotlib->scikit-mdr) (2.8.2)\n",
"Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->scikit-mdr) (1.3.2)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->scikit-mdr) (3.2.0)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.7->matplotlib->scikit-mdr) (1.16.0)\n",
"Building wheels for collected packages: skrebate\n",
" Building wheel for skrebate (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
" Created wheel for skrebate: filename=skrebate-0.62-py3-none-any.whl size=29253 sha256=fabaec25d816d5c70e391d46c1a80a11dd7324cce03c4f89989c8007ebfe3fe3\n",
" Stored in directory: /root/.cache/pip/wheels/dd/67/40/683074a684607162bd0e34dcf7ccdfcab5861c3b2a83286f3a\n",
"Successfully built skrebate\n",
"Installing collected packages: skrebate, scikit-mdr\n",
"Successfully installed scikit-mdr-0.4.5 skrebate-0.62\n",
"Collecting tpot\n",
" Downloading TPOT-0.12.1-py3-none-any.whl (87 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m87.4/87.4 kB\u001b[0m \u001b[31m2.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hRequirement already satisfied: numpy>=1.16.3 in /usr/local/lib/python3.10/dist-packages (from tpot) (1.23.5)\n",
"Requirement already satisfied: scipy>=1.3.1 in /usr/local/lib/python3.10/dist-packages (from tpot) (1.11.4)\n",
"Requirement already satisfied: scikit-learn>=0.22.0 in /usr/local/lib/python3.10/dist-packages (from tpot) (1.2.2)\n",
"Requirement already satisfied: deap>=1.2 in /usr/local/lib/python3.10/dist-packages (from tpot) (1.4.1)\n",
"Requirement already satisfied: update-checker>=0.16 in /usr/local/lib/python3.10/dist-packages (from tpot) (0.18.0)\n",
"Requirement already satisfied: tqdm>=4.36.1 in /usr/local/lib/python3.10/dist-packages (from tpot) (4.66.1)\n",
"Requirement already satisfied: stopit>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from tpot) (1.1.2)\n",
"Requirement already satisfied: pandas>=0.24.2 in /usr/local/lib/python3.10/dist-packages (from tpot) (1.5.3)\n",
"Requirement already satisfied: joblib>=0.13.2 in /usr/local/lib/python3.10/dist-packages (from tpot) (1.3.2)\n",
"Requirement already satisfied: xgboost>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from tpot) (2.0.3)\n",
"Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=0.24.2->tpot) (2.8.2)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas>=0.24.2->tpot) (2023.3.post1)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn>=0.22.0->tpot) (3.2.0)\n",
"Requirement already satisfied: requests>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from update-checker>=0.16->tpot) (2.31.0)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas>=0.24.2->tpot) (1.16.0)\n",
"Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.3.0->update-checker>=0.16->tpot) (3.3.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.3.0->update-checker>=0.16->tpot) (3.6)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.3.0->update-checker>=0.16->tpot) (2.0.7)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.3.0->update-checker>=0.16->tpot) (2023.11.17)\n",
"Installing collected packages: tpot\n",
"Successfully installed tpot-0.12.1\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import torch\n",
"print(torch.__version__)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "5UH_q7gJL3YU",
"outputId": "13ebf85c-bc0a-4c4c-9466-e7194d1d6bc8"
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"2.1.0+cu121\n"
]
}
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 448,
"referenced_widgets": [
"301a7a8edb37404da369a8607c8e7f39",
"550f9e30494a4cedb8265819b10eda9f",
"305334ab788d4c939ef090481cd37014",
"5e6d831eb4da408db14779d0bfa1c3ce",
"59def110c5334bc7865801cc2c2b1201",
"30407bbfe0534d839965fda6ff2ad7a6",
"7a64c0004a5b4f63a5e3e123a1bedb42",
"09b1baa49b8e4261aa12044843037665",
"b895890e16f04226ae3379d0a2b061b2",
"72f3f1a3683e4b7d93549d9bf998f1c7",
"b6633832327f415c85a7d0fc13a1aca3"
]
},
"id": "luBYuKTlKPlX",
"outputId": "b13f1749-e3b4-4152-d48f-cb857d408727"
},
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"Optimization Progress: 0%| | 0/110 [00:00<?, ?pipeline/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "301a7a8edb37404da369a8607c8e7f39"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"\n",
"Generation 1 - Current best internal CV score: 1.0\n",
"\n",
"Generation 2 - Current best internal CV score: 1.0\n",
"\n",
"Generation 3 - Current best internal CV score: 1.0\n",
"\n",
"Generation 4 - Current best internal CV score: 1.0\n",
"\n",
"Generation 5 - Current best internal CV score: 1.0\n",
"\n",
"Generation 6 - Current best internal CV score: 1.0\n",
"\n",
"Generation 7 - Current best internal CV score: 1.0\n",
"\n",
"Generation 8 - Current best internal CV score: 1.0\n",
"\n",
"Generation 9 - Current best internal CV score: 1.0\n",
"\n",
"Generation 10 - Current best internal CV score: 1.0\n",
"\n",
"Best pipeline: PytorchLRClassifier(MaxAbsScaler(SelectPercentile(input_matrix, percentile=57)), batch_size=4, learning_rate=1.0, num_epochs=15, weight_decay=0)\n",
"1.0\n"
]
}
],
"source": [
"from tpot import TPOTClassifier\n",
"from sklearn.datasets import make_blobs\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"X, y = make_blobs(n_samples=100, centers=2, n_features=3, random_state=42)\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.75, test_size=0.25)\n",
"\n",
"clf = TPOTClassifier(config_dict='TPOT NN', template='Selector-Transformer-PytorchLRClassifier',\n",
" verbosity=2, population_size=10, generations=10)\n",
"clf.fit(X_train, y_train)\n",
"print(clf.score(X_test, y_test))\n",
"clf.export('tpot_nn_demo_pipeline.py')"
]
},
{
"cell_type": "code",
"source": [
"from deap import creator\n",
"from sklearn.model_selection import cross_val_score\n",
"import numpy as np\n",
"\n",
"print(clf.score(X_test, y_test))\n",
"# print part of the pipeline dictionary\n",
"print(dict(list(clf.evaluated_individuals_.items())[0:2]))\n",
"# print a pipeline and its values\n",
"pipeline_str = list(clf.evaluated_individuals_.keys())[0]\n",
"print(pipeline_str)\n",
"print(clf.evaluated_individuals_[pipeline_str])\n",
"# convert pipeline string to scikit-learn pipeline object\n",
"optimized_pipeline = creator.Individual.from_string(pipeline_str, clf._pset) # deap object\n",
"fitted_pipeline = clf._toolbox.compile(expr=optimized_pipeline) # scikit-learn pipeline object\n",
"# print scikit-learn pipeline object\n",
"print(fitted_pipeline)\n",
"\n",
"# Fix random state when the operator allows (optional) just for getting consistent CV score\n",
"for step in fitted_pipeline.steps:\n",
" if 'random_state' in step[1].get_params():\n",
" step[1].set_params(random_state=42)\n",
"\n",
"# CV scores from scikit-learn\n",
"scores = cross_val_score(fitted_pipeline, X_train, y_train, cv=5, scoring='accuracy', verbose=0)\n",
"print(np.mean(scores))\n",
"print(clf.evaluated_individuals_[pipeline_str])\n",
""
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "mq0lkWNuO1WU",
"outputId": "f2615618-f61f-4f75-91da-c6fd50745777"
},
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"1.0\n",
"{'PytorchLRClassifier(MaxAbsScaler(SelectPercentile(input_matrix, SelectPercentile__percentile=57)), PytorchLRClassifier__batch_size=4, PytorchLRClassifier__learning_rate=1.0, PytorchLRClassifier__num_epochs=15, PytorchLRClassifier__weight_decay=0)': {'generation': 0, 'mutation_count': 0, 'crossover_count': 0, 'predecessor': ('ROOT',), 'operator_count': 3, 'internal_cv_score': 1.0}, 'PytorchLRClassifier(OneHotEncoder(SelectFwe(input_matrix, SelectFwe__alpha=0.039), OneHotEncoder__minimum_fraction=0.2, OneHotEncoder__sparse=False, OneHotEncoder__threshold=10), PytorchLRClassifier__batch_size=8, PytorchLRClassifier__learning_rate=0.01, PytorchLRClassifier__num_epochs=15, PytorchLRClassifier__weight_decay=0.001)': {'generation': 0, 'mutation_count': 0, 'crossover_count': 0, 'predecessor': ('ROOT',), 'operator_count': 3, 'internal_cv_score': 1.0}}\n",
"PytorchLRClassifier(MaxAbsScaler(SelectPercentile(input_matrix, SelectPercentile__percentile=57)), PytorchLRClassifier__batch_size=4, PytorchLRClassifier__learning_rate=1.0, PytorchLRClassifier__num_epochs=15, PytorchLRClassifier__weight_decay=0)\n",
"{'generation': 0, 'mutation_count': 0, 'crossover_count': 0, 'predecessor': ('ROOT',), 'operator_count': 3, 'internal_cv_score': 1.0}\n",
"Pipeline(steps=[('selectpercentile', SelectPercentile(percentile=57)),\n",
" ('maxabsscaler', MaxAbsScaler()),\n",
" ('pytorchlrclassifier',\n",
" PytorchLRClassifier(batch_size=4, learning_rate=1.0,\n",
" num_epochs=15, weight_decay=0))])\n",
"1.0\n",
"{'generation': 0, 'mutation_count': 0, 'crossover_count': 0, 'predecessor': ('ROOT',), 'operator_count': 3, 'internal_cv_score': 1.0}\n"
]
}
]
}
]
}