log2ml/2-3-2-open-source-deep-learning-models-for-encoding-taks/Comparison_of_Encoder_Models.ipynb

3728 lines
892 KiB
Plaintext
Raw Permalink Normal View History

{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"machine_shape": "hm",
"gpuType": "A100"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU",
"widgets": {
"application/vnd.jupyter.widget-state+json": {
2024-07-29 12:01:01 +00:00
"bd6b607607d34e3c96de67f9a103f19d": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
2024-07-29 12:01:01 +00:00
"IPY_MODEL_db9863f966f24e6fad90dc2e6acae785",
"IPY_MODEL_50a29d3705bb49d986f548374fef81bc",
"IPY_MODEL_d2ca4d5e2863449ba87ceb117e3df7da"
],
2024-07-29 12:01:01 +00:00
"layout": "IPY_MODEL_3031182ea08a43e2a75d12d67ceef101"
}
},
2024-07-29 12:01:01 +00:00
"db9863f966f24e6fad90dc2e6acae785": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
2024-07-29 12:01:01 +00:00
"layout": "IPY_MODEL_46aa4f53040c485387ae03dbc195b34b",
"placeholder": "",
2024-07-29 12:01:01 +00:00
"style": "IPY_MODEL_30bd5d604a7041e9ab193e34be502137",
"value": "100%"
}
},
2024-07-29 12:01:01 +00:00
"50a29d3705bb49d986f548374fef81bc": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
2024-07-29 12:01:01 +00:00
"layout": "IPY_MODEL_ed7ee81cf876476ebcd6c6d29b954de4",
"max": 1425741623,
"min": 0,
"orientation": "horizontal",
2024-07-29 12:01:01 +00:00
"style": "IPY_MODEL_3528f43e1f594fa6a65ae43a926e9d95",
"value": 1425741623
}
},
2024-07-29 12:01:01 +00:00
"d2ca4d5e2863449ba87ceb117e3df7da": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
2024-07-29 12:01:01 +00:00
"layout": "IPY_MODEL_27281ae3d5d649829478a230b68bca47",
"placeholder": "",
2024-07-29 12:01:01 +00:00
"style": "IPY_MODEL_f2764b194d4e478f97ae5e09f4afb86c",
"value": "1.43G/1.43G[00:16<00:00,87.7MiB/s]"
}
},
2024-07-29 12:01:01 +00:00
"3031182ea08a43e2a75d12d67ceef101": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
2024-07-29 12:01:01 +00:00
"46aa4f53040c485387ae03dbc195b34b": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
2024-07-29 12:01:01 +00:00
"30bd5d604a7041e9ab193e34be502137": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
2024-07-29 12:01:01 +00:00
"ed7ee81cf876476ebcd6c6d29b954de4": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
2024-07-29 12:01:01 +00:00
"3528f43e1f594fa6a65ae43a926e9d95": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
2024-07-29 12:01:01 +00:00
"27281ae3d5d649829478a230b68bca47": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
2024-07-29 12:01:01 +00:00
"f2764b194d4e478f97ae5e09f4afb86c": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
2024-07-29 12:01:01 +00:00
"b5271fc8f29a469696c63aa94f4c66c8": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
2024-07-29 12:01:01 +00:00
"IPY_MODEL_6892e1bb32de4655bd093e50aad979b9",
"IPY_MODEL_56ff1bf94fac41e2a4bfe6c642db64b9",
"IPY_MODEL_ee15e4b938df499da5b1d2838552962f"
],
2024-07-29 12:01:01 +00:00
"layout": "IPY_MODEL_b5dd37fa45c24680b93c21c1969490f2"
}
},
2024-07-29 12:01:01 +00:00
"6892e1bb32de4655bd093e50aad979b9": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
2024-07-29 12:01:01 +00:00
"layout": "IPY_MODEL_df2385efc48a411c8e3584c954f23a26",
"placeholder": "",
2024-07-29 12:01:01 +00:00
"style": "IPY_MODEL_63bb98eab8c74dfabea98b85ff397b77",
"value": "100%"
}
},
2024-07-29 12:01:01 +00:00
"56ff1bf94fac41e2a4bfe6c642db64b9": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
2024-07-29 12:01:01 +00:00
"layout": "IPY_MODEL_cdbb5e7a249c436b809e125f496eb82e",
"max": 594036,
"min": 0,
"orientation": "horizontal",
2024-07-29 12:01:01 +00:00
"style": "IPY_MODEL_3b4d2cf14f8c496ab740a2485155442b",
"value": 594036
}
},
2024-07-29 12:01:01 +00:00
"ee15e4b938df499da5b1d2838552962f": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
2024-07-29 12:01:01 +00:00
"layout": "IPY_MODEL_59dd159940064a309eb1cad67d5faa7a",
"placeholder": "",
2024-07-29 12:01:01 +00:00
"style": "IPY_MODEL_f84bb79f52664a91b1a6ee637281c285",
"value": "594k/594k[00:00<00:00,14.0MiB/s]"
}
},
2024-07-29 12:01:01 +00:00
"b5dd37fa45c24680b93c21c1969490f2": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
2024-07-29 12:01:01 +00:00
"df2385efc48a411c8e3584c954f23a26": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
2024-07-29 12:01:01 +00:00
"63bb98eab8c74dfabea98b85ff397b77": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
2024-07-29 12:01:01 +00:00
"cdbb5e7a249c436b809e125f496eb82e": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
2024-07-29 12:01:01 +00:00
"3b4d2cf14f8c496ab740a2485155442b": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
2024-07-29 12:01:01 +00:00
"59dd159940064a309eb1cad67d5faa7a": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
2024-07-29 12:01:01 +00:00
"f84bb79f52664a91b1a6ee637281c285": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
2024-07-29 12:01:01 +00:00
"cf8262210b7c42bc8db80e1a102acf19": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
2024-07-29 12:01:01 +00:00
"IPY_MODEL_509c0ac5a7624e2bb8618b978c8a80dc",
"IPY_MODEL_56c17b10c28c4fa9a6db05aa60a71f5e",
"IPY_MODEL_4b85c54d069e4a8d87520dbdad4c2c89"
],
2024-07-29 12:01:01 +00:00
"layout": "IPY_MODEL_6537feb3ffd24d2ea90b27a7a0a1aeb0"
}
},
2024-07-29 12:01:01 +00:00
"509c0ac5a7624e2bb8618b978c8a80dc": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
2024-07-29 12:01:01 +00:00
"layout": "IPY_MODEL_4e2585503cca48f8a0428cc62d02e5f1",
"placeholder": "",
2024-07-29 12:01:01 +00:00
"style": "IPY_MODEL_b6ea52fd1054458e9458865f8f40ff84",
"value": "config.json:100%"
}
},
2024-07-29 12:01:01 +00:00
"56c17b10c28c4fa9a6db05aa60a71f5e": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
2024-07-29 12:01:01 +00:00
"layout": "IPY_MODEL_d70cc590fb28457b9ca7f74d8aa425f1",
"max": 694,
"min": 0,
"orientation": "horizontal",
2024-07-29 12:01:01 +00:00
"style": "IPY_MODEL_a9116211ec3d407d9278f744d41507e6",
"value": 694
}
},
2024-07-29 12:01:01 +00:00
"4b85c54d069e4a8d87520dbdad4c2c89": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
2024-07-29 12:01:01 +00:00
"layout": "IPY_MODEL_d706ec4068874b6eb52918afad80982e",
"placeholder": "",
2024-07-29 12:01:01 +00:00
"style": "IPY_MODEL_9c7b2917eb984de88aa6dbdec1bfa9cb",
"value": "694/694[00:00<00:00,52.6kB/s]"
}
},
2024-07-29 12:01:01 +00:00
"6537feb3ffd24d2ea90b27a7a0a1aeb0": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
2024-07-29 12:01:01 +00:00
"4e2585503cca48f8a0428cc62d02e5f1": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
2024-07-29 12:01:01 +00:00
"b6ea52fd1054458e9458865f8f40ff84": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
2024-07-29 12:01:01 +00:00
"d70cc590fb28457b9ca7f74d8aa425f1": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
2024-07-29 12:01:01 +00:00
"a9116211ec3d407d9278f744d41507e6": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
2024-07-29 12:01:01 +00:00
"d706ec4068874b6eb52918afad80982e": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
2024-07-29 12:01:01 +00:00
"9c7b2917eb984de88aa6dbdec1bfa9cb": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
2024-07-29 12:01:01 +00:00
"c69abb9efb6b411fa4ee31b0c1910bae": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HBoxModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HBoxModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HBoxView",
"box_style": "",
"children": [
2024-07-29 12:01:01 +00:00
"IPY_MODEL_1170d27bc06e48bc985700e1703887b3",
"IPY_MODEL_6264397e6794478fa54e161f4434cc8b",
"IPY_MODEL_e5acc8060d7c4d25b6c13bb486eb5c39"
],
2024-07-29 12:01:01 +00:00
"layout": "IPY_MODEL_6ff11b9092db4dab98a68505955a612c"
}
},
2024-07-29 12:01:01 +00:00
"1170d27bc06e48bc985700e1703887b3": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
2024-07-29 12:01:01 +00:00
"layout": "IPY_MODEL_b9f3902a7a844da0ab7d9bf5c2869db7",
"placeholder": "",
2024-07-29 12:01:01 +00:00
"style": "IPY_MODEL_86fa6d3c75e94c109ca0d44ddc7cb6e4",
"value": "pytorch_model.bin:100%"
}
},
2024-07-29 12:01:01 +00:00
"6264397e6794478fa54e161f4434cc8b": {
"model_module": "@jupyter-widgets/controls",
"model_name": "FloatProgressModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "FloatProgressModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "ProgressView",
"bar_style": "success",
"description": "",
"description_tooltip": null,
2024-07-29 12:01:01 +00:00
"layout": "IPY_MODEL_4d7f1a3f6c044343a096951ed2a58603",
"max": 597257159,
"min": 0,
"orientation": "horizontal",
2024-07-29 12:01:01 +00:00
"style": "IPY_MODEL_ef324ed1c14d4e9aac3312ed637f5789",
"value": 597257159
}
},
2024-07-29 12:01:01 +00:00
"e5acc8060d7c4d25b6c13bb486eb5c39": {
"model_module": "@jupyter-widgets/controls",
"model_name": "HTMLModel",
"model_module_version": "1.5.0",
"state": {
"_dom_classes": [],
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "HTMLModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/controls",
"_view_module_version": "1.5.0",
"_view_name": "HTMLView",
"description": "",
"description_tooltip": null,
2024-07-29 12:01:01 +00:00
"layout": "IPY_MODEL_d19fd3e721b84b37b2f4da2bf03008de",
"placeholder": "",
2024-07-29 12:01:01 +00:00
"style": "IPY_MODEL_5d3d80342c864f419e64a323051e3509",
"value": "597M/597M[00:02<00:00,299MB/s]"
}
},
2024-07-29 12:01:01 +00:00
"6ff11b9092db4dab98a68505955a612c": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
2024-07-29 12:01:01 +00:00
"b9f3902a7a844da0ab7d9bf5c2869db7": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
2024-07-29 12:01:01 +00:00
"86fa6d3c75e94c109ca0d44ddc7cb6e4": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
},
2024-07-29 12:01:01 +00:00
"4d7f1a3f6c044343a096951ed2a58603": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
2024-07-29 12:01:01 +00:00
"ef324ed1c14d4e9aac3312ed637f5789": {
"model_module": "@jupyter-widgets/controls",
"model_name": "ProgressStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "ProgressStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"bar_color": null,
"description_width": ""
}
},
2024-07-29 12:01:01 +00:00
"d19fd3e721b84b37b2f4da2bf03008de": {
"model_module": "@jupyter-widgets/base",
"model_name": "LayoutModel",
"model_module_version": "1.2.0",
"state": {
"_model_module": "@jupyter-widgets/base",
"_model_module_version": "1.2.0",
"_model_name": "LayoutModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "LayoutView",
"align_content": null,
"align_items": null,
"align_self": null,
"border": null,
"bottom": null,
"display": null,
"flex": null,
"flex_flow": null,
"grid_area": null,
"grid_auto_columns": null,
"grid_auto_flow": null,
"grid_auto_rows": null,
"grid_column": null,
"grid_gap": null,
"grid_row": null,
"grid_template_areas": null,
"grid_template_columns": null,
"grid_template_rows": null,
"height": null,
"justify_content": null,
"justify_items": null,
"left": null,
"margin": null,
"max_height": null,
"max_width": null,
"min_height": null,
"min_width": null,
"object_fit": null,
"object_position": null,
"order": null,
"overflow": null,
"overflow_x": null,
"overflow_y": null,
"padding": null,
"right": null,
"top": null,
"visibility": null,
"width": null
}
},
2024-07-29 12:01:01 +00:00
"5d3d80342c864f419e64a323051e3509": {
"model_module": "@jupyter-widgets/controls",
"model_name": "DescriptionStyleModel",
"model_module_version": "1.5.0",
"state": {
"_model_module": "@jupyter-widgets/controls",
"_model_module_version": "1.5.0",
"_model_name": "DescriptionStyleModel",
"_view_count": null,
"_view_module": "@jupyter-widgets/base",
"_view_module_version": "1.2.0",
"_view_name": "StyleView",
"description_width": ""
}
}
}
}
},
"cells": [
{
"cell_type": "markdown",
"source": [
"# Installations\n",
"\n",
"The installation is mostly automated.\n",
"\n",
"A file in the same directory named \"thesis_ro\" will be required, which should contain your GitHub read-only token.\n",
"\n",
"The file has one line:\n",
"\n",
"\n",
"`echo \"GITHUB_PERSONAL_ACCESS_TOKEN=\"ghp_...\" > thesis_ro`"
],
"metadata": {
"id": "4-p_B7FGB9mg"
}
},
{
"cell_type": "markdown",
"source": [],
"metadata": {
"id": "QZT7W5H2hezF"
}
},
{
"cell_type": "markdown",
"source": [
">[Installations](#scrollTo=4-p_B7FGB9mg)\n",
"\n",
">[Setup data](#scrollTo=b7jMEfm3CAp1)\n",
"\n",
">[Train custom Tokenizer with complex dataset](#scrollTo=dAYGIiUlCQM9)\n",
"\n",
">[Download dataset to use](#scrollTo=RtdpdugyCl9v)\n",
"\n",
">[ETL dataset](#scrollTo=wk7RQHMuCxjr)\n",
"\n",
">[Max token length](#scrollTo=QdLFssPYC-oG)\n",
"\n",
">[Load LinFormer with max. token length](#scrollTo=ARRqrF8XDG6_)\n",
"\n",
">[Vectorize the text column in that DataFrame](#scrollTo=AzxrG_RmDLr1)\n",
"\n",
">[Convert the vectors to a NumPy array for statistical analysis](#scrollTo=3Gxep-miEAVm)\n",
"\n",
">[UMAP dimensionality check - scattered or clustered?](#scrollTo=zZVCSpvsEJIq)\n",
"\n",
">[Silhouette score via KMeans clustering](#scrollTo=56wtMVhhEVzf)\n",
"\n",
">[Distribution of the Cosine similarity](#scrollTo=8yhTWsSnEl30)\n",
"\n",
">[Comparison model: LongFormer](#scrollTo=FmwKV-BK-rKX)\n",
"\n",
">[Initialize and use LongFormer with the same pre-trained Tokenizer](#scrollTo=J6REcr7hE1xz)\n",
"\n",
">[Approaching Quantile Transoformer for feature scaling for LinFormer model vectors](#scrollTo=SxhOQqiHhPju)\n",
"\n",
">[Quantile Transformer vs. Standard Scaler for feature scaling](#scrollTo=kgdy9N8lhEaF)\n",
"\n"
],
"metadata": {
"colab_type": "toc",
"id": "V-wAtykHhidb"
}
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "5QU9CtLHsJr_"
},
"outputs": [],
"source": [
"import sys\n",
"import os\n",
"import subprocess\n",
"\n",
"IN_COLAB = 'google.colab' in sys.modules\n",
"\n",
"if not IN_COLAB:\n",
" pass\n",
"\n",
"else:\n",
" subprocess.run('''\n",
" source <(curl -s https://raw.githubusercontent.com/norandom/log2ml/main/dependencies/install.sh)\n",
" ''',\n",
" shell=True, check=True, executable='/bin/bash')\n",
"\n"
]
},
{
"cell_type": "markdown",
"source": [
"# Setup data\n",
"\n",
"The data setup is automated, based on GitHub release files. Kaggle or HuggingFace releases are for later, once the work has been completed."
],
"metadata": {
"id": "b7jMEfm3CAp1"
}
},
{
"cell_type": "code",
"source": [
"from dotenv import load_dotenv\n",
"import os\n",
"\n",
"load_dotenv(\"thesis_ro\", verbose=True) # take environment variables from the file\n",
"token = os.getenv('GITHUB_PERSONAL_ACCESS_TOKEN')\n",
"if len(token) > 0:\n",
" print(\"ok\")\n",
"else:\n",
" print(\"no token\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "B4wCRu4DAC0Z",
"outputId": "e5383477-f981-4a9e-d448-6707b4f73737"
},
"execution_count": 2,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"ok\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from github import Github\n",
"import requests\n",
"from tqdm.notebook import tqdm\n",
"\n",
"\n",
"def get_specific_file_from_tagged_release(token, repo_name, tag_name, filename):\n",
" g = Github(token)\n",
" repo = g.get_repo(repo_name)\n",
" releases = repo.get_releases()\n",
"\n",
" for release in releases:\n",
" if release.tag_name == tag_name:\n",
" for asset in release.get_assets():\n",
" if asset.name == filename:\n",
" return asset.url\n",
" print(\"File not found. Try get_specific_file_from_latest_release() instead.\")\n",
" return None\n",
"\n",
"def get_specific_file_from_latest_release(token, repo_name, filename):\n",
" g = Github(token)\n",
" repo = g.get_repo(repo_name)\n",
" release = repo.get_latest_release()\n",
"\n",
" for asset in release.get_assets():\n",
" if asset.name == filename:\n",
" return asset.url # Use asset.url which points to API URL needing headers\n",
"\n",
"def download_file(url, token, save_path):\n",
" headers = {'Authorization': f'token {token}', 'Accept': 'application/octet-stream'}\n",
" # First request to handle GitHub's redirection and authentication properly\n",
" with requests.get(url, headers=headers, stream=True) as initial_response:\n",
" initial_response.raise_for_status() # Ensure the initial request is successful\n",
" # Follow redirection if necessary, maintaining headers\n",
" if initial_response.history:\n",
" url = initial_response.url # Updated URL after redirection\n",
"\n",
" # Now, proceed with downloading the file\n",
" with requests.get(url, headers=headers, stream=True) as response:\n",
" response.raise_for_status()\n",
" total_size_in_bytes = int(response.headers.get('content-length', 0))\n",
" block_size = 1024\n",
"\n",
" progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)\n",
" with open(save_path, 'wb') as file:\n",
" for data in response.iter_content(block_size):\n",
" progress_bar.update(len(data))\n",
" file.write(data)\n",
" progress_bar.close()\n",
"\n",
" if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:\n",
" print(\"ERROR, something went wrong\")\n",
" else:\n",
" print(f\"File downloaded successfully and saved as {save_path}\")\n",
"\n",
"# Your GitHub token\n",
"github_token = token\n",
"\n",
"# Repository name\n",
"repository_name = \"norandom/log2ml\"\n",
"\n",
"# File name to search for\n",
"file_name = \"lab_logs_normal_activity_may_11_2024.csv\"\n",
"\n",
"# Get the download URL of the specific file\n",
"# download_url = get_specific_file_from_latest_release(github_token, repository_name, file_name)\n",
"download_url = get_specific_file_from_tagged_release(github_token, repository_name, \"foundations\", file_name)\n",
"print(download_url)\n",
"\n",
"if download_url:\n",
" local_file_path = \"lab_logs_normal_activity_may_11_2024.csv\"\n",
" download_file(download_url, github_token, local_file_path)\n",
"else:\n",
" print(\"File not found.\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85,
"referenced_widgets": [
"bd6b607607d34e3c96de67f9a103f19d",
"db9863f966f24e6fad90dc2e6acae785",
"50a29d3705bb49d986f548374fef81bc",
"d2ca4d5e2863449ba87ceb117e3df7da",
"3031182ea08a43e2a75d12d67ceef101",
"46aa4f53040c485387ae03dbc195b34b",
"30bd5d604a7041e9ab193e34be502137",
"ed7ee81cf876476ebcd6c6d29b954de4",
"3528f43e1f594fa6a65ae43a926e9d95",
"27281ae3d5d649829478a230b68bca47",
"f2764b194d4e478f97ae5e09f4afb86c"
]
},
"id": "ORNebW15Bb4B",
"outputId": "58583b66-046e-4b57-c892-f877d5fa55ad"
},
"execution_count": 3,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"https://api.github.com/repos/norandom/log2ml/releases/assets/168248751\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
" 0%| | 0.00/1.43G [00:00<?, ?iB/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "bd6b607607d34e3c96de67f9a103f19d"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"File downloaded successfully and saved as lab_logs_normal_activity_may_11_2024.csv\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Train custom Tokenizer with complex dataset\n",
"\n",
"The dataset contains security-agent telemetry from MS Sysmon.\n",
"Therefore a custom Tokenizer is needed, given that the content isn't 100% natural language.\n",
"\n",
"BPE stands for Byte Pair Encoding Tokenization\n",
"\n",
"\"Its used by a lot of Transformer models, including GPT, GPT-2, RoBERTa, BART, and DeBERTa.\" [1]\n",
"\n",
"Given that LongFormer and Linformer are Transformer Network architectures, BPE is a compatible choice.\n",
"\n",
"Parsing the data as a CSV or colon-separated values isn't preferable, because the implementation goal is a generically reuseable log- / trace-data vectorizer. Approaches like bag-of-words or one-hot-encoding may become too specific to the data.\n",
"\n",
"[1] https://huggingface.co/learn/nlp-course/en/chapter6/5"
],
"metadata": {
"id": "dAYGIiUlCQM9"
}
},
{
"cell_type": "code",
"source": [
"from tokenizers import Tokenizer\n",
"from tokenizers.models import BPE\n",
"from tokenizers.trainers import BpeTrainer\n",
"from tokenizers.pre_tokenizers import Whitespace\n",
"\n",
"# Initialize the tokenizer\n",
"tokenizer = Tokenizer(BPE())\n",
"\n",
"# Setup pre-tokenizer\n",
"tokenizer.pre_tokenizer = Whitespace()\n",
"\n",
"# Setup trainer\n",
"trainer = BpeTrainer(\n",
" vocab_size=30000,\n",
" min_frequency=2,\n",
" special_tokens=[\"=\", \":\", \",\", \"\\\"\", \"\\'\", \"(\", \")\", \"[\", \"]\", \"{\", \"}\"],\n",
" show_progress=True\n",
")\n",
"\n",
"# Train the tokenizer on your log data\n",
"tokenizer.train(files=[\"lab_logs_normal_activity_may_11_2024.csv\"], trainer=trainer)\n",
"\n",
"# Save the tokenizer\n",
"tokenizer.save(\"log_tokenizer.json\")"
],
"metadata": {
"id": "9A2nD4deq8lO"
},
"execution_count": 7,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"## Test with a message\n",
"\n",
"By avoid time or other variant series data we can avoid noisy vectors."
],
"metadata": {
"id": "pszMjefAiWmS"
}
},
{
"cell_type": "code",
"source": [
"log_text = \"\"\"\n",
"Registry value set (rule: RegistryEvent),\"Registry value set:\n",
"RuleName: InvDB-Ver\n",
"EventType: SetValue\n",
"UtcTime: 2024-07-28 15:09:33.716\n",
"ProcessGuid: {18e8265a-5ee4-66a6-5400-000000004400}\n",
"ProcessId: 4032\n",
"Image: C:\\Windows\\System32\\svchost.exe\n",
"TargetObject: \\REGISTRY\\A\\{90cbbb87-bac4-4fa3-1d8b-b1a042a75259}\\Root\\InventoryApplicationFile\\ie4uinit.exe|874b2700383dd346\\BinProductVersion\n",
"Details: 11.0.19041.4648\"\n",
"\"\"\"\n",
"\n",
"# Tokenize your log_text\n",
"output = tokenizer.encode(log_text)\n",
"\n",
"# Print the tokens\n",
"print(output.tokens)\n",
"\n",
"# If you want to get the IDs of the tokens\n",
"print(output.ids)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "LTQyGEuZMhdS",
"outputId": "8e4f07ba-5f5e-4e42-afc0-4634afd7df7c"
},
"execution_count": 33,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"['2024', '-', '07', '-', '28T', '15', ':', '09', ':', '33', '.', '716Z', ',', 'win10', ',', 'fe80', ':', ':', 'c', '1af', ':', '35', 'de', ':', '600', '6', ':', 'd4c', 'f', ',', 'information', ',', '13', ',', 'Registry', 'value', 'set', '(', 'rule', ':', 'RegistryEvent', ')', ',', '\"', 'Registry', 'value', 'set', ':', 'RuleName', ':', 'InvDB', '-', 'Ver', 'EventType', ':', 'SetValue', 'UtcTime', ':', '2024', '-', '07', '-', '28', '15', ':', '09', ':', '33', '.', '716', 'ProcessGuid', ':', '{', '18e8265a', '-', '5e', 'e4', '-', '6', '6a', '6', '-', '5400', '-', '00000000', '4400', '}', 'ProcessId', ':', '4032', 'Image', ':', 'C', ':', '\\\\', 'Windows', '\\\\', 'System32', '\\\\', 'svchost', '.', 'exe', 'TargetObject', ':', '\\\\', 'REGISTRY', '\\\\', 'A', '\\\\', '{', '90', 'cbb', 'b87', '-', 'bac', '4', '-', '4fa', '3', '-', '1d', '8b', '-', 'b1a', '042', 'a7', '525', '9', '}', '\\\\', 'Root', '\\\\', 'InventoryApplicationFile', '\\\\', 'ie4uinit', '.', 'exe', '|', '87', '4b', '2700', '38', '3d', 'd3', '46', '\\\\', 'BinProductVersion', 'Details', ':', '11', '.', '0', '.', '19041', '.', '4648', '\"']\n",
"[789, 18, 282, 18, 9118, 429, 1, 451, 1, 733, 19, 20558, 2, 229, 2, 426, 1, 1, 68, 29307, 1, 292, 352, 1, 1150, 27, 1, 22018, 71, 2, 437, 2, 591, 2, 1196, 795, 987, 5, 865, 1, 2644, 6, 2, 3, 1196, 795, 987, 1, 1324, 1, 4115, 18, 5731, 4108, 1, 2763, 1112, 1, 789, 18, 282, 18, 316, 429, 1, 451, 1, 733, 19, 26125, 1450, 1, 9, 444, 18, 1120, 1747, 18, 27, 898, 27, 18, 22096, 18, 379, 14576, 10, 1089, 1, 27452, 1041, 1, 38, 1, 62, 176, 62, 307, 62, 915, 19, 260, 4107, 1, 62, 3071, 62, 36, 62, 9, 322, 10902, 10887, 18, 10941, 25, 18, 6459, 24, 18, 360, 2379, 18, 11383, 14573, 766, 7844, 30, 10, 62, 2675, 62, 3349, 62, 20723, 19, 260, 92, 799, 489, 2805, 256, 2276, 1360, 308, 62, 5983, 2001, 1, 318, 19, 21, 19, 326, 19, 19526, 3]\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Download dataset to use\n",
"\n",
"The analysis dataset has less complex data.\n",
"Variant series data has been removed."
],
"metadata": {
"id": "RtdpdugyCl9v"
}
},
{
"cell_type": "code",
"source": [
"# File name to search for\n",
"file_name = \"lab_logs_blindtest_activity_june_25_2024.csv\"\n",
"\n",
"# Get the download URL of the specific file\n",
"# download_url = get_specific_file_from_latest_release(github_token, repository_name, file_name)\n",
"download_url = get_specific_file_from_tagged_release(github_token, repository_name, \"foundations\", file_name)\n",
"print(download_url)\n",
"\n",
"if download_url:\n",
" local_file_path = file_name\n",
" download_file(download_url, github_token, local_file_path)\n",
"else:\n",
" print(\"File not found.\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 85,
"referenced_widgets": [
"b5271fc8f29a469696c63aa94f4c66c8",
"6892e1bb32de4655bd093e50aad979b9",
"56ff1bf94fac41e2a4bfe6c642db64b9",
"ee15e4b938df499da5b1d2838552962f",
"b5dd37fa45c24680b93c21c1969490f2",
"df2385efc48a411c8e3584c954f23a26",
"63bb98eab8c74dfabea98b85ff397b77",
"cdbb5e7a249c436b809e125f496eb82e",
"3b4d2cf14f8c496ab740a2485155442b",
"59dd159940064a309eb1cad67d5faa7a",
"f84bb79f52664a91b1a6ee637281c285"
]
},
"id": "mjrM6gVOGfHq",
"outputId": "ed833527-a554-48eb-f2a5-6d166e4eb88e"
},
"execution_count": 9,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"https://api.github.com/repos/norandom/log2ml/releases/assets/175872904\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
" 0%| | 0.00/594k [00:00<?, ?iB/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
"model_id": "b5271fc8f29a469696c63aa94f4c66c8"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"File downloaded successfully and saved as lab_logs_blindtest_activity_june_25_2024.csv\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# ETL dataset\n",
"\n",
"ETL stands for Extract, Transform, Load. Polars is used as an ETL framework. It allows DataFrame operations, similar to Pandas."
],
"metadata": {
"id": "wk7RQHMuCxjr"
}
},
{
"cell_type": "code",
"source": [
"import polars as pl\n",
"\n",
"csv_file_path = 'lab_logs_blindtest_activity_june_25_2024.csv'\n",
"\n",
"# Load the CSV file into a DataFrame\n",
"df = pl.read_csv(csv_file_path)\n",
"\n",
"# Show the DataFrame to confirm it's loaded correctly\n",
"print(df)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4ZzWtIgFxBJS",
"outputId": "c757f28a-94c4-4d2b-cfc8-33e4c8b3889f"
},
"execution_count": 10,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"shape: (1_027, 5)\n",
"┌───────┬─────────────┬─────────────────┬───────────────────────────────────┬─────────────────────┐\n",
"│ index ┆ log.level ┆ winlog.event_id ┆ winlog.task ┆ filtered_message │\n",
"│ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n",
"│ i64 ┆ str ┆ i64 ┆ str ┆ str │\n",
"╞═══════╪═════════════╪═════════════════╪═══════════════════════════════════╪═════════════════════╡\n",
"│ 0 ┆ information ┆ 10 ┆ Process accessed (rule: ProcessA… ┆ Process accessed: │\n",
"│ ┆ ┆ ┆ ┆ RuleName: - │\n",
"│ ┆ ┆ ┆ ┆ So… │\n",
"│ 1 ┆ information ┆ 10 ┆ Process accessed (rule: ProcessA… ┆ Process accessed: │\n",
"│ ┆ ┆ ┆ ┆ RuleName: - │\n",
"│ ┆ ┆ ┆ ┆ So… │\n",
"│ 2 ┆ information ┆ 1 ┆ Process Create (rule: ProcessCre… ┆ Process Create: │\n",
"│ ┆ ┆ ┆ ┆ RuleName: - │\n",
"│ ┆ ┆ ┆ ┆ Proc… │\n",
"│ 3 ┆ information ┆ 13 ┆ Registry value set (rule: Regist… ┆ Registry value set: │\n",
"│ ┆ ┆ ┆ ┆ RuleName: Ta… │\n",
"│ … ┆ … ┆ … ┆ … ┆ … │\n",
"│ 1023 ┆ information ┆ 10 ┆ Process accessed (rule: ProcessA… ┆ Process accessed: │\n",
"│ ┆ ┆ ┆ ┆ RuleName: - │\n",
"│ ┆ ┆ ┆ ┆ So… │\n",
"│ 1024 ┆ information ┆ 1 ┆ Process Create (rule: ProcessCre… ┆ Process Create: │\n",
"│ ┆ ┆ ┆ ┆ RuleName: - │\n",
"│ ┆ ┆ ┆ ┆ Proc… │\n",
"│ 1025 ┆ information ┆ 22 ┆ Dns query (rule: DnsQuery) ┆ Dns query: │\n",
"│ ┆ ┆ ┆ ┆ RuleName: - │\n",
"│ ┆ ┆ ┆ ┆ ProcessId… │\n",
"│ 1026 ┆ information ┆ 1 ┆ Process Create (rule: ProcessCre… ┆ Process Create: │\n",
"│ ┆ ┆ ┆ ┆ RuleName: - │\n",
"│ ┆ ┆ ┆ ┆ Proc… │\n",
"└───────┴─────────────┴─────────────────┴───────────────────────────────────┴─────────────────────┘\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import re\n",
"\n",
"# Extract relevant information using regular expressions\n",
"def extract_info(text):\n",
" image = re.search(r\"Image: (.*?\\.exe)\", text, re.IGNORECASE)\n",
" target_filename = re.search(r\"TargetFilename: (.*?\\.exe)\", text, re.IGNORECASE)\n",
" return {\n",
" \"image\": image.group(1) if image else \"\",\n",
" \"target_filename\": target_filename.group(1) if target_filename else \"\",\n",
" \"text\": text\n",
" }\n",
"\n",
"# Apply extraction to the Polars DataFrame using map_elements\n",
"better_columns_df = df.with_columns(\n",
" pl.col(\"filtered_message\").map_elements(lambda x: extract_info(x), return_dtype=pl.Object).alias(\"extracted_info\")\n",
")\n",
"\n",
"# Extract fields from the extracted_info column using map_elements with return_dtype\n",
"better_columns_df = better_columns_df.with_columns(\n",
" pl.col(\"extracted_info\").map_elements(lambda x: x['image'], return_dtype=pl.Utf8).alias(\"image\"),\n",
" pl.col(\"extracted_info\").map_elements(lambda x: x['target_filename'], return_dtype=pl.Utf8).alias(\"target_filename\"),\n",
" pl.col(\"extracted_info\").map_elements(lambda x: x['text'], return_dtype=pl.Utf8).alias(\"text\")\n",
").drop(\"extracted_info\")\n",
"\n",
"print(better_columns_df)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "HoknJOrTZC7V",
"outputId": "6525ea8b-f1f8-4ac7-9f7f-9e13ea6856d7"
},
"execution_count": 11,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"shape: (1_027, 8)\n",
"┌───────┬────────────┬────────────┬────────────┬────────────┬────────────┬────────────┬────────────┐\n",
"│ index ┆ log.level ┆ winlog.eve ┆ winlog.tas ┆ filtered_m ┆ image ┆ target_fil ┆ text │\n",
"│ --- ┆ --- ┆ nt_id ┆ k ┆ essage ┆ --- ┆ ename ┆ --- │\n",
"│ i64 ┆ str ┆ --- ┆ --- ┆ --- ┆ str ┆ --- ┆ str │\n",
"│ ┆ ┆ i64 ┆ str ┆ str ┆ ┆ str ┆ │\n",
"╞═══════╪════════════╪════════════╪════════════╪════════════╪════════════╪════════════╪════════════╡\n",
"│ 0 ┆ informatio ┆ 10 ┆ Process ┆ Process ┆ C:\\Windows ┆ ┆ Process │\n",
"│ ┆ n ┆ ┆ accessed ┆ accessed: ┆ \\system32\\ ┆ ┆ accessed: │\n",
"│ ┆ ┆ ┆ (rule: ┆ RuleName: ┆ svchost.ex ┆ ┆ RuleName: │\n",
"│ ┆ ┆ ┆ ProcessA… ┆ - ┆ e ┆ ┆ - │\n",
"│ ┆ ┆ ┆ ┆ So… ┆ ┆ ┆ So… │\n",
"│ 1 ┆ informatio ┆ 10 ┆ Process ┆ Process ┆ C:\\Windows ┆ ┆ Process │\n",
"│ ┆ n ┆ ┆ accessed ┆ accessed: ┆ \\system32\\ ┆ ┆ accessed: │\n",
"│ ┆ ┆ ┆ (rule: ┆ RuleName: ┆ svchost.ex ┆ ┆ RuleName: │\n",
"│ ┆ ┆ ┆ ProcessA… ┆ - ┆ e ┆ ┆ - │\n",
"│ ┆ ┆ ┆ ┆ So… ┆ ┆ ┆ So… │\n",
"│ 2 ┆ informatio ┆ 1 ┆ Process ┆ Process ┆ C:\\Windows ┆ ┆ Process │\n",
"│ ┆ n ┆ ┆ Create ┆ Create: ┆ \\servicing ┆ ┆ Create: │\n",
"│ ┆ ┆ ┆ (rule: Pro ┆ RuleName: ┆ \\TrustedIn ┆ ┆ RuleName: │\n",
"│ ┆ ┆ ┆ cessCre… ┆ - ┆ st… ┆ ┆ - │\n",
"│ ┆ ┆ ┆ ┆ Proc… ┆ ┆ ┆ Proc… │\n",
"│ 3 ┆ informatio ┆ 13 ┆ Registry ┆ Registry ┆ C:\\Windows ┆ ┆ Registry │\n",
"│ ┆ n ┆ ┆ value set ┆ value set: ┆ \\servicing ┆ ┆ value set: │\n",
"│ ┆ ┆ ┆ (rule: ┆ RuleName: ┆ \\TrustedIn ┆ ┆ RuleName: │\n",
"│ ┆ ┆ ┆ Regist… ┆ Ta… ┆ st… ┆ ┆ Ta… │\n",
"│ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … │\n",
"│ 1023 ┆ informatio ┆ 10 ┆ Process ┆ Process ┆ C:\\Program ┆ ┆ Process │\n",
"│ ┆ n ┆ ┆ accessed ┆ accessed: ┆ Files (x86 ┆ ┆ accessed: │\n",
"│ ┆ ┆ ┆ (rule: ┆ RuleName: ┆ )\\Microsof ┆ ┆ RuleName: │\n",
"│ ┆ ┆ ┆ ProcessA… ┆ - ┆ t… ┆ ┆ - │\n",
"│ ┆ ┆ ┆ ┆ So… ┆ ┆ ┆ So… │\n",
"│ 1024 ┆ informatio ┆ 1 ┆ Process ┆ Process ┆ C:\\Windows ┆ ┆ Process │\n",
"│ ┆ n ┆ ┆ Create ┆ Create: ┆ \\System32\\ ┆ ┆ Create: │\n",
"│ ┆ ┆ ┆ (rule: Pro ┆ RuleName: ┆ taskhostw. ┆ ┆ RuleName: │\n",
"│ ┆ ┆ ┆ cessCre… ┆ - ┆ ex… ┆ ┆ - │\n",
"│ ┆ ┆ ┆ ┆ Proc… ┆ ┆ ┆ Proc… │\n",
"│ 1025 ┆ informatio ┆ 22 ┆ Dns query ┆ Dns query: ┆ ┆ ┆ Dns query: │\n",
"│ ┆ n ┆ ┆ (rule: ┆ RuleName: ┆ ┆ ┆ RuleName: │\n",
"│ ┆ ┆ ┆ DnsQuery) ┆ - ┆ ┆ ┆ - │\n",
"│ ┆ ┆ ┆ ┆ ProcessId… ┆ ┆ ┆ ProcessId… │\n",
"│ 1026 ┆ informatio ┆ 1 ┆ Process ┆ Process ┆ C:\\Program ┆ ┆ Process │\n",
"│ ┆ n ┆ ┆ Create ┆ Create: ┆ Files\\RUXI ┆ ┆ Create: │\n",
"│ ┆ ┆ ┆ (rule: Pro ┆ RuleName: ┆ M\\PLUGSche ┆ ┆ RuleName: │\n",
"│ ┆ ┆ ┆ cessCre… ┆ - ┆ d… ┆ ┆ - │\n",
"│ ┆ ┆ ┆ ┆ Proc… ┆ ┆ ┆ Proc… │\n",
"└───────┴────────────┴────────────┴────────────┴────────────┴────────────┴────────────┴────────────┘\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Max token length\n",
"\n",
"Selecting models which can deal with the long Sysmon agent messages is a key design requirement for the ML pipeline. Therefore the maximum token lengt"
],
"metadata": {
"id": "QdLFssPYC-oG"
}
},
{
"cell_type": "code",
"source": [
"from tokenizers import Tokenizer\n",
"\n",
"# Load the custom tokenizer\n",
"tokenizer = Tokenizer.from_file(\"log_tokenizer.json\")\n",
"\n",
"# Tokenize the messages and calculate total tokens\n",
"token_lengths = better_columns_df[\"filtered_message\"].map_elements(\n",
" lambda x: len(tokenizer.encode(x).ids)\n",
")\n",
"\n",
"# Get the total input token count\n",
"total_input_tokens = token_lengths.sum()\n",
"\n",
"print(f\"The total number of input tokens in 'filtered_message' column is: {total_input_tokens}\")\n",
"\n",
"# Cost calculation\n",
"input_cost_per_1m_tokens = 10.00 # $10.00 per 1M tokens for input\n",
"output_cost_per_1m_tokens = 30.00 # $30.00 per 1M tokens for output\n",
"\n",
"# Estimate output tokens (this is an assumption, adjust as needed)\n",
"estimated_output_tokens = total_input_tokens * 0.5 # Assuming output is 50% of input\n",
"\n",
"# Calculate costs\n",
"input_cost = (total_input_tokens / 1_000_000) * input_cost_per_1m_tokens\n",
"output_cost = (estimated_output_tokens / 1_000_000) * output_cost_per_1m_tokens\n",
"total_cost = input_cost + output_cost\n",
"\n",
"print(\"Reference: OpenAI GPT4-turbo July 2022\")\n",
"print(f\"Estimated input cost: ${input_cost:.2f}\")\n",
"print(f\"Estimated output cost: ${output_cost:.2f}\")\n",
"print(f\"Estimated total cost: ${total_cost:.2f}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "g7lfhJmEckFK",
"outputId": "c6330ae4-8673-4814-fd45-ef2d25238fbc"
},
"execution_count": 15,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"The total number of input tokens in 'filtered_message' column is: 157897\n",
"Reference: OpenAI GPT4-turbo July 2022\n",
"Estimated input cost: $1.58\n",
"Estimated output cost: $2.37\n",
"Estimated total cost: $3.95\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Load LinFormer with max. token length"
],
"metadata": {
"id": "ARRqrF8XDG6_"
}
},
{
"cell_type": "code",
"source": [
"import torch\n",
"torch.cuda.empty_cache()"
],
"metadata": {
"id": "OIXpckDUPZ6C"
},
"execution_count": 16,
"outputs": []
},
{
"cell_type": "code",
"source": [
"# Linformer parameter 1:1 from the researchers\n",
"# then the input_size is determined based on the max no. of tokens\n",
"# the positional embedding flag is on by default\n",
"\n",
"from linformer_pytorch import LinformerLM\n",
"import torch\n",
"\n",
"linformer_model = LinformerLM(\n",
" num_tokens=30000, # Number of tokens in the LM\n",
" input_size=700, # Dimension 1 of the input\n",
" channels=64, # Dimension 2 of the input\n",
" dim_d=None, # Overwrites the inner dim of the attention heads. If None, sticks with the recommended channels // nhead, as in the \"Attention is all you need\" paper\n",
" dim_k=128, # The second dimension of the P_bar matrix from the paper\n",
" dim_ff=128, # Dimension in the feed forward network\n",
" dropout_ff=0.15, # Dropout for feed forward network\n",
" nhead=4, # Number of attention heads\n",
" depth=2, # How many times to run the model\n",
" dropout=0.1, # How much dropout to apply to P_bar after softmax\n",
" activation=\"gelu\", # What activation to use. Currently, only gelu and relu supported, and only on ff network.\n",
" checkpoint_level=\"C0\", # What checkpoint level to use. For more information, see below.\n",
" parameter_sharing=\"layerwise\", # What level of parameter sharing to use. For more information, see below.\n",
" k_reduce_by_layer=0, # Going down `depth`, how much to reduce `dim_k` by, for the `E` and `F` matrices. Will have a minimum value of 1.\n",
" full_attention=False, # Use full attention instead, for O(n^2) time and space complexity. Included here just for comparison\n",
" include_ff=True, # Whether or not to include the Feed Forward layer\n",
" w_o_intermediate_dim=None, # If not None, have 2 w_o matrices, such that instead of `dim*nead,channels`, you have `dim*nhead,w_o_int`, and `w_o_int,channels`\n",
" emb_dim=128, # If you want the embedding dimension to be different than the channels for the Linformer\n",
" causal=False, # If you want this to be a causal Linformer, where the upper right of the P_bar matrix is masked out.\n",
" method=\"learnable\", # The method of how to perform the projection. Supported methods are 'convolution', 'learnable', and 'no_params'\n",
" ff_intermediate=None, # See the section below for more information\n",
" ).cuda()"
],
"metadata": {
"id": "BeSmET5qvPCu"
},
"execution_count": 17,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# Vectorize the text column in that DataFrame"
],
"metadata": {
"id": "AzxrG_RmDLr1"
}
},
{
"cell_type": "code",
"source": [
"from tokenizers import Tokenizer\n",
"import torch\n",
"import numpy as np\n",
"import polars as pl\n",
"\n",
"# Load the custom tokenizer\n",
"tokenizer = Tokenizer.from_file(\"log_tokenizer.json\")\n",
"\n",
"# Define the device (assuming you're using PyTorch and want to specify CPU or GPU)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"def vectorize_text(text):\n",
" MAX_LENGTH = 700 # Define the maximum length of tokens for the model\n",
"\n",
" # Tokenize using the custom tokenizer\n",
" encoded = tokenizer.encode(text)\n",
"\n",
" # Get token IDs\n",
" input_ids = encoded.ids\n",
"\n",
" # Ensure the input_ids length is exactly MAX_LENGTH\n",
" input_ids = input_ids[:MAX_LENGTH] if len(input_ids) > MAX_LENGTH else input_ids + [0] * (MAX_LENGTH - len(input_ids))\n",
"\n",
" # Convert to PyTorch tensor and move to the appropriate device\n",
" input_ids = torch.tensor([input_ids], dtype=torch.long).to(device)\n",
"\n",
" # Get the model outputs, ensuring the input tensor is the correct size\n",
" outputs = linformer_model(input_ids) # Now passing the tensor directly\n",
"\n",
" # Assuming outputs is the tensor of interest\n",
" vector = outputs.mean(dim=1).detach() # Detach the tensor from the GPU\n",
" return vector.cpu().numpy() # Move tensor back to CPU and convert to numpy\n",
"\n",
"# Assuming `better_columns_df` is a Polars DataFrame with a column \"filtered_message\"\n",
"linformer_vector_df = better_columns_df.with_columns(\n",
" pl.col(\"filtered_message\").map_elements(lambda x: vectorize_text(x).flatten(), return_dtype=pl.Object).alias(\"message_vector\")\n",
")\n",
"\n",
"print(linformer_vector_df)\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "79IKUS-OXEiS",
"outputId": "cc314d37-884a-4852-c17e-ab7e65777ada"
},
"execution_count": 18,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"shape: (1_027, 9)\n",
"┌───────┬────────────┬────────────┬────────────┬───┬───────────┬───────────┬───────────┬───────────┐\n",
"│ index ┆ log.level ┆ winlog.eve ┆ winlog.tas ┆ … ┆ image ┆ target_fi ┆ text ┆ message_v │\n",
"│ --- ┆ --- ┆ nt_id ┆ k ┆ ┆ --- ┆ lename ┆ --- ┆ ector │\n",
"│ i64 ┆ str ┆ --- ┆ --- ┆ ┆ str ┆ --- ┆ str ┆ --- │\n",
"│ ┆ ┆ i64 ┆ str ┆ ┆ ┆ str ┆ ┆ object │\n",
"╞═══════╪════════════╪════════════╪════════════╪═══╪═══════════╪═══════════╪═══════════╪═══════════╡\n",
"│ 0 ┆ informatio ┆ 10 ┆ Process ┆ … ┆ C:\\Window ┆ ┆ Process ┆ [ 0.18137 │\n",
"│ ┆ n ┆ ┆ accessed ┆ ┆ s\\system3 ┆ ┆ accessed: ┆ 093 -0.25 │\n",
"│ ┆ ┆ ┆ (rule: ┆ ┆ 2\\svchost ┆ ┆ RuleName: ┆ 33698 │\n",
"│ ┆ ┆ ┆ ProcessA… ┆ ┆ .exe ┆ ┆ - ┆ -0.0416… │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ So… ┆ │\n",
"│ 1 ┆ informatio ┆ 10 ┆ Process ┆ … ┆ C:\\Window ┆ ┆ Process ┆ [ 0.18577 │\n",
"│ ┆ n ┆ ┆ accessed ┆ ┆ s\\system3 ┆ ┆ accessed: ┆ 585 -0.26 │\n",
"│ ┆ ┆ ┆ (rule: ┆ ┆ 2\\svchost ┆ ┆ RuleName: ┆ 15168 │\n",
"│ ┆ ┆ ┆ ProcessA… ┆ ┆ .exe ┆ ┆ - ┆ -0.0566… │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ So… ┆ │\n",
"│ 2 ┆ informatio ┆ 1 ┆ Process ┆ … ┆ C:\\Window ┆ ┆ Process ┆ [ 0.32916 │\n",
"│ ┆ n ┆ ┆ Create ┆ ┆ s\\servici ┆ ┆ Create: ┆ 135 -0.28 │\n",
"│ ┆ ┆ ┆ (rule: Pro ┆ ┆ ng\\Truste ┆ ┆ RuleName: ┆ 799376 │\n",
"│ ┆ ┆ ┆ cessCre… ┆ ┆ dInst… ┆ ┆ - ┆ -0.0459… │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ Proc… ┆ │\n",
"│ 3 ┆ informatio ┆ 13 ┆ Registry ┆ … ┆ C:\\Window ┆ ┆ Registry ┆ [ 0.33915 │\n",
"│ ┆ n ┆ ┆ value set ┆ ┆ s\\servici ┆ ┆ value ┆ 532 -0.29 │\n",
"│ ┆ ┆ ┆ (rule: ┆ ┆ ng\\Truste ┆ ┆ set: ┆ 842803 │\n",
"│ ┆ ┆ ┆ Regist… ┆ ┆ dInst… ┆ ┆ RuleName: ┆ -0.0153… │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ Ta… ┆ │\n",
"│ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … │\n",
"│ 1023 ┆ informatio ┆ 10 ┆ Process ┆ … ┆ C:\\Progra ┆ ┆ Process ┆ [ 0.07160 │\n",
"│ ┆ n ┆ ┆ accessed ┆ ┆ m Files ┆ ┆ accessed: ┆ 681 -0.26 │\n",
"│ ┆ ┆ ┆ (rule: ┆ ┆ (x86)\\Mic ┆ ┆ RuleName: ┆ 001436 │\n",
"│ ┆ ┆ ┆ ProcessA… ┆ ┆ rosoft… ┆ ┆ - ┆ -0.0120… │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ So… ┆ │\n",
"│ 1024 ┆ informatio ┆ 1 ┆ Process ┆ … ┆ C:\\Window ┆ ┆ Process ┆ [ │\n",
"│ ┆ n ┆ ┆ Create ┆ ┆ s\\System3 ┆ ┆ Create: ┆ 0.3071262 │\n",
"│ ┆ ┆ ┆ (rule: Pro ┆ ┆ 2\\taskhos ┆ ┆ RuleName: ┆ -0.313339 │\n",
"│ ┆ ┆ ┆ cessCre… ┆ ┆ tw.ex… ┆ ┆ - ┆ 6 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ Proc… ┆ -0.0181… │\n",
"│ 1025 ┆ informatio ┆ 22 ┆ Dns query ┆ … ┆ ┆ ┆ Dns ┆ [ │\n",
"│ ┆ n ┆ ┆ (rule: ┆ ┆ ┆ ┆ query: ┆ 0.3430285 │\n",
"│ ┆ ┆ ┆ DnsQuery) ┆ ┆ ┆ ┆ RuleName: ┆ -0.288157 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ - ┆ 05 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ProcessId ┆ -0.0193… │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ … ┆ │\n",
"│ 1026 ┆ informatio ┆ 1 ┆ Process ┆ … ┆ C:\\Progra ┆ ┆ Process ┆ [ 0.29110 │\n",
"│ ┆ n ┆ ┆ Create ┆ ┆ m Files\\R ┆ ┆ Create: ┆ 384 -0.29 │\n",
"│ ┆ ┆ ┆ (rule: Pro ┆ ┆ UXIM\\PLUG ┆ ┆ RuleName: ┆ 604813 │\n",
"│ ┆ ┆ ┆ cessCre… ┆ ┆ Sched… ┆ ┆ - ┆ -0.0109… │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ Proc… ┆ │\n",
"└───────┴────────────┴────────────┴────────────┴───┴───────────┴───────────┴───────────┴───────────┘\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Ensure memory_profiler is installed\n",
"!pip install -q memory-profiler\n",
"\n",
"import timeit\n",
"from memory_profiler import memory_usage\n",
"import polars as pl\n",
"import numpy as np\n",
"import torch\n",
"\n",
"# Function to process the DataFrame and measure memory usage\n",
"def process_and_measure(df):\n",
" # Reset GPU memory peak stats\n",
" if torch.cuda.is_available():\n",
" torch.cuda.reset_peak_memory_stats()\n",
"\n",
" temp_df = df.with_columns(\n",
" pl.col(\"filtered_message\").map_elements(lambda x: vectorize_text(x).flatten(), return_dtype=pl.Object).alias(\"message_vector\")\n",
" )\n",
"\n",
" # Get GPU memory usage\n",
" gpu_memory_usage = torch.cuda.max_memory_allocated() / (1024 ** 2) if torch.cuda.is_available() else 0\n",
" return gpu_memory_usage\n",
"\n",
"# Number of times to run the function\n",
"n_runs = 10\n",
"\n",
"# Measure the runtime, CPU, and GPU memory usage over multiple runs\n",
"runtimes = []\n",
"cpu_memory_usages = []\n",
"gpu_memory_usages = []\n",
"\n",
"for _ in range(n_runs):\n",
" # Measure runtime\n",
" runtime = timeit.timeit(lambda: process_and_measure(better_columns_df), number=1)\n",
" runtimes.append(runtime)\n",
"\n",
" # Measure CPU memory usage\n",
" cpu_mem_usage = memory_usage((process_and_measure, (better_columns_df,)))\n",
" cpu_memory_usages.append(max(cpu_mem_usage) - min(cpu_mem_usage))\n",
"\n",
" # Measure GPU memory usage\n",
" gpu_mem_usage = process_and_measure(better_columns_df)\n",
" gpu_memory_usages.append(gpu_mem_usage)\n",
"\n",
"# Calculate average runtime, CPU memory, and GPU memory usage\n",
"average_runtime = np.mean(runtimes)\n",
"average_cpu_memory_usage = np.mean(cpu_memory_usages)\n",
"average_gpu_memory_usage = np.mean(gpu_memory_usages)\n",
"\n",
"print(f\"Average runtime over {n_runs} runs: {average_runtime:.6f} seconds\")\n",
"print(f\"Average CPU memory usage over {n_runs} runs: {average_cpu_memory_usage:.2f} MiB\")\n",
"print(f\"Average GPU memory usage over {n_runs} runs: {average_gpu_memory_usage:.2f} MiB\")\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "g8zGXPPsRAoh",
"outputId": "cdaf812f-935b-4dc5-db31-5aa2593ea24c"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Average runtime over 10 runs: 7.183239 seconds\n",
"Average CPU memory usage over 10 runs: 116.41 MiB\n",
"Average GPU memory usage over 10 runs: 131.21 MiB\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Convert the vectors to a NumPy array for statistical analysis"
],
"metadata": {
"id": "3Gxep-miEAVm"
}
},
{
"cell_type": "code",
"source": [
"# Assuming vector_df has a column 'message_vector' with the vectors\n",
"linformer_vectors = np.stack(linformer_vector_df['message_vector'].to_numpy()) # Stack the vectors into a 2D array"
],
"metadata": {
"id": "AInOx4hm7mR1"
},
"execution_count": 19,
"outputs": []
},
{
"cell_type": "markdown",
"source": [
"# UMAP dimensionality check - scattered or clustered?"
],
"metadata": {
"id": "zZVCSpvsEJIq"
}
},
{
"cell_type": "code",
"source": [
"import umap.umap_ as umap # Correct import statement for UMAP\n",
"\n",
"# Reduce dimensions using UMAP\n",
"umap_reducer = umap.UMAP(n_components=2, random_state=42)\n",
"umap_results = umap_reducer.fit_transform(linformer_vectors)\n",
"\n",
"# Plotting the results\n",
"import matplotlib.pyplot as plt\n",
"\n",
"plt.figure(figsize=(10, 8))\n",
"plt.scatter(umap_results[:, 0], umap_results[:, 1], s=5)\n",
"plt.title('UMAP Visualization of Vectors')\n",
"plt.xlabel('UMAP component 1')\n",
"plt.ylabel('UMAP component 2')\n",
"plt.show()\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 774
},
"id": "I55hh_VVm5Gs",
"outputId": "26ff87cc-301f-4993-a06e-ce05f1c49560"
},
"execution_count": 20,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/umap/umap_.py:1945: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.\n",
" warn(f\"n_jobs value {self.n_jobs} overridden to 1 by setting random_state. Use no seed for parallelism.\")\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1000x800 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1kAAAK9CAYAAADWo6YTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAACYJUlEQVR4nOzdeXiU5b3/8c8zGciiJGSiyJKELICgqOyLVCAQxVZrtbihqOBai4rYTXusW/uTtpZKxLjUBVqttscFbW2P1UAqRfZFWytKVkLYNDuSBWbm+f0RZpxnliwwJJnk/bquXId5trknxUM++d739zZM0zQFAAAAAAgLW2cPAAAAAAC6E0IWAAAAAIQRIQsAAAAAwoiQBQAAAABhRMgCAAAAgDAiZAEAAABAGBGyAAAAACCMCFkAAAAAEEaELAAAAAAII0IWAOC4GYahhx56qLOHETCOFStWyDAMlZaWdug4Out92+uxxx5TRkaGoqKiNGrUqM4eDgB0G4QsADhODz30kAzDUEVFRdDzI0eO1PTp072vS0tLZRiGDMPQL37xi6D3XHvttTIMQyeffHLI950wYYIMw9DTTz8d9LznB33PV0xMjIYNG6Y77rhDBw4cCPncN998U4Zh6Pnnnw95zfvvvy/DMPTEE0+EvKYnePTRR/XWW2919jCOyXvvvacf//jHmjJlipYvX65HH3004JojR47olFNO0Te+8Y2QzzFNUykpKRozZkxYx7d371499NBD+uijj8L6XADoCIQsAOgkMTExevXVVwOOHzp0SG+//bZiYmJC3ltQUKDNmzcrLS1Nf/zjH1t8n0ceeUQvvfSSnnzySZ177rl6+umnNXnyZNXX1we9/qKLLlJCQoJeeeWVkM985ZVXFBUVpauvvlqS1NDQoPvvv7/FcXSG6667Tg0NDRo8ePAJeX6okHWi3zccVq9eLZvNphdeeEHXX3+9vvWtbwVc06tXL11xxRVat26ddu3aFfQ5a9asUXl5uebOnRvW8e3du1cPP/wwIQtARCJkAUAn+da3vqVPP/1UH3/8seX422+/rcOHD+v8888Pee/LL7+sfv36acmSJVq3bl2L09K++c1vau7cubr55pu1YsUK3X333SopKdHbb78d9Pro6Ghdfvnl+uCDD7R3796A842NjVq5cqXOP/989evXT1JzYLTb7W341B0rKipKMTExMgyjR7xve3zxxReKjY1V7969W7zu2muvlWmaQX8hIDUHbpvN5g3cXd2hQ4c6ewgAegBCFgB0ksmTJys9PT2gYvTHP/5RF154oRwOR8h7X3nlFV1++eW6+OKLW606+ZsxY4YkqaSkJOQ1c+fOldvt1p/+9KeAc3/7299UW1ura6+91nvMfy3UwYMHdffddystLU3R0dHq16+fzj//fG3bts17TVpamubNmxfw/OnTp1umVx4+fFgPPPCAxo4dq4SEBJ100kk677zzlJ+f3+pn9V8b5ZnaGezLdyy/+c1vdO655yopKUmxsbEaO3asXn/9dcuzDcPQoUOH9Pvf/z7gGaHWZD311FM688wzFR0drYEDB2rBggWqqakJ+PwjR47Up59+qqysLMXFxWnQoEH69a9/3ernlSSn06mf//znyszMVHR0tNLS0vTTn/5UTU1NlrEvX75chw4d8o59xYoVQZ83ZcoUpaWlBf07duTIEb3++uvKysrSwIEDJUmfffaZLr/8cjkcDsXExGjcuHH6y1/+EnBvTU2NFi1a5P07kpycrOuvv14VFRX65z//qfHjx0uS5s+fH3SMr732msaOHavY2Fidcsopmjt3rvbs2WN5j3nz5unkk09WUVGRvvWtb6lPnz7ev7cFBQWaPXu2+vfvr5iYGCUnJ+vqq69WbW1tm77PANASQhYAdKI5c+boT3/6k0zTlCRVVFTovffe0zXXXBPyno0bN6qwsFBz5sxR79699d3vfrfVKYO+ioqKJElJSUkhr5k6daqSk5OD/mD9yiuvKC4uTpdeemnI+7/3ve/p6aef1uzZs/XUU0/phz/8oWJjY7Vjx442j9Ojrq5Ozz//vKZPn65f/epXeuihh/Tll19q1qxZ7Z5K9t3vflcvvfSS5evuu++WJG9VTpJycnI0evRoPfLII3r00Udlt9t1xRVX6G9/+5v3mpdeeknR0dE677zzvM+67bbbQr73Qw89pAULFmjgwIFasmSJZs+erWeffVYXXHCBjhw5Yrm2urpaF154oc455xwtWbJEw4cP109+8hP93//9X6uf8eabb9YDDzygMWPG6PHHH9e0adO0ePFiS6XppZde0nnnnafo6Gjv2KdOnRr0eYZh6JprrtF//vMf/fe//7Wce/fdd1VVVeUNLv/97381adIk7dixQ/fee6+WLFmik046SZdeeqlWrlzpve+rr77Seeedp2XLlumCCy5QTk6Ovve97+mzzz5TeXm5RowYoUceeUSSdOuttwaMccWKFbryyisVFRWlxYsX65ZbbtGbb76pb3zjGwGh1el0atasWerXr59+85vfaPbs2Tp8+LBmzZqlDRs26M4771Rubq5uvfVWFRcXB9wPAMfEBAAclwcffNCUZH755ZdBz5955pnmtGnTvK9LSkpMSeZjjz1mfvLJJ6Yk81//+pdpmqaZm5trnnzyyeahQ4fMG264wTzppJMCnnfHHXeYKSkpptvtNk3TNN977z1Tkrl9+3bLdcuXLzclmXl5eeaXX35p7t692/zTn/5kJiUlmbGxsWZ5eXmLn+tHP/qRKcn8/PPPvcdqa2vNmJgYc86cOZZrJZkPPvig93VCQoK5YMGCFp8/ePBg84Ybbgg4Pm3aNMv3y+l0mk1NTZZrqqurzdNOO8288cYbWxyH53tQUlISdAxffvmlmZqaap511lnmV1995T1eX19vue7w4cPmyJEjzRkzZliOn3TSSUE/g//7fvHFF2bv3r3NCy64wHS5XN7rnnzySVOS+eKLL1o+vyTzD3/4g/dYU1OT2b9/f3P27NlBP4fHRx99ZEoyb775ZsvxH/7wh6Ykc/Xq1d5jof5+BfPf//7XlGTed999luNXX321GRMTY9bW1pqmaZozZ840zzrrLLOxsdF7jdvtNs8991xz6NCh3mMPPPCAKcl88803A97L8/d68+bNpiRz+fLllvOHDx82+/XrZ44cOdJsaGjwHn/nnXdMSeYDDzxg+YySzHvvvdfyjO3bt5uSzNdee61Nnx8A2otKFgB0ojPPPFNnn322d73LK6+8ou985zuKi4sLer3T6dSf//xnXXXVVd71PjNmzFC/fv1CVrOys7N16qmnKiUlRVdffbVOPvlkrVy5UoMGDWpxbJ5GBr7VrDfeeEONjY2WqYLB9O3bVxs3bgy6pqu9oqKivOuG3G63qqqq5HQ6NW7cOMv0w/ZyuVyaM2eODh48qJUrV+qkk07ynouNjfX+ubq6WrW1tTrvvPOO+f3y8vJ0+PBh3X333bLZvv6n95ZbblF8fLylQiZJJ598sqWRRO/evTVhwgQVFxe3+D5///vfJUn33HOP5fgPfvADSQp4n7Y644wzNHr0aMv00UOHDukvf/mLLr74YsXHx6uqqkqrV6/WlVdeqYMHD6qiokIVFRWqrKzUrFmzVFBQ4J3O98Ybb+icc87RZZddFvBera1j27Jli7744gt9//vftzSHueiiizR8+PCgn/H222+3vE5ISJAk/eMf/wjZAAYAjgchCwA6QEs/OF5zzTV67bXXVFhYqHXr1rU4VfC9997Tl19+qQkTJqiwsFCFhYUqKSlRVlaWXn31Vbnd7oB7cnNz9f777ys/P1+ffvqpiouLNWvWrFbHfPbZZ2vkyJGWhgevvPKKTjnllFbv//Wvf61PPvlEKSkpmjBhgh566KFWA0JLfv/73+vss89WTEyMkpKSdOqpp3rXhh2r+++/X6tXr9Yrr7yizMxMy7l33nlHkyZNUkxMjBwOh0499VQ9/fTTx/x+ns58p59+uuV47969lZGREdC5Lzk5OeDvTGJioqqrq1t9H5vNpiFDhliO9+/fX3379g3ZIbAtrr32W
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"# Silhouette score via KMeans clustering"
],
"metadata": {
"id": "56wtMVhhEVzf"
}
},
{
"cell_type": "code",
"source": [
"from sklearn.cluster import KMeans\n",
"from sklearn.metrics import silhouette_score\n",
"\n",
"# Assuming we choose an arbitrary number of clusters, e.g., 5\n",
"kmeans = KMeans(n_clusters=2, random_state=42)\n",
"cluster_labels = kmeans.fit_predict(linformer_vectors)\n",
"\n",
"# Calculate the silhouette score\n",
"sil_score = silhouette_score(linformer_vectors, cluster_labels)\n",
"print(f\"Silhouette Score: {sil_score}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "xGm3Fm5bnFae",
"outputId": "85c1ccff-f959-468a-a1a4-277eba2e089f"
},
"execution_count": 21,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Silhouette Score: 0.7105976939201355\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Distribution of the Cosine similarity"
],
"metadata": {
"id": "8yhTWsSnEl30"
}
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sklearn.metrics.pairwise import cosine_similarity\n",
"\n",
"# Assuming vector_df has a column 'message_vector' with the vectors\n",
"vectors = np.stack(linformer_vector_df['message_vector'].to_numpy()) # Stack the vectors into a 2D array\n",
"\n",
"# Calculate the cosine similarity matrix\n",
"cosine_sim_matrix = cosine_similarity(linformer_vectors)\n",
"\n",
"# Extract the upper triangle values from the cosine similarity matrix, excluding the diagonal\n",
"upper_triangle_indices = np.triu_indices_from(cosine_sim_matrix, k=1)\n",
"cosine_sim_values = cosine_sim_matrix[upper_triangle_indices]\n",
"\n",
"# Plotting the KDE of cosine similarities\n",
"plt.figure(figsize=(10, 6))\n",
"sns.kdeplot(cosine_sim_values, color='blue', fill=True)\n",
"plt.title('KDE Plot of Cosine Similarities')\n",
"plt.xlabel('Cosine Similarity')\n",
"plt.ylabel('Density')\n",
"plt.grid(True)\n",
"plt.show()\n",
"\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 564
},
"id": "NPbnuVKHqqtf",
"outputId": "60c00803-a7f0-4843-c348-d3fa57db9388"
},
"execution_count": 22,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0kAAAIjCAYAAADWYVDIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABsuklEQVR4nO3dd3hUZf7+8XsSUklBWkIJHaSDopQFkY7AIiBWLICy6oruT1lXZVUEFVF3dbEg6opgAd2vCtiBCFIsgFJE0UU6KL2lkwzJ+f3x7KSQQjKZmTMzeb+ua64hZ86c80nyMJl7nnIclmVZAgAAAABIkkLsLgAAAAAA/AkhCQAAAAAKISQBAAAAQCGEJAAAAAAohJAEAAAAAIUQkgAAAACgEEISAAAAABRCSAIAAACAQghJAAAAAFAIIQkAAlyTJk00btw4u8soYvv27Ro0aJDi4+PlcDi0ePFiu0sqpk+fPurTp4/dZZRp3LhxatKkiUePefb3vWfPHjkcDs2bN8+j55k6daocDke59p03b54cDof27Nnj0RoAwF2EJABVluuN2ffff19ke0pKirp27arIyEgtWbJEUsEbPtctOjpajRo10vDhwzV37lxlZ2cXO/64ceOKPKfwLTIy8pz1Fd4/JCRE9evX16BBg7Ry5UqPfP8HDhzQ1KlTtXnzZo8cr7CxY8fqxx9/1PTp0/XWW2/poosuKnP/1NRUTZs2TZ06dVJMTIyioqLUvn173X///Tpw4IDH67Pb0aNH9f/+3/9T69atFRUVpbp166pr1666//77lZ6ebnd5XvPEE0/4ZWAGgLNVs7sAAPAnqampGjRokLZs2aJFixbpsssuK/L47NmzFRMTo+zsbP3+++9aunSpbr75Zs2cOVOffPKJkpKSiuwfERGh1157rdh5QkNDy1XPwIEDddNNN8myLO3evVsvvfSS+vXrp08//VRDhgxx/xuVCUnTpk1TkyZN1Llz50odq7CsrCx9++23evDBB3XnnXeec/9du3ZpwIAB2rdvn6666irdeuutCg8P15YtWzRnzhwtWrRIv/76q8fqc1m2bJnHj1keJ06c0EUXXaTU1FTdfPPNat26tY4fP64tW7Zo9uzZ+vOf/6yYmBhJ0r///W/l5eV59Py++r4feughPfDAA0W2PfHEE7ryyis1cuTIIttvvPFGXXvttYqIiPBJbQBwLoQkAPiftLQ0DR48WJs3b9bChQtLDCFXXnmlateunf/1lClTNH/+fN1000266qqrtHbt2iL7V6tWTTfccIPbNbVq1arI80eNGqWOHTtq5syZlQ5J3nL06FFJUo0aNc6575kzZ3TFFVfo8OHDWrlypXr16lXk8enTp+upp57yRpkKDw/3ynHPZc6cOdq3b5++/vpr/eEPfyjyWGpqapG6wsLCPH5+b3/fGRkZql69uqpVq6Zq1cr3NiM0NLTcHxwAgC8w3A4AJKWnp+uyyy7Txo0b9cEHH2jYsGHlfu7111+vCRMmaN26dUpOTvZilVKHDh1Uu3Zt7d69u8z9du3apauuuko1a9ZUdHS0unfvrk8//TT/8ZUrV+riiy+WJI0fPz5/WN+55qVs2rRJQ4YMUVxcnGJiYtS/f/8iwXDq1Klq3LixJOlvf/ubHA5HmXNqPvjgA/3www968MEHiwUkSYqLi9P06dOLbHvvvffUpUsXRUVFqXbt2rrhhhv0+++/F9nn0KFDGj9+vBo2bKiIiAjVq1dPI0aMKDLn5ey5OStXrpTD4dD//d//afr06WrYsKEiIyPVv39/7dixo1ht69at02WXXab4+HhFR0fr0ksv1ddff13Wj0+StHPnToWGhqp79+4lfr+Fh2KePSfJNX/on//8p2bNmqVmzZopOjpagwYN0v79+2VZlh577DE1bNhQUVFRGjFihE6cOFHkHOWZi7VlyxaNGzdOzZo1U2RkpBITE3XzzTfr+PHjRfZzDUP9+eefNWbMGJ133nn5v8ez5yQ5HA5lZGTojTfeyG9vrrl0pc1J+vzzz3XJJZeoevXqio2N1bBhw7R169Yi+5Tndw0AFUVPEoAqLyMjQ0OGDNF3332n999/X3/84x8rfIwbb7xRr776qpYtW6aBAwcWeezYsWPF9g8PD1dcXFyFz3Py5EmdPHlSLVq0KHWfw4cP6w9/+IMyMzP1l7/8RbVq1dIbb7yhyy+/XO+//75GjRqlNm3a6NFHH9WUKVN066236pJLLpGkYj0bhW3dulWXXHKJ4uLidN999yksLEyvvPKK+vTpo1WrVqlbt2664oorVKNGDd1zzz267rrrNHTo0PyhYyX56KOPJJmfX3nMmzdP48eP18UXX6wZM2bo8OHDeu655/T1119r06ZN+b1Xo0eP1tatW3XXXXepSZMmOnLkiJKTk7Vv375zLoTw5JNPKiQkRPfee69SUlL09NNP6/rrr9e6devy91mxYoWGDBmiLl266JFHHlFISIjmzp2rfv36ac2aNeratWupx2/cuLFyc3P11ltvaezYseX6vs82f/585eTk6K677tKJEyf09NNP6+qrr1a/fv20cuVK3X///dqxY4deeOEF3XvvvXr99dcrdPzk5GTt2rVL48ePV2JiorZu3apXX31VW7du1dq1a4styHDVVVepZcuWeuKJJ2RZVonHfOuttzRhwgR17dpVt956qySpefPmpdbg+vkMHjxYTz31lDIzMzV79mz16tVLmzZtyv89VuZ3DQClsgCgipo7d64lyWrcuLEVFhZmLV68uNR9H3nkEUuSdfTo0RIfP3nypCXJGjVqVP62sWPHWpJKvA0ePPic9UmybrnlFuvo0aPWkSNHrHXr1ln9+/e3JFnPPPNM/n6NGze2xo4dm//13XffbUmy1qxZk78tLS3Natq0qdWkSRMrNzfXsizL+u677yxJ1ty5c89Zi2VZ1siRI63w8HBr586d+dsOHDhgxcbGWr17987ftnv3bkuS9Y9//OOcx7zgggus+Pj4cp0/JyfHqlu3rtW+fXsrKysrf/snn3xiSbKmTJliWVbB7+Jc57/00kutSy+9NP/rL7/80pJktWnTxsrOzs7f/txzz1mSrB9//NGyLMvKy8uzWrZsaQ0ePNjKy8vL3y8zM9Nq2rSpNXDgwDLPe+jQIatOnTqWJKt169bW7bffbi1YsMA6depUsX3Hjh1rNW7cOP9r18+2Tp06RfafPHmyJcnq1KmT5XQ687dfd911Vnh4uHX69OlSv2/XMQu3g8zMzGK1vPPOO5Yka/Xq1fnbXP8vrrvuumL7ux4rrHr16kXaqovr/+Lu3bstyzLttUaNGtaf/vSnIvsdOnTIio+Pz99e3t81AFQUw+0AVHmHDx9WZGRksUUXKsLVW5KWllZke2RkpJKTk4vdnnzyyXIdd86cOapTp47q1q2rbt266euvv9akSZN09913l/qczz77TF27di0yfC0mJka33nqr9uzZo59//rnC319ubq6WLVumkSNHqlmzZvnb69WrpzFjxuirr75SampqhY+bmpqq2NjYcu37/fff68iRI7rjjjuKDEkbNmyYWrdunT+cMCoqSuHh4Vq5cqVOnjxZ4ZrGjx9fZN6Oq5dt165dkqTNmzdr+/btGjNmjI4fP65jx47p2LFjysjIUP/+/bV69eoyF1tISEjQDz/8oNtvv10nT57Uyy+/rDFjxqhu3bp67LHHSu2JKeyqq65SfHx8/tfdunWTJN1www1F5gF169ZNOTk5xYYjnktUVFT+v0+fPq1jx47lDw/cuHFjsf1vv/32Ch3/XJKTk3Xq1Cldd911+T/fY8eOKTQ0VN26ddOXX36ZX2dlftcAUBqG2wGo8l555RVNmjRJl112mdasWaPzzz+/wsdwLdt89hv+0NBQDRgwwO3aRowYoTvvvFMOh0OxsbFq166dqlevXuZz9u7dm/+mubA2bdrkP96+ffsK1XH06FFlZmaW+LNp06aN8vLytH//frVr165Cx42Li8sPH+eyd+9eSSqxhtatW+urr76SZFYUfOqpp/TXv/5VCQkJ6t69u/74xz/qpptuUmJi4
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"# Comparison model: LongFormer"
],
"metadata": {
"id": "FmwKV-BK-rKX"
}
},
{
"cell_type": "code",
"source": [
"# starts with the same data\n",
"print(better_columns_df)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "l01czPQT-kSu",
"outputId": "be7724d9-5632-41e9-e2dd-6e06430c843e"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"shape: (1_027, 8)\n",
"┌───────┬────────────┬────────────┬────────────┬────────────┬────────────┬────────────┬────────────┐\n",
"│ index ┆ log.level ┆ winlog.eve ┆ winlog.tas ┆ filtered_m ┆ image ┆ target_fil ┆ text │\n",
"│ --- ┆ --- ┆ nt_id ┆ k ┆ essage ┆ --- ┆ ename ┆ --- │\n",
"│ i64 ┆ str ┆ --- ┆ --- ┆ --- ┆ str ┆ --- ┆ str │\n",
"│ ┆ ┆ i64 ┆ str ┆ str ┆ ┆ str ┆ │\n",
"╞═══════╪════════════╪════════════╪════════════╪════════════╪════════════╪════════════╪════════════╡\n",
"│ 0 ┆ informatio ┆ 10 ┆ Process ┆ Process ┆ C:\\Windows ┆ ┆ Process │\n",
"│ ┆ n ┆ ┆ accessed ┆ accessed: ┆ \\system32\\ ┆ ┆ accessed: │\n",
"│ ┆ ┆ ┆ (rule: ┆ RuleName: ┆ svchost.ex ┆ ┆ RuleName: │\n",
"│ ┆ ┆ ┆ ProcessA… ┆ - ┆ e ┆ ┆ - │\n",
"│ ┆ ┆ ┆ ┆ So… ┆ ┆ ┆ So… │\n",
"│ 1 ┆ informatio ┆ 10 ┆ Process ┆ Process ┆ C:\\Windows ┆ ┆ Process │\n",
"│ ┆ n ┆ ┆ accessed ┆ accessed: ┆ \\system32\\ ┆ ┆ accessed: │\n",
"│ ┆ ┆ ┆ (rule: ┆ RuleName: ┆ svchost.ex ┆ ┆ RuleName: │\n",
"│ ┆ ┆ ┆ ProcessA… ┆ - ┆ e ┆ ┆ - │\n",
"│ ┆ ┆ ┆ ┆ So… ┆ ┆ ┆ So… │\n",
"│ 2 ┆ informatio ┆ 1 ┆ Process ┆ Process ┆ C:\\Windows ┆ ┆ Process │\n",
"│ ┆ n ┆ ┆ Create ┆ Create: ┆ \\servicing ┆ ┆ Create: │\n",
"│ ┆ ┆ ┆ (rule: Pro ┆ RuleName: ┆ \\TrustedIn ┆ ┆ RuleName: │\n",
"│ ┆ ┆ ┆ cessCre… ┆ - ┆ st… ┆ ┆ - │\n",
"│ ┆ ┆ ┆ ┆ Proc… ┆ ┆ ┆ Proc… │\n",
"│ 3 ┆ informatio ┆ 13 ┆ Registry ┆ Registry ┆ C:\\Windows ┆ ┆ Registry │\n",
"│ ┆ n ┆ ┆ value set ┆ value set: ┆ \\servicing ┆ ┆ value set: │\n",
"│ ┆ ┆ ┆ (rule: ┆ RuleName: ┆ \\TrustedIn ┆ ┆ RuleName: │\n",
"│ ┆ ┆ ┆ Regist… ┆ Ta… ┆ st… ┆ ┆ Ta… │\n",
"│ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … │\n",
"│ 1023 ┆ informatio ┆ 10 ┆ Process ┆ Process ┆ C:\\Program ┆ ┆ Process │\n",
"│ ┆ n ┆ ┆ accessed ┆ accessed: ┆ Files (x86 ┆ ┆ accessed: │\n",
"│ ┆ ┆ ┆ (rule: ┆ RuleName: ┆ )\\Microsof ┆ ┆ RuleName: │\n",
"│ ┆ ┆ ┆ ProcessA… ┆ - ┆ t… ┆ ┆ - │\n",
"│ ┆ ┆ ┆ ┆ So… ┆ ┆ ┆ So… │\n",
"│ 1024 ┆ informatio ┆ 1 ┆ Process ┆ Process ┆ C:\\Windows ┆ ┆ Process │\n",
"│ ┆ n ┆ ┆ Create ┆ Create: ┆ \\System32\\ ┆ ┆ Create: │\n",
"│ ┆ ┆ ┆ (rule: Pro ┆ RuleName: ┆ taskhostw. ┆ ┆ RuleName: │\n",
"│ ┆ ┆ ┆ cessCre… ┆ - ┆ ex… ┆ ┆ - │\n",
"│ ┆ ┆ ┆ ┆ Proc… ┆ ┆ ┆ Proc… │\n",
"│ 1025 ┆ informatio ┆ 22 ┆ Dns query ┆ Dns query: ┆ ┆ ┆ Dns query: │\n",
"│ ┆ n ┆ ┆ (rule: ┆ RuleName: ┆ ┆ ┆ RuleName: │\n",
"│ ┆ ┆ ┆ DnsQuery) ┆ - ┆ ┆ ┆ - │\n",
"│ ┆ ┆ ┆ ┆ ProcessId… ┆ ┆ ┆ ProcessId… │\n",
"│ 1026 ┆ informatio ┆ 1 ┆ Process ┆ Process ┆ C:\\Program ┆ ┆ Process │\n",
"│ ┆ n ┆ ┆ Create ┆ Create: ┆ Files\\RUXI ┆ ┆ Create: │\n",
"│ ┆ ┆ ┆ (rule: Pro ┆ RuleName: ┆ M\\PLUGSche ┆ ┆ RuleName: │\n",
"│ ┆ ┆ ┆ cessCre… ┆ - ┆ d… ┆ ┆ - │\n",
"│ ┆ ┆ ┆ ┆ Proc… ┆ ┆ ┆ Proc… │\n",
"└───────┴────────────┴────────────┴────────────┴────────────┴────────────┴────────────┴────────────┘\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
2024-07-29 12:01:01 +00:00
"# Initialize and use LongFormer with the same pre-trained Tokenizer"
],
"metadata": {
"id": "J6REcr7hE1xz"
}
},
{
"cell_type": "code",
"source": [
"from tokenizers import Tokenizer\n",
"import torch\n",
"import numpy as np\n",
"import polars as pl\n",
"from transformers import LongformerModel\n",
"\n",
"# Load the custom tokenizer\n",
"tokenizer = Tokenizer.from_file(\"log_tokenizer.json\")\n",
"\n",
"# Load LongFormer model\n",
"model_name = 'allenai/longformer-base-4096'\n",
"longformer_model = LongformerModel.from_pretrained(model_name)\n",
"\n",
"# Define the device (GPU if available, otherwise CPU)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"longformer_model.to(device)\n",
"\n",
"def vectorize_text(text):\n",
" MAX_LENGTH = 700 # Define the maximum length of tokens for the model\n",
"\n",
" # Tokenize using the custom tokenizer\n",
" encoded = tokenizer.encode(text)\n",
"\n",
" # Get token IDs\n",
" input_ids = encoded.ids\n",
"\n",
" # Ensure the input_ids length is exactly MAX_LENGTH\n",
" if len(input_ids) > MAX_LENGTH:\n",
" input_ids = input_ids[:MAX_LENGTH]\n",
" else:\n",
" input_ids = input_ids + [0] * (MAX_LENGTH - len(input_ids))\n",
"\n",
" # Convert to PyTorch tensor and move to the appropriate device\n",
" input_ids = torch.tensor([input_ids], dtype=torch.long).to(device)\n",
"\n",
" # Create attention mask: 1 for non-padding, 0 for padding\n",
" attention_mask = (input_ids != 0).long()\n",
"\n",
" # Get the model outputs, ensuring the input tensor is the correct size\n",
" with torch.no_grad():\n",
" outputs = longformer_model(input_ids, attention_mask=attention_mask)\n",
" vector = outputs.last_hidden_state.mean(dim=1).cpu().numpy() # Average over the token embeddings\n",
" return vector.flatten() # Flatten the array to 1D if needed\n",
"\n",
"# Assuming `better_columns_df` is your Polars DataFrame with a column \"filtered_message\"\n",
"longformer_vector_df = better_columns_df.with_columns(\n",
" pl.col(\"filtered_message\").map_elements(vectorize_text, return_dtype=pl.Object).alias(\"message_vector\")\n",
")\n",
"\n",
"print(longformer_vector_df)\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000,
"referenced_widgets": [
2024-07-29 12:01:01 +00:00
"cf8262210b7c42bc8db80e1a102acf19",
"509c0ac5a7624e2bb8618b978c8a80dc",
"56c17b10c28c4fa9a6db05aa60a71f5e",
"4b85c54d069e4a8d87520dbdad4c2c89",
"6537feb3ffd24d2ea90b27a7a0a1aeb0",
"4e2585503cca48f8a0428cc62d02e5f1",
"b6ea52fd1054458e9458865f8f40ff84",
"d70cc590fb28457b9ca7f74d8aa425f1",
"a9116211ec3d407d9278f744d41507e6",
"d706ec4068874b6eb52918afad80982e",
"9c7b2917eb984de88aa6dbdec1bfa9cb",
"c69abb9efb6b411fa4ee31b0c1910bae",
"1170d27bc06e48bc985700e1703887b3",
"6264397e6794478fa54e161f4434cc8b",
"e5acc8060d7c4d25b6c13bb486eb5c39",
"6ff11b9092db4dab98a68505955a612c",
"b9f3902a7a844da0ab7d9bf5c2869db7",
"86fa6d3c75e94c109ca0d44ddc7cb6e4",
"4d7f1a3f6c044343a096951ed2a58603",
"ef324ed1c14d4e9aac3312ed637f5789",
"d19fd3e721b84b37b2f4da2bf03008de",
"5d3d80342c864f419e64a323051e3509"
]
},
"id": "tpREQMax-yy3",
2024-07-29 12:01:01 +00:00
"outputId": "fa3b4724-87ea-497c-fa6f-60f753b9088e"
},
2024-07-29 12:01:01 +00:00
"execution_count": 29,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
" warnings.warn(\n",
"/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n",
"The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
"To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
"You will be able to reuse this secret in all of your notebooks.\n",
"Please note that authentication is recommended but still optional to access public models or datasets.\n",
" warnings.warn(\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"config.json: 0%| | 0.00/694 [00:00<?, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
2024-07-29 12:01:01 +00:00
"model_id": "cf8262210b7c42bc8db80e1a102acf19"
}
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"pytorch_model.bin: 0%| | 0.00/597M [00:00<?, ?B/s]"
],
"application/vnd.jupyter.widget-view+json": {
"version_major": 2,
"version_minor": 0,
2024-07-29 12:01:01 +00:00
"model_id": "c69abb9efb6b411fa4ee31b0c1910bae"
}
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stderr",
"text": [
"Input ids are automatically padded from 700 to 1024 to be a multiple of `config.attention_window`: 512\n"
]
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"shape: (1_027, 9)\n",
"┌───────┬────────────┬────────────┬────────────┬───┬───────────┬───────────┬───────────┬───────────┐\n",
"│ index ┆ log.level ┆ winlog.eve ┆ winlog.tas ┆ … ┆ image ┆ target_fi ┆ text ┆ message_v │\n",
"│ --- ┆ --- ┆ nt_id ┆ k ┆ ┆ --- ┆ lename ┆ --- ┆ ector │\n",
"│ i64 ┆ str ┆ --- ┆ --- ┆ ┆ str ┆ --- ┆ str ┆ --- │\n",
"│ ┆ ┆ i64 ┆ str ┆ ┆ ┆ str ┆ ┆ object │\n",
"╞═══════╪════════════╪════════════╪════════════╪═══╪═══════════╪═══════════╪═══════════╪═══════════╡\n",
2024-07-29 12:01:01 +00:00
"│ 0 ┆ informatio ┆ 10 ┆ Process ┆ … ┆ C:\\Window ┆ ┆ Process ┆ [ 1.66481 │\n",
"│ ┆ n ┆ ┆ accessed ┆ ┆ s\\system3 ┆ ┆ accessed: ┆ 603e-02 │\n",
"│ ┆ ┆ ┆ (rule: ┆ ┆ 2\\svchost ┆ ┆ RuleName: ┆ 5.9788990 │\n",
"│ ┆ ┆ ┆ ProcessA… ┆ ┆ .exe ┆ ┆ - ┆ 8e-03… │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ So… ┆ │\n",
2024-07-29 12:01:01 +00:00
"│ 1 ┆ informatio ┆ 10 ┆ Process ┆ … ┆ C:\\Window ┆ ┆ Process ┆ [ 2.54064 │\n",
"│ ┆ n ┆ ┆ accessed ┆ ┆ s\\system3 ┆ ┆ accessed: ┆ 687e-02 │\n",
"│ ┆ ┆ ┆ (rule: ┆ ┆ 2\\svchost ┆ ┆ RuleName: ┆ 9.9657382 │\n",
"│ ┆ ┆ ┆ ProcessA… ┆ ┆ .exe ┆ ┆ - ┆ 8e-03… │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ So… ┆ │\n",
2024-07-29 12:01:01 +00:00
"│ 2 ┆ informatio ┆ 1 ┆ Process ┆ … ┆ C:\\Window ┆ ┆ Process ┆ [-2.47876 │\n",
"│ ┆ n ┆ ┆ Create ┆ ┆ s\\servici ┆ ┆ Create: ┆ 402e-02 │\n",
"│ ┆ ┆ ┆ (rule: Pro ┆ ┆ ng\\Truste ┆ ┆ RuleName: ┆ 7.5210638 │\n",
"│ ┆ ┆ ┆ cessCre… ┆ ┆ dInst… ┆ ┆ - ┆ 3e-02… │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ Proc… ┆ │\n",
2024-07-29 12:01:01 +00:00
"│ 3 ┆ informatio ┆ 13 ┆ Registry ┆ … ┆ C:\\Window ┆ ┆ Registry ┆ [-2.11402 │\n",
"│ ┆ n ┆ ┆ value set ┆ ┆ s\\servici ┆ ┆ value ┆ 588e-02 │\n",
"│ ┆ ┆ ┆ (rule: ┆ ┆ ng\\Truste ┆ ┆ set: ┆ 7.1680687 │\n",
"│ ┆ ┆ ┆ Regist… ┆ ┆ dInst… ┆ ┆ RuleName: ┆ 4e-02… │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ Ta… ┆ │\n",
"│ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … │\n",
2024-07-29 12:01:01 +00:00
"│ 1023 ┆ informatio ┆ 10 ┆ Process ┆ … ┆ C:\\Progra ┆ ┆ Process ┆ [ 4.76494 │\n",
"│ ┆ n ┆ ┆ accessed ┆ ┆ m Files ┆ ┆ accessed: ┆ 990e-03 │\n",
"│ ┆ ┆ ┆ (rule: ┆ ┆ (x86)\\Mic ┆ ┆ RuleName: ┆ -3.913062 │\n",
"│ ┆ ┆ ┆ ProcessA… ┆ ┆ rosoft… ┆ ┆ - ┆ 44e-02… │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ So… ┆ │\n",
2024-07-29 12:01:01 +00:00
"│ 1024 ┆ informatio ┆ 1 ┆ Process ┆ … ┆ C:\\Window ┆ ┆ Process ┆ [-2.70790 │\n",
"│ ┆ n ┆ ┆ Create ┆ ┆ s\\System3 ┆ ┆ Create: ┆ 458e-02 │\n",
"│ ┆ ┆ ┆ (rule: Pro ┆ ┆ 2\\taskhos ┆ ┆ RuleName: ┆ 7.4740640 │\n",
"│ ┆ ┆ ┆ cessCre… ┆ ┆ tw.ex… ┆ ┆ - ┆ 8e-02… │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ Proc… ┆ │\n",
2024-07-29 12:01:01 +00:00
"│ 1025 ┆ informatio ┆ 22 ┆ Dns query ┆ … ┆ ┆ ┆ Dns ┆ [-2.38244 │\n",
"│ ┆ n ┆ ┆ (rule: ┆ ┆ ┆ ┆ query: ┆ 906e-02 │\n",
"│ ┆ ┆ ┆ DnsQuery) ┆ ┆ ┆ ┆ RuleName: ┆ 7.2345621 │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ - ┆ 9e-02… │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ ProcessId ┆ │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ … ┆ │\n",
2024-07-29 12:01:01 +00:00
"│ 1026 ┆ informatio ┆ 1 ┆ Process ┆ … ┆ C:\\Progra ┆ ┆ Process ┆ [-1.94575 │\n",
"│ ┆ n ┆ ┆ Create ┆ ┆ m Files\\R ┆ ┆ Create: ┆ 228e-02 │\n",
"│ ┆ ┆ ┆ (rule: Pro ┆ ┆ UXIM\\PLUG ┆ ┆ RuleName: ┆ 8.0403946 │\n",
"│ ┆ ┆ ┆ cessCre… ┆ ┆ Sched… ┆ ┆ - ┆ 3e-02… │\n",
"│ ┆ ┆ ┆ ┆ ┆ ┆ ┆ Proc… ┆ │\n",
"└───────┴────────────┴────────────┴────────────┴───┴───────────┴───────────┴───────────┴───────────┘\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"!pip install -q psutil\n",
"\n",
"import timeit\n",
"import psutil\n",
"import polars as pl\n",
"import torch\n",
"\n",
"# Function to process the DataFrame and measure memory usage\n",
"def process_and_measure(df):\n",
" # Reset GPU memory peak stats\n",
" if torch.cuda.is_available():\n",
" torch.cuda.reset_peak_memory_stats()\n",
"\n",
" temp_df = df.with_columns(\n",
" pl.col(\"filtered_message\").map_elements(lambda x: vectorize_text(x).flatten(), return_dtype=pl.Object).alias(\"message_vector\")\n",
" )\n",
"\n",
" # Get GPU memory usage\n",
" gpu_memory_usage = torch.cuda.max_memory_allocated() / (1024 ** 2) if torch.cuda.is_available() else 0\n",
" return gpu_memory_usage\n",
"\n",
"# Number of times to run the function\n",
"n_runs = 10\n",
"\n",
"# Measure the runtime, CPU, and GPU memory usage over multiple runs\n",
"runtimes = []\n",
"cpu_memory_usages = []\n",
"gpu_memory_usages = []\n",
"\n",
"for _ in range(n_runs):\n",
" # Record initial CPU memory usage\n",
" # psutil, because JAX doesn't work well with the memory_profiler\n",
" initial_memory = psutil.Process().memory_info().rss / (1024 ** 2)\n",
"\n",
" # Measure runtime\n",
" runtime = timeit.timeit(lambda: process_and_measure(better_columns_df), number=1)\n",
" runtimes.append(runtime)\n",
"\n",
" # Record final CPU memory usage\n",
" final_memory = psutil.Process().memory_info().rss / (1024 ** 2)\n",
" cpu_memory_usages.append(final_memory - initial_memory)\n",
"\n",
" # Measure GPU memory usage\n",
" gpu_mem_usage = process_and_measure(better_columns_df)\n",
" gpu_memory_usages.append(gpu_mem_usage)\n",
"\n",
"# Calculate average runtime, CPU memory, and GPU memory usage\n",
"average_runtime = np.mean(runtimes)\n",
"average_cpu_memory_usage = np.mean(cpu_memory_usages)\n",
"average_gpu_memory_usage = np.mean(gpu_memory_usages)\n",
"\n",
"print(f\"Average runtime over {n_runs} runs: {average_runtime:.6f} seconds\")\n",
"print(f\"Average CPU memory usage over {n_runs} runs: {average_cpu_memory_usage:.2f} MiB\")\n",
"print(f\"Average GPU memory usage over {n_runs} runs: {average_gpu_memory_usage:.2f} MiB\")\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "0IWHrarbUvJN",
"outputId": "baf1b651-cd96-494a-aa01-f2c3f6e677c4"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Average runtime over 10 runs: 32.509051 seconds\n",
"Average CPU memory usage over 10 runs: 0.00 MiB\n",
"Average GPU memory usage over 10 runs: 739.78 MiB\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# Assuming vector_df has a column 'message_vector' with the vectors\n",
"longformer_vectors = np.stack(longformer_vector_df['message_vector'].to_numpy()) # Stack the vectors into a 2D array"
],
"metadata": {
"id": "LCnkFlrTB2qr"
},
2024-07-29 12:01:01 +00:00
"execution_count": 31,
"outputs": []
},
{
"cell_type": "code",
"source": [
"import umap.umap_ as umap # Correct import statement for UMAP\n",
"\n",
"# Reduce dimensions using UMAP\n",
"umap_reducer = umap.UMAP(n_components=2, random_state=42)\n",
"umap_results = umap_reducer.fit_transform(longformer_vectors)\n",
"\n",
"# Plotting the results\n",
"import matplotlib.pyplot as plt\n",
"\n",
"plt.figure(figsize=(10, 8))\n",
"plt.scatter(umap_results[:, 0], umap_results[:, 1], s=5)\n",
"plt.title('UMAP Visualization of Vectors')\n",
"plt.xlabel('UMAP component 1')\n",
"plt.ylabel('UMAP component 2')\n",
"plt.show()\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 754
},
"id": "I_O0HEDwAC9T",
"outputId": "8e1647bb-bf3a-4831-acaa-d4031a8b14b1"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stderr",
"text": [
"/usr/local/lib/python3.10/dist-packages/umap/umap_.py:1945: UserWarning: n_jobs value 1 overridden to 1 by setting random_state. Use no seed for parallelism.\n",
" warn(f\"n_jobs value {self.n_jobs} overridden to 1 by setting random_state. Use no seed for parallelism.\")\n"
]
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1000x800 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA1UAAAK9CAYAAADMn0adAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABr80lEQVR4nO3deVxU9f7H8fewg8ggboiiIOZuWq64U6SW1a20RTO1XSMrbfVWav0qTVs0M1s07FZm18q6t7rlRpqJ5lpZrijuaIaAKIvDnN8fXuY6Mih4gBng9Xw85vHwfM93znxmDiBvvt/zPRbDMAwBAAAAAC6Kl7sLAAAAAIDKjFAFAAAAACYQqgAAAADABEIVAAAAAJhAqAIAAAAAEwhVAAAAAGACoQoAAAAATCBUAQAAAIAJhCoAAAAAMIFQBQAoNYvFokmTJrm7jCJ1zJs3TxaLRampqRVah7tet7SmTZumpk2bytvbWx06dHB3OQBQZRCqAKCUJk2aJIvFomPHjrnc37ZtW/Xt29exnZqaKovFIovFohdeeMHlc26//XZZLBYFBwcX+7pdunSRxWLR7NmzXe4v/MW+8BEQEKDmzZvrwQcf1JEjR4o97hdffCGLxaI5c+YU22fJkiWyWCx64403iu1THbz00kv68ssv3V3GRVm8eLGeeOIJ9ejRQ4mJiXrppZeK9Dl9+rTq1Kmjnj17FnscwzAUGRmpyy+/vEzrO3TokCZNmqTNmzeX6XEBoCIQqgCgggQEBOiTTz4p0n7y5El99dVXCggIKPa5O3fu1Lp16xQVFaWPP/74vK/z/PPP68MPP9Sbb76p7t27a/bs2YqNjdWpU6dc9h84cKCsVqvmz59f7DHnz58vb29v3XbbbZKknJwcPfPMM+etwx3uuOMO5eTkqEmTJuVy/OJCVXm/bllYvny5vLy8NHfuXA0fPlzXXHNNkT6+vr66+eabtXr1au3du9flcVauXKkDBw5o2LBhZVrfoUOH9NxzzxGqAFRKhCoAqCDXXHON/vjjD/3yyy9O7V999ZXy8/N11VVXFfvcjz76SPXq1dOrr76q1atXn3ea2dVXX61hw4bpnnvu0bx58/TII49oz549+uqrr1z29/f31+DBg7VixQodOnSoyP7c3FwtWrRIV111lerVqyfpTED08fEpwbuuWN7e3goICJDFYqkWr1saR48eVWBgoPz8/M7b7/bbb5dhGC7/ACCdCdheXl6OgO3pTp486e4SAFQDhCoAqCCxsbGKjo4uMiL08ccfa8CAAQoLCyv2ufPnz9fgwYN17bXXXnBU6VxXXHGFJGnPnj3F9hk2bJjsdrsWLFhQZN8333yjzMxM3X777Y62c69lOnHihB555BFFRUXJ399f9erV01VXXaWNGzc6+kRFRWnkyJFFjt+3b1+n6ZL5+fmaMGGCOnbsKKvVqho1aqhXr15KSkq64Hs999qmwqmarh5n1/LKK6+oe/fuql27tgIDA9WxY0d99tlnTse2WCw6efKkPvjggyLHKO6aqrfeektt2rSRv7+/IiIilJCQoIyMjCLvv23btvrjjz8UFxenoKAgNWzYUFOnTr3g+5Ukm82m//u//1NMTIz8/f0VFRWlv//978rLy3OqPTExUSdPnnTUPm/ePJfH69Gjh6Kiolx+jZ0+fVqfffaZ4uLiFBERIUnatm2bBg8erLCwMAUEBKhTp07617/+VeS5GRkZGjt2rONrpFGjRho+fLiOHTumH374QZ07d5Yk3XnnnS5rXLhwoTp27KjAwEDVqVNHw4YN08GDB51eY+TIkQoODlZKSoquueYa1axZ0/F1u3PnTg0aNEjh4eEKCAhQo0aNdNtttykzM7NEnzMAnA+hCgAq0JAhQ7RgwQIZhiFJOnbsmBYvXqyhQ4cW+5y1a9dq165dGjJkiPz8/HTTTTddcArg2VJSUiRJtWvXLrZP79691ahRI5e/SM+fP19BQUG64YYbin3+qFGjNHv2bA0aNEhvvfWWHnvsMQUGBmrr1q0lrrNQVlaW5syZo759++rll1/WpEmT9Oeff6p///6lnhp200036cMPP3R6PPLII5LkGHWTpBkzZuiyyy7T888/r5deekk+Pj66+eab9c033zj6fPjhh/L391evXr0cx7r//vuLfe1JkyYpISFBERERevXVVzVo0CC988476tevn06fPu3U9/jx4xowYIDat2+vV199VS1bttSTTz6p//znPxd8j/fcc48mTJigyy+/XK+//rr69OmjyZMnO40kffjhh+rVq5f8/f0dtffu3dvl8SwWi4YOHarffvtNv//+u9O+7777Tunp6Y6g8vvvv6tbt27aunWrnnrqKb366quqUaOGbrjhBi1atMjxvOzsbPXq1UszZ85Uv379NGPGDI0aNUrbtm3TgQMH1KpVKz3//POSpPvuu69IjfPmzdMtt9wib29vTZ48Wffee6+++OIL9ezZs0hItdls6t+/v+rVq6dXXnlFgwYNUn5+vvr37681a9ZozJgxmjVrlu677z7t3r27yPMB4KIYAIBSmThxoiHJ+PPPP13ub9OmjdGnTx/H9p49ewxJxrRp04wtW7YYkowff/zRMAzDmDVrlhEcHGycPHnSGDFihFGjRo0ix3vwwQeNyMhIw263G4ZhGIsXLzYkGZs2bXLql5iYaEgyli5davz555/G/v37jQULFhi1a9c2AgMDjQMHDpz3fT3++OOGJGP79u2OtszMTCMgIMAYMmSIU19JxsSJEx3bVqvVSEhIOO/xmzRpYowYMaJIe58+fZw+L5vNZuTl5Tn1OX78uFG/fn3jrrvuOm8dhZ/Bnj17XNbw559/Go0bNzbatWtnZGdnO9pPnTrl1C8/P99o27atccUVVzi116hRw+V7OPd1jx49avj5+Rn9+vUzCgoKHP3efPNNQ5Lx/vvvO71/ScY//vEPR1teXp4RHh5uDBo0yOX7KLR582ZDknHPPfc4tT/22GOGJGP58uWOtuK+vlz5/fffDUnG+PHjndpvu+02IyAgwMjMzDQMwzCuvPJKo127dkZubq6jj91uN7p3725ccskljrYJEyYYkowvvviiyGsVfl2vW7fOkGQkJiY67c/Pzzfq1atntG3b1sjJyXG0f/3114YkY8KECU7vUZLx1FNPOR1j06ZNhiRj4cKFJXr/AFBajFQBQAVq06aNLr30Usf1KvPnz9ff/vY3BQUFuexvs9n06aef6tZbb3Vcr3PFFVeoXr16xY5WxcfHq27duoqMjNRtt92m4OBgLVq0SA0bNjxvbYULD5w9WvX5558rNzfXaeqfK6GhoVq7dq3La7JKy9vb23Hdj91uV3p6umw2mzp16uQ0nbC0CgoKNGTIEJ04cUKLFi1SjRo1HPsCAwMd/z5+/LgyMzPVq1evi369pUuXKj8/X4888oi8vP73X+29996rkJAQpxEwSQoODnZa+MHPz09dunTR7t27z/s63377rSRp3LhxTu2PPvqoJBV5nZJq3bq1LrvsMqfpoCdPntS//vUvXXvttQoJCVF6erqWL1+uW265RSdOnNCxY8d07Ngx/fXXX+rfv7927tzpmJ73+eefq3379rrxxhuLvNaFrkNbv369jh49qgceeMBpMZeBAweqZcuWLt/j6NGjnbatVqsk6fvvvy92wRYAMINQBQDl4Hy/KA4dOlQLFy7Url27tHr16vNO/Vu8eLH+/PNPdenSRbt27dKuXbu0Z88excXF6ZNPPpHdbi/ynFmzZmnJkiVKSkrSH3/8od27d6t///4XrPnSSy9V27ZtnRYomD9/vurUqXPB50+dOlVbtmxRZGSkunTpokmTJl0wEJzPBx98oEsvvVQBAQGqXbu26tat67i262I988wzWr58uebPn6+YmBinfV9//bW6deumgIAAhYWFqW7dupo9e/ZFv17hynktWrRwavfz81PTpk2LrKzXqFGjIl8ztWrV0vHjxy/4Ol5eXmrWrJlTe3h4uEJDQ4tdwa8kbr/9du3Zs0erV6+WJH355Zc6deqUI2Dv2rVLhmHo2WefVd26dZ0eEydOl
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"from sklearn.cluster import KMeans\n",
"from sklearn.metrics import silhouette_score\n",
"\n",
"# Assuming we choose an arbitrary number of clusters, e.g., 5\n",
"kmeans = KMeans(n_clusters=2, random_state=42)\n",
"cluster_labels = kmeans.fit_predict(longformer_vectors)\n",
"\n",
"# Calculate the silhouette score\n",
"sil_score = silhouette_score(longformer_vectors, cluster_labels)\n",
"print(f\"Silhouette Score: {sil_score}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "w6hYJUe2APTt",
"outputId": "0c923d0a-8b74-4e88-deec-f3c643039b86"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Silhouette Score: 0.7029939889907837\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sklearn.metrics.pairwise import cosine_similarity\n",
"\n",
"# Assuming vector_df has a column 'message_vector' with the vectors\n",
"vectors = np.stack(longformer_vector_df['message_vector'].to_numpy()) # Stack the vectors into a 2D array\n",
"\n",
"# Calculate the cosine similarity matrix\n",
"cosine_sim_matrix = cosine_similarity(longformer_vectors)\n",
"\n",
"# Extract the upper triangle values from the cosine similarity matrix, excluding the diagonal\n",
"upper_triangle_indices = np.triu_indices_from(cosine_sim_matrix, k=1)\n",
"cosine_sim_values = cosine_sim_matrix[upper_triangle_indices]\n",
"\n",
"# Plotting the KDE of cosine similarities\n",
"plt.figure(figsize=(10, 6))\n",
"sns.kdeplot(cosine_sim_values, color='blue', fill=True)\n",
"plt.title('KDE Plot of Cosine Similarities')\n",
"plt.xlabel('Cosine Similarity')\n",
"plt.ylabel('Density')\n",
"plt.grid(True)\n",
"plt.show()\n",
"\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 564
},
"id": "wz-3_DNDBe-L",
"outputId": "486b54f0-101a-499d-d180-add867c012ae"
},
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA0kAAAIjCAYAAADWYVDIAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABrtklEQVR4nO3dd3hUZd7G8XsS0kmhJ0DoIggIilJEQaWLiooVRMDKihUrryuCioi7slgQdBUQF9RVAVcRJNJsSBMsiEiHFUJNSCWE5Lx/PDuThCSQMpMzJ/l+rmuuITNnzvllHiYz9zzluCzLsgQAAAAAkCQF2F0AAAAAAPgTQhIAAAAA5ENIAgAAAIB8CEkAAAAAkA8hCQAAAADyISQBAAAAQD6EJAAAAADIh5AEAAAAAPkQkgAAAAAgH0ISADhckyZNNHz4cLvLKGDr1q3q06ePoqOj5XK5tGDBArtLKuTSSy/VpZdeancZpzV8+HA1adLEq/s89ffetWuXXC6XZs2a5dXjjBs3Ti6Xq0Tbzpo1Sy6XS7t27fJqDQBQVoQkAFWW+4PZunXrCtx+7NgxderUSaGhoVq8eLGkvA987kt4eLgaNWqkq666SjNnzlRWVlah/Q8fPrzAY/JfQkNDz1hf/u0DAgJUv3599enTRytWrPDK779v3z6NGzdOGzdu9Mr+8hs2bJh++eUXTZgwQe+9954uuOCC026fkpKi8ePHq3379qpevbrCwsLUtm1bPfHEE9q3b5/X67PboUOH9OCDD6pVq1YKCwtT3bp11alTJz3xxBNKS0uzuzyfeeGFF/wyMAPAqarZXQAA+JOUlBT16dNHP//8s+bPn69+/foVuH/atGmqXr26srKy9Oeff+rLL7/U7bffrilTpujzzz9XfHx8ge1DQkL09ttvFzpOYGBgierp3bu3brvtNlmWpZ07d+qNN97Q5ZdfroULF6p///5l/0VlQtL48ePVpEkTdejQoVz7yi8zM1OrVq3SU089pfvuu++M2+/YsUO9evXSnj17dMMNN+juu+9WcHCwfv75Z73zzjuaP3++/vjjD6/V57ZkyRKv77Mkjh49qgsuuEApKSm6/fbb1apVKx05ckQ///yzpk2bpr/85S+qXr26JOmf//yncnNzvXr8ivq9//rXv+rJJ58scNsLL7yg66+/Xtdcc02B24cOHaqbb75ZISEhFVIbAJwJIQkA/ic1NVV9+/bVxo0bNW/evCJDyPXXX6/atWt7fh47dqzmzJmj2267TTfccIN++OGHAttXq1ZNt956a5lratmyZYHHX3vttTr33HM1ZcqUcockXzl06JAkKSYm5ozbnjx5Utddd50OHDigFStW6OKLLy5w/4QJEzRp0iRflKng4GCf7PdM3nnnHe3Zs0ffffedLrroogL3paSkFKgrKCjI68f39e+dnp6uiIgIVatWTdWqlexjRmBgYIm/OACAisBwOwCQlJaWpn79+unHH3/UJ598ogEDBpT4sUOGDNGdd96p1atXKyEhwYdVSu3atVPt2rW1c+fO0263Y8cO3XDDDapZs6bCw8PVpUsXLVy40HP/ihUrdOGFF0qSRowY4RnWd6Z5KRs2bFD//v0VFRWl6tWrq2fPngWC4bhx49S4cWNJ0mOPPSaXy3XaOTWffPKJfvrpJz311FOFApIkRUVFacKECQVu++ijj9SxY0eFhYWpdu3auvXWW/Xnn38W2CYxMVEjRoxQw4YNFRISori4OA0cOLDAnJdT5+asWLFCLpdL//73vzVhwgQ1bNhQoaGh6tmzp7Zt21aottWrV6tfv36Kjo5WeHi4evTooe++++50T58kafv27QoMDFSXLl2K/H3zD8U8dU6Se/7Q3//+d02dOlXNmjVTeHi4+vTpo71798qyLD333HNq2LChwsLCNHDgQB09erTAMUoyF+vnn3/W8OHD1axZM4WGhio2Nla33367jhw5UmA79zDU3377TYMHD1aNGjU87XjqnCSXy6X09HS9++67nv9v7rl0xc1JWrRokS655BJFREQoMjJSAwYM0KZNmwpsU5K2BoDSoicJQJWXnp6u/v37a+3atfr444915ZVXlnofQ4cO1VtvvaUlS5aod+/eBe47fPhwoe2Dg4MVFRVV6uMkJSUpKSlJLVq0KHabAwcO6KKLLlJGRoYeeOAB1apVS++++66uvvpqffzxx7r22mvVunVrPfvssxo7dqzuvvtuXXLJJZJUqGcjv02bNumSSy5RVFSUHn/8cQUFBenNN9/UpZdeqpUrV6pz58667rrrFBMTo4cffli33HKLrrjiCs/QsaL85z//kWSev5KYNWuWRowYoQsvvFATJ07UgQMH9Morr+i7777Thg0bPL1XgwYN0qZNm3T//ferSZMmOnjwoBISErRnz54zLoTw4osvKiAgQI8++qiOHTuml156SUOGDNHq1as92yxbtkz9+/dXx44d9cwzzyggIEAzZ87U5Zdfrm+++UadOnUqdv+NGzdWTk6O3nvvPQ0bNqxEv/ep5syZoxMnTuj+++/X0aNH9dJLL+nGG2/U5ZdfrhUrVuiJJ57Qtm3b9Nprr+nRRx/VjBkzSrX/hIQE7dixQyNGjFBsbKw2bdqkt956S5s2bdIPP/xQaEGGG264QWeddZZeeOEFWZZV5D7fe+893XnnnerUqZPuvvtuSVLz5s2LrcH9/PTt21eTJk1SRkaGpk2bposvvlgbNmzwtGN52hoAimUBQBU1c+ZMS5LVuHFjKygoyFqwYEGx2z7zzDOWJOvQoUNF3p+UlGRJsq699lrPbcOGDbMkFXnp27fvGeuTZN1xxx3WoUOHrIMHD1qrV6+2evbsaUmyXn75Zc92jRs3toYNG+b5+aGHHrIkWd98843nttTUVKtp06ZWkyZNrJycHMuyLGvt2rWWJGvmzJlnrMWyLOuaa66xgoODre3bt3tu27dvnxUZGWl1797dc9vOnTstSdbf/va3M+7zvPPOs6Kjo0t0/BMnTlh169a12rZta2VmZnpu//zzzy1J1tixYy3LymuLMx2/R48eVo8ePTw/L1++3JJktW7d2srKyvLc/sorr1iSrF9++cWyLMvKzc21zjrrLKtv375Wbm6uZ7uMjAyradOmVu/evU973MTERKtOnTqWJKtVq1bWyJEjrblz51rJycmFth02bJjVuHFjz8/u57ZOnToFth8zZowlyWrfvr2VnZ3tuf2WW26xgoODrePHjxf7e7v3mf//QUZGRqFa3n//fUuS9fXXX3tuc78ubrnllkLbu+/LLyIiosD/VTf3a3Hnzp2WZZn/rzExMdZdd91VYLvExEQrOjrac3tJ2xoASovhdgCqvAMHDig0NLTQogul4e4tSU1NLXB7aGioEhISCl1efPHFEu33nXfeUZ06dVS3bl117txZ3333nUaPHq2HHnqo2Md88cUX6tSpU4Hha9WrV9fdd9+tXbt26bfffiv175eTk6MlS5bommuuUbNmzTy3x8XFafDgwfr222+VkpJS6v2mpKQoMjKyRNuuW7dOBw8e1L333ltgSNqAAQPUqlUrz3DCsLAwBQcHa8WKFUpKSip1TSNGjCgwb8fdy7Zjxw5J0saNG7V161YNHjxYR44c0eHDh3X48GGlp6erZ8+e+vrrr0+72EK9evX0008/aeTIkUpKStL06dM1ePBg1a1bV88991yxPTH53XDDDYqOjvb83LlzZ0nSrbfeWmAeUOfOnXXixIlCwxHPJCwszPPv48eP6/Dhw57hgT/++GOh7UeOHFmq/Z9JQkKCkpOTdcstt3ie38OHDyswMFCdO3fW8uXLPXWWp60BoDgMtwNQ5b355psaPXq0+vXrp2+++UZnn312qffhXrb51A/8gYGB6tWrV5lrGzhwoO677z65XC5FRkaqTZs2ioiIOO1jdu/e7fnQnF/r1q0997dt27ZUdRw6dEgZGRlFPjetW7dWbm6u9u7dqzZt2pRqv1FRUZ7wcSa7d++WpCJraNWqlb799ltJZkXBSZMm6ZFHHlG9evXUpUsXXXnllbrtttsUGxt7xuM0atSow
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"from sklearn.decomposition import PCA\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"# Assuming `X` is the matrix of extracted features from your deep learning model\n",
"# Replace this with your actual feature matrix\n",
"X = linformer_vectors\n",
"\n",
"# Scale the features before applying PCA\n",
"scaler = StandardScaler()\n",
"X_scaled = scaler.fit_transform(X)\n",
"\n",
"# Apply PCA to retain 95% of the variance\n",
"pca = PCA(n_components=0.95)\n",
"X_pca = pca.fit_transform(X_scaled)\n",
"\n",
"# Output the results\n",
"num_features_before = X.shape[1]\n",
"num_features_after = X_pca.shape[1]\n",
"explained_variance_ratio = pca.explained_variance_ratio_\n",
"total_explained_variance = np.sum(explained_variance_ratio)\n",
"\n",
2024-07-29 12:01:01 +00:00
"print(\"For Standard Scaler\")\n",
"print(f\"Number of features before PCA: {num_features_before}\")\n",
"print(f\"Number of features after PCA: {num_features_after}\")\n",
"print(f\"Explained variance ratio of each component: {explained_variance_ratio}\")\n",
"print(f\"Total explained variance: {total_explained_variance}\")\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "pf6Ff46rJenw",
2024-07-29 12:01:01 +00:00
"outputId": "652570a0-2de1-40c2-aaf5-a106d87ff2c2"
},
2024-07-29 12:01:01 +00:00
"execution_count": 23,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Number of features before PCA: 30000\n",
"Number of features after PCA: 16\n",
"Explained variance ratio of each component: [0.7322417 0.10258386 0.01919063 0.01368955 0.01286633 0.01047001\n",
" 0.00869492 0.00842645 0.00745646 0.00658019 0.00584541 0.00514998\n",
" 0.00491944 0.00477372 0.00384961 0.00355314]\n",
"Total explained variance: 0.9502913951873779\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"from sklearn.decomposition import PCA\n",
"from sklearn.preprocessing import RobustScaler\n",
"\n",
"# Assuming `X` is the matrix of extracted features from your deep learning model\n",
"# Replace this with your actual feature matrix\n",
"X = linformer_vectors\n",
"\n",
"# Scale the features before applying PCA\n",
"scaler = RobustScaler()\n",
"X_scaled = scaler.fit_transform(X)\n",
"\n",
"# Apply PCA to retain 95% of the variance\n",
"pca = PCA(n_components=0.95)\n",
"X_pca = pca.fit_transform(X_scaled)\n",
"\n",
"# Output the results\n",
"num_features_before = X.shape[1]\n",
"num_features_after = X_pca.shape[1]\n",
"explained_variance_ratio = pca.explained_variance_ratio_\n",
"total_explained_variance = np.sum(explained_variance_ratio)\n",
"\n",
"print(\"For Robust Scaler:\")\n",
"print(f\"Number of features before PCA: {num_features_before}\")\n",
"print(f\"Number of features after PCA: {num_features_after}\")\n",
"print(f\"Explained variance ratio of each component: {explained_variance_ratio}\")\n",
"print(f\"Total explained variance: {total_explained_variance}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "eK8wgP0IRPnr",
"outputId": "8758e5dd-9616-4faf-9fd5-9e87c9625784"
},
"execution_count": 24,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"For Robust Scaler:\n",
"Number of features before PCA: 30000\n",
"Number of features after PCA: 12\n",
"Explained variance ratio of each component: [0.780261 0.10317536 0.0133233 0.00962587 0.00900306 0.00735111\n",
" 0.00611507 0.00592202 0.00531727 0.00460212 0.00413254 0.00365149]\n",
"Total explained variance: 0.952480137348175\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"from sklearn.preprocessing import QuantileTransformer\n",
"from sklearn.decomposition import PCA\n",
"import numpy as np\n",
"\n",
"# Assuming X is your feature matrix\n",
"X = linformer_vectors\n",
"\n",
"# Initialize and fit QuantileTransformer\n",
"qt = QuantileTransformer(n_quantiles=1000, output_distribution='normal')\n",
"X_scaled = qt.fit_transform(X)\n",
"\n",
"# Apply PCA\n",
"pca = PCA(n_components=0.95)\n",
"X_pca = pca.fit_transform(X_scaled)\n",
"\n",
"# Output results\n",
"num_features_before = X.shape[1]\n",
"num_features_after = X_pca.shape[1]\n",
"explained_variance_ratio = pca.explained_variance_ratio_\n",
"total_explained_variance = np.sum(explained_variance_ratio)\n",
"\n",
"print(\"For QuantileTransformer:\")\n",
"print(f\"Number of features before PCA: {num_features_before}\")\n",
"print(f\"Number of features after PCA: {num_features_after}\")\n",
"print(f\"Explained variance ratio of each component: {explained_variance_ratio}\")\n",
"print(f\"Total explained variance: {total_explained_variance}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "GNpC0dqQR0c1",
"outputId": "ecfdbd8a-f243-4d20-b52f-f71e2d88f372"
},
"execution_count": 25,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
2024-07-29 12:01:01 +00:00
"For QuantileTransformer:\n",
"Number of features before PCA: 30000\n",
2024-07-29 12:01:01 +00:00
"Number of features after PCA: 44\n",
"Explained variance ratio of each component: [0.5520516 0.06681303 0.03466679 0.02603397 0.02438432 0.01991471\n",
" 0.01807965 0.01676909 0.01452703 0.01407685 0.01204457 0.01079712\n",
" 0.01011274 0.00953582 0.00818911 0.00781149 0.00750947 0.00684127\n",
" 0.00666759 0.00635126 0.0060087 0.00576267 0.00530005 0.00508402\n",
" 0.0044879 0.00436854 0.00426688 0.0040915 0.00363791 0.00351764\n",
" 0.00315828 0.00294013 0.00270898 0.00254636 0.0024569 0.00226332\n",
" 0.0022466 0.00212357 0.00194851 0.00179851 0.0016702 0.00165001\n",
" 0.00159225 0.00152319]\n",
"Total explained variance: 0.9503300786018372\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"from sklearn.decomposition import PCA\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"# Assuming `X` is the matrix of extracted features from your deep learning model\n",
"# Replace this with your actual feature matrix\n",
"X = longformer_vectors\n",
"\n",
"# Scale the features before applying PCA\n",
"scaler = StandardScaler()\n",
"X_scaled = scaler.fit_transform(X)\n",
"\n",
"# Apply PCA to retain 95% of the variance\n",
"pca = PCA(n_components=0.95)\n",
"X_pca = pca.fit_transform(X_scaled)\n",
"\n",
"# Output the results\n",
"num_features_before = X.shape[1]\n",
"num_features_after = X_pca.shape[1]\n",
"explained_variance_ratio = pca.explained_variance_ratio_\n",
"total_explained_variance = np.sum(explained_variance_ratio)\n",
"\n",
"print(f\"Number of features before PCA: {num_features_before}\")\n",
"print(f\"Number of features after PCA: {num_features_after}\")\n",
"print(f\"Explained variance ratio of each component: {explained_variance_ratio}\")\n",
"print(f\"Total explained variance: {total_explained_variance}\")\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "TeaiZpy_H67h",
"outputId": "91fb9f4c-6c74-4511-9f22-9ce8192bf948"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Number of features before PCA: 768\n",
"Number of features after PCA: 12\n",
"Explained variance ratio of each component: [0.61036086 0.20625743 0.0349943 0.02877296 0.01811446 0.01343036\n",
" 0.01034578 0.00865575 0.00768904 0.0061328 0.00457506 0.00375366]\n",
"Total explained variance: 0.9530825018882751\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"import numpy as np\n",
"from sklearn.decomposition import PCA\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"# X is the features matrix\n",
"x = linformer_vectors\n",
"\n",
"\n",
"# Standardize the features\n",
"scaler = StandardScaler()\n",
"X_scaled = scaler.fit_transform(X)\n",
"\n",
"# Apply PCA\n",
"pca = PCA(n_components=0.95)\n",
"X_pca = pca.fit_transform(X_scaled)\n",
"explained_variance_ratio = pca.explained_variance_ratio_\n",
"cumulative_variance = np.cumsum(explained_variance_ratio)\n",
"\n",
"# Plotting the Scree plot\n",
"plt.figure(figsize=(10, 4))\n",
"plt.bar(range(1, len(explained_variance_ratio) + 1), explained_variance_ratio, alpha=0.6, color='g', label='Individual explained variance')\n",
"plt.ylabel('Explained variance ratio')\n",
"plt.xlabel('Principal components')\n",
"plt.title('Scree Plot')\n",
"plt.legend(loc='best')\n",
"plt.tight_layout()\n",
"\n",
"# Plotting the Cumulative Variance plot\n",
"plt.figure(figsize=(10, 4))\n",
"plt.plot(range(1, len(cumulative_variance) + 1), cumulative_variance, marker='o', linestyle='-', color='b', label='Cumulative explained variance')\n",
"plt.xlabel('Number of components')\n",
"plt.ylabel('Cumulative explained variance')\n",
"plt.title('Cumulative Variance Plot')\n",
"plt.axhline(y=0.95, color='r', linestyle='--', label='95% Explained Variance')\n",
"plt.legend(loc='best')\n",
"plt.tight_layout()\n",
"\n",
"# Displaying the plots\n",
"plt.show()\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 797
},
"id": "qOa9OTV6DpOK",
"outputId": "cfede8d1-e334-426f-c517-186035938323"
},
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1000x400 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA90AAAGGCAYAAABmGOKbAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABOc0lEQVR4nO3de3zP9f//8ft7Y5udMTY025xizjaEj0N9phURHRwaY5V85NhSLDnHVA5L5BhT8aE+IaVIc6iYQ5gIM8dJ5pBYm2zs/f794ef97d029mZvb5vb9XJ5XS7ez9fz9Xo9Xi/vDvf36/l6vgwmk8kkAAAAAABQ6BzsXQAAAAAAAMUVoRsAAAAAABshdAMAAAAAYCOEbgAAAAAAbITQDQAAAACAjRC6AQAAAACwEUI3AAAAAAA2QugGAAAAAMBGCN0AAAAAANgIoRsAANw18fHxMhgMOn78uL1LAQDgriB0AwBgJ3v37tUzzzyjgIAAubi4qFKlSmrbtq3ef/99e5dmtTFjxshgMJgXV1dXBQcH680331R6enqhHGPJkiWKi4srlH0BAHC3lLB3AQAA3I+2bNmihx9+WJUrV1afPn3k5+enkydPauvWrXrvvfc0cOBAe5d4W2bNmiV3d3dlZGTo22+/1YQJE7R+/Xpt3rxZBoPhjva9ZMkS7du3T0OGDCmcYgEAuAsI3QAA2MGECRPk5eWlHTt2yNvb22Ld2bNn73j/JpNJV65cUalSpe54X9Z45pln5OPjI0n6z3/+o6efflrLly/X1q1b1axZs7taCwAA9wKGlwMAYAdHjhxR7dq1cwVuSSpfvnyutk8++URNmjSRq6urSpcurVatWunbb781rw8MDNQTTzyhtWvXKjQ0VKVKldKcOXMkSRcvXtSQIUPk7+8vZ2dnVatWTW+//baMRqPFMYxGo+Li4lS7dm25uLjI19dXffv21R9//HHb5/nII49Iko4dO3bTfh988IFq164tZ2dnVaxYUf3799fFixfN69u0aaPVq1frxIkT5iHsgYGBt10XAAB3C3e6AQCwg4CAACUmJmrfvn2qU6fOTfuOHTtWY8aMUfPmzTVu3Dg5OTlp27ZtWr9+vR599FFzv+TkZHXv3l19+/ZVnz599OCDD+ry5ctq3bq1Tp06pb59+6py5crasmWLYmJidPr0aYtnpPv27av4+HhFRUVp0KBBOnbsmGbMmKHdu3dr8+bNKlmypNXneeTIEUlS2bJl8+0zZswYjR07VmFhYerXr5+Sk5M1a9Ys7dixw3zcESNG6NKlS/r11181bdo0SZK7u7vV9QAAcNeZAADAXfftt9+aHB0dTY6OjqZmzZqZXn/9ddPatWtN2dnZFv1SUlJMDg4Ops6dO5tycnIs1hmNRvOfAwICTJJMa9assegzfvx4k5ubm+nQoUMW7cOHDzc5OjqaUlNTTSaTyfTDDz+YJJkWL15s0W/NmjV5tv/T6NGjTZJMycnJpnPnzpmOHTtmmjNnjsnZ2dnk6+tryszMNJlMJtPChQtNkkzHjh0zmUwm09mzZ01OTk6mRx991OL8ZsyYYZJkWrBggbmtffv2poCAgJvWAQDAvYbh5QAA2EHbtm2VmJiojh07as+ePXrnnXcUHh6uSpUqadWqVeZ+K1eulNFo1KhRo+TgYPmf7X9OTBYUFKTw8HCLts8++0wtW7ZU6dKldf78efMSFhamnJwcff/99+Z+Xl5eatu2rUW/kJAQubu7a8OGDQU6rwcffFDlypVTUFCQ+vbtq2rVqmn16tVydXXNs/93332n7OxsDRkyxOL8+vTpI09PT61evbpAxwUA4F7F8HIAAOykcePGWr58ubKzs7Vnzx6tWLFC06ZN0zPPPKOkpCQFBwfryJEjcnBwUHBw8C33FxQUlKstJSVFP//8s8qVK5fnNjcmbUtJSdGlS5fyfJ787/1u5fPPP5enp6dKliypBx54QFWrVr1p/xMnTki6Htb/zsnJSVWqVDGvBwCgqCJ0AwBgZ05OTmrcuLEaN26sGjVqKCoqSp999plGjx5t1X7ymqncaDSqbdu2ev311/PcpkaNGuZ+5cuX1+LFi/Psl19o/6dWrVqZZy8HAACEbgAA7imhoaGSpNOnT0uSqlatKqPRqP3796tBgwZW769q1arKyMhQWFjYLft99913atGixV19zVhAQICk65PAValSxdyenZ2tY8eOWdR9p+/5BgDAHnimGwAAO9iwYYNMJlOu9q+//lrS/w237tSpkxwcHDRu3Lhcr/jKa/t/6tKlixITE7V27dpc6y5evKhr166Z++Xk5Gj8+PG5+l27ds3i9V2FKSwsTE5OTpo+fbrF+Xz44Ye6dOmS2rdvb25zc3PTpUuXbFIHAAC2wp1uAADsYODAgbp8+bI6d+6smjVrKjs7W1u2bNGyZcsUGBioqKgoSVK1atU0YsQIjR8/Xi1bttRTTz0lZ2dn7dixQxUrVlRsbOxNj/Paa69p1apVeuKJJ9S7d2+FhIQoMzNTe/fu1f/+9z8dP35cPj4+at26tfr27avY2FglJSXp0UcfVcmSJZWSkqLPPvtM7733np555plCvw7lypVTTEyMxo4dq8cee0wdO3ZUcnKyPvjgAzVu3Fg9evQw9w0JCdGyZcsUHR2txo0by93dXR06dCj0mgAAKEyEbgAA7GDy5Mn67LPP9PXXX2vu3LnKzs5W5cqV9fLLL+vNN9+Ut7e3ue+4ceMUFBSk999/XyNGjJCrq6vq1aunnj173vI4rq6u2rRpkyZOnKjPPvtMH330kTw9PVWjRg2NHTtWXl5e5r6zZ89WSEiI5syZozfeeEMlSpRQYGCgevTooRYtWtjiMki6/p7ucuXKacaMGXrllVdUpkwZvfTSS5o4caLFu8FffvllJSUlaeHChZo2bZoCAgII3QCAe57BVJCxaQAAAAAAwGo80w0AAAAAgI0QugEAAAAAsBFCNwAAAAAANkLoBgAAAADARgjdAAAAAADYCKEbAAAAAAAbue/e0200GvXbb7/Jw8NDBoPB3uUAAAAAAIogk8mkP//8UxUrVpSDQ/73s++70P3bb7/J39/f3mUAAAAAAIqBkydP6oEHHsh3/X0Xuj08PCRdvzCenp52rgYAAAAAUBSlp6fL39/fnDHzc9+F7htDyj09PQndAAAAAIA7cqvHlplIDQAAAAAAGyF0AwAAAABgI4RuAAAAAABs5L57phsAAAC4E0ajUdnZ2fYuA4CNlSxZUo6Ojne8H0I3AAAAUEDZ2dk6duyYjEajvUsBcBd4e3vLz8/vlpOl3QyhGwAAACgAk8mk06dPy9HRUf7+/nJw4ElNoLgymUy6fPmyzp49K0mqUKHCbe+L0A0AAAAUwLVr13T58mVVrFhRrq6u9i4HgI2VKlVKknT27FmVL1/+toea8/McAAAAUAA5OTmSJCcnJztXAuBuufED29WrV297H4RuAAAAwAp38mwngKKlMP55J3QDAAAAAGAjhG4AAAAAt2QwGLRy5UpJ0vHjx2UwGJSUlHRb2+fldvZZEIGBgYqLiyvUfVpr48aNMhgMunjxYoG3adOmjYYMGWKzmm7o3bu3OnXqZPPj5OVunaO9MZEaAAAAcAf6ftn3rh5vToc5VvXv3bu3Ll68eNPAay1/f3+dPn1aPj4+Bd7m9OnTKl26dKHVUNwtX75cJUuWtHcZNnU/nKN0D4TumTNn6t1331VaWprq16+v999/X02aNMm3/8WLFzVixAgtX75cFy5cUEBAgOLi4tSuXbu7WPXdcbf/BW4P1v5HAwAAAPbn6OgoPz8/q7axtv/9rkyZMvYuwWays7Pl5ORUrM/x7+w6vHzZsmWKjo7W6NGjtWvXLtWvX1/h4eHmd6H9U3Z2ttq2bavjx4/rf//7n5KTkzVv3jxVqlTpLlcOAAAAFE1t2rTRoEGD9Prrr6tMmTLy8/PTmDFjLPqkpKSoVatWcnFxUXBwsNatW2ex/u9DwY1Gox544AHNmjXLos/u3bvl4OCgEydOSMo9vHz79u1q2LChXFxcFBoaqt27d1tsHx8fL29vb4u2l
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1000x400 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA90AAAGGCAYAAABmGOKbAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAACK9ElEQVR4nOzdd1hTZxsG8DsE2UNFQEAExb03BWeVilp3VVTqoI62aq2CVm3F2bqqFle1WpXWutCits5axL0X6qeiIogiiAsQZJmc749TApGACRDDuH/XlYvkPe95z5MQI0/eJREEQQARERERERERFTk9XQdAREREREREVFox6SYiIiIiIiLSEibdRERERERERFrCpJuIiIiIiIhIS5h0ExEREREREWkJk24iIiIiIiIiLWHSTURERERERKQlTLqJiIiIiIiItIRJNxEREREREZGWMOkmIiIqYsOHD4ezs3ORthkYGAiJRIKoqKgibbc4OXr0KCQSCY4eParrUN6LsvZ8iYjKKibdRERULEVERODzzz9H9erVYWRkBAsLC7Ru3RrLli1DamqqrsPTmnnz5mH37t26DgMA0LNnT5iYmODVq1d51vH29oaBgQGeP3/+HiMr3rK+IMm6GRkZoVatWhg3bhyePHlSJNfYv38/Zs2aVSRtERGRdjHpJiKiYmffvn1o2LAhgoKC0KNHD6xYsQLz589H1apVMXnyZHz99de6DlFr8kq6hwwZgtTUVDg5Ob23WLy9vZGamopdu3apPP769Wvs2bMHXbp0gZWVVaGv165dO6SmpqJdu3aFbqs4mDNnDjZt2oSVK1fC3d0dq1evhpubG16/fl3otvfv34/Zs2cXQZRERKRt+roOgIiIKKfIyEgMHDgQTk5OOHLkCOzs7BTHxo4di3v37mHfvn06jFA3pFIppFLpe71mz549YW5uji1btmDo0KG5ju/ZswcpKSnw9vYu1HXS0tJgYGAAPT09GBkZFaqt4qRr165o0aIFAGDkyJGwsrLC0qVLsWfPHgwaNEjH0RER0fvCnm4iIipWFi1ahOTkZKxfv14p4c5So0YNRU93VFQUJBIJAgMDc9WTSCRKw29nzZoFiUSCO3fu4NNPP4WlpSWsra3h7+8PQRDw8OFD9OrVCxYWFqhcuTKWLFmi1F5ec6rVnZe7ePFiuLu7w8rKCsbGxmjevDl27tyZK+aUlBT89ttviqHJw4cPV3n97t27o3r16iqv5ebmpkj2svzxxx9o3rw5jI2NUbFiRQwcOBAPHz7MN2ZjY2P07dsXISEhiI+Pz3V8y5YtMDc3R8+ePfHixQtMmjQJDRs2hJmZGSwsLNC1a1eEhYUpnZP1em3btg3Tp0+Hg4MDTExMkJSUpPK1PHHiBPr374+qVavC0NAQjo6OmDhxYq4pBsOHD4eZmRliYmLQu3dvmJmZwdraGpMmTYJMJlOqK5fLsWzZMjRs2BBGRkawtrZGly5dcPHixUK/Zvnp2LEjAPGLpfzs2LFDcd1KlSrh008/RUxMjNJzXbVqFQAoDWMnIqLiiUk3EREVK3///TeqV68Od3d3rbTv5eUFuVyOBQsWwNXVFd9//z0CAgLw0UcfwcHBAQsXLkSNGjUwadIkHD9+vMiuu2zZMjRt2hRz5szBvHnzoK+vj/79+yv12m/atAmGhoZo27YtNm3ahE2bNuHzzz/P83lERkbiwoULSuUPHjzA2bNnMXDgQEXZDz/8gKFDh6JmzZpYunQpJkyYgJCQELRr1w4JCQn5xu3t7Y03b94gKChIqfzFixc4dOgQ+vTpA2NjY9y/fx+7d+9G9+7dsXTpUkyePBnXr19H+/bt8fjx41ztzp07F/v27cOkSZMwb948GBgYqLz+jh078Pr1a3z55ZdYsWIFPD09sWLFCpU97zKZDJ6enrCyssLixYvRvn17LFmyBGvXrlWqN2LECEyYMAGOjo5YuHAhpk6dCiMjI5w9e7ZIXrO8REREAEC+Q/EDAwMxYMAASKVSzJ8/H6NGjUJwcDDatGmjuO7nn3+Ojz76CAAU75NNmzYVKCYiInoPBCIiomIiMTFRACD06tVLrfqRkZECAGHjxo25jgEQZs6cqXg8c+ZMAYAwevRoRdmbN2+EKlWqCBKJRFiwYIGi/OXLl4KxsbEwbNgwRdnGjRsFAEJkZKTSdUJDQwUAQmhoqKJs2LBhgpOTk1K9169fKz3OyMgQGjRoIHTs2FGp3NTUVOm6eV0/MTFRMDQ0FPz8/JTqLVq0SJBIJMKDBw8EQRCEqKgoQSqVCj/88INSvevXrwv6+vq5yt/25s0bwc7OTnBzc1MqX7NmjQBAOHTokCAIgpCWlibIZDKlOpGRkYKhoaEwZ84cRVnW61W9evVcr4mq1/LtOoIgCPPnz1d6joIgvuYAlK4lCILQtGlToXnz5orHR44cEQAI48ePz9WuXC4XBKHwr1nW7+rff/8Vnj59Kjx8+FDYtm2bYGVlJRgbGwuPHj1S+XwzMjIEGxsboUGDBkJqaqqivb179woAhBkzZijKxo4dK/DPOCKikoE93UREVGwkJSUBAMzNzbV2jZEjRyruS6VStGjRAoIgYMSIEYry8uXLo3bt2rh//36RXdfY2Fhx/+XLl0hMTETbtm1x+fLlArWXNXw7KCgIgiAoyrdv344PPvgAVatWBQAEBwdDLpdjwIABePbsmeJWuXJl1KxZE6GhofleRyqVYuDAgThz5ozS0PotW7bA1tYWnTp1AgAYGhpCT0/8s0Imk+H58+cwMzND7dq1VT7HYcOGKb0meclZJyUlBc+ePYO7uzsEQcCVK1dy1f/iiy+UHrdt21bp9/jnn39CIpFg5syZuc7NGqJd2Ncsi4eHB6ytreHo6IiBAwfCzMwMu3btgoODg8r6Fy9eRHx8PMaMGaM0t/3jjz9GnTp1yuRaBkREpQEXUiMiomLDwsICAPLdoqqwspLRLJaWljAyMkKlSpVylRflNlh79+7F999/j6tXryI9PV1RXpi5uF5eXti9ezfOnDkDd3d3RERE4NKlSwgICFDUuXv3LgRBQM2aNVW2Ua5cuXdex9vbGz/99BO2bNmCb7/9Fo8ePcKJEycwfvx4xeJuWfOkf/75Z0RGRirNo1Y1nLpatWpqPcfo6GjMmDEDf/31F16+fKl0LDExUelx1vzsnCpUqKB0XkREBOzt7VGxYsU8r1kUrxkArFq1CrVq1YK+vj5sbW1Ru3ZtxRcTqjx48AAAULt27VzH6tSpg5MnT6p1XSIiKl6YdBMRUbFhYWEBe3t73LhxQ636eSWsby+clZOqFcDzWhU8Zw9yQa6V5cSJE+jZsyfatWuHn3/+GXZ2dihXrhw2btyILVu2vPP8vPTo0QMmJiYICgqCu7s7goKCoKenh/79+yvqyOVySCQSHDhwQOXzNDMze+d1mjdvjjp16mDr1q349ttvsXXrVgiCoLRq+bx58+Dv74/PPvsMc+fORcWKFaGnp4cJEyZALpfnalOdXm6ZTIaPPvoIL168wJQpU1CnTh2YmpoiJiYGw4cPz9VuUa3uXhSvGQC0atUq14J2RERU9jDpJiKiYqV79+5Yu3Ytzpw5Azc3t3zrVqhQAQByLWyV1WNYlApzrT///BNGRkY4dOgQDA0NFeUbN27MVVeTnm9TU1N0794dO3bswNKlS7F9+3a0bdsW9vb2ijouLi4QBAHVqlVDrVq11G77bd7e3vD398e1a9ewZcsW1KxZEy1btlQc37lzJz788EOsX79e6byEhIRcowjUdf36ddy5cwe//fab0sJphw8fLtiTgPh6HDp0CC9evMizt7uoXjNNZe3BHh4erljpPEt4eLjSHu1crZyIqOTgnG4iIipWvvnmG5iammLkyJF48uRJruMRERFYtmwZALFnvFKlSrlWGf/555+LPC4XFxcAULqWTCbLtTK2KlKpFBKJRKlXPCoqCrt3785V19TUVKPVsb28vPD48WP8+uuvCAsLg5eXl9Lxvn37QiqVYvbs2Uo994DYk6/uEPqsXu0ZM2bg6tWrufbmlkqludrfsWOH0lZXmsrqZc7Zr
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"import numpy as np\n",
"from sklearn.decomposition import PCA\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"# X is the features matrix\n",
"x = longformer_vectors\n",
"\n",
"\n",
"# Standardize the features\n",
"scaler = StandardScaler()\n",
"X_scaled = scaler.fit_transform(X)\n",
"\n",
"# Apply PCA\n",
"pca = PCA(n_components=0.95)\n",
"X_pca = pca.fit_transform(X_scaled)\n",
"explained_variance_ratio = pca.explained_variance_ratio_\n",
"cumulative_variance = np.cumsum(explained_variance_ratio)\n",
"\n",
"# Plotting the Scree plot\n",
"plt.figure(figsize=(10, 4))\n",
"plt.bar(range(1, len(explained_variance_ratio) + 1), explained_variance_ratio, alpha=0.6, color='g', label='Individual explained variance')\n",
"plt.ylabel('Explained variance ratio')\n",
"plt.xlabel('Principal components')\n",
"plt.title('Scree Plot')\n",
"plt.legend(loc='best')\n",
"plt.tight_layout()\n",
"\n",
"# Plotting the Cumulative Variance plot\n",
"plt.figure(figsize=(10, 4))\n",
"plt.plot(range(1, len(cumulative_variance) + 1), cumulative_variance, marker='o', linestyle='-', color='b', label='Cumulative explained variance')\n",
"plt.xlabel('Number of components')\n",
"plt.ylabel('Cumulative explained variance')\n",
"plt.title('Cumulative Variance Plot')\n",
"plt.axhline(y=0.95, color='r', linestyle='--', label='95% Explained Variance')\n",
"plt.legend(loc='best')\n",
"plt.tight_layout()\n",
"\n",
"# Displaying the plots\n",
"plt.show()\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 797
},
"id": "KYRsJBkkD7ep",
"outputId": "0169dcbe-d0e5-48a8-bcce-471d83ff095c"
},
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1000x400 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA90AAAGGCAYAAABmGOKbAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAABOc0lEQVR4nO3de3zP9f//8ft7Y5udMTY025xizjaEj0N9phURHRwaY5V85NhSLDnHVA5L5BhT8aE+IaVIc6iYQ5gIM8dJ5pBYm2zs/f794ef97d029mZvb5vb9XJ5XS7ez9fz9Xo9Xi/vDvf36/l6vgwmk8kkAAAAAABQ6BzsXQAAAAAAAMUVoRsAAAAAABshdAMAAAAAYCOEbgAAAAAAbITQDQAAAACAjRC6AQAAAACwEUI3AAAAAAA2QugGAAAAAMBGCN0AAAAAANgIoRsAANw18fHxMhgMOn78uL1LAQDgriB0AwBgJ3v37tUzzzyjgIAAubi4qFKlSmrbtq3ef/99e5dmtTFjxshgMJgXV1dXBQcH680331R6enqhHGPJkiWKi4srlH0BAHC3lLB3AQAA3I+2bNmihx9+WJUrV1afPn3k5+enkydPauvWrXrvvfc0cOBAe5d4W2bNmiV3d3dlZGTo22+/1YQJE7R+/Xpt3rxZBoPhjva9ZMkS7du3T0OGDCmcYgEAuAsI3QAA2MGECRPk5eWlHTt2yNvb22Ld2bNn73j/JpNJV65cUalSpe54X9Z45pln5OPjI0n6z3/+o6efflrLly/X1q1b1axZs7taCwAA9wKGlwMAYAdHjhxR7dq1cwVuSSpfvnyutk8++URNmjSRq6urSpcurVatWunbb781rw8MDNQTTzyhtWvXKjQ0VKVKldKcOXMkSRcvXtSQIUPk7+8vZ2dnVatWTW+//baMRqPFMYxGo+Li4lS7dm25uLjI19dXffv21R9//HHb5/nII49Iko4dO3bTfh988IFq164tZ2dnVaxYUf3799fFixfN69u0aaPVq1frxIkT5iHsgYGBt10XAAB3C3e6AQCwg4CAACUmJmrfvn2qU6fOTfuOHTtWY8aMUfPmzTVu3Dg5OTlp27ZtWr9+vR599FFzv+TkZHXv3l19+/ZVnz599OCDD+ry5ctq3bq1Tp06pb59+6py5crasmWLYmJidPr0aYtnpPv27av4+HhFRUVp0KBBOnbsmGbMmKHdu3dr8+bNKlmypNXneeTIEUlS2bJl8+0zZswYjR07VmFhYerXr5+Sk5M1a9Ys7dixw3zcESNG6NKlS/r11181bdo0SZK7u7vV9QAAcNeZAADAXfftt9+aHB0dTY6OjqZmzZqZXn/9ddPatWtN2dnZFv1SUlJMDg4Ops6dO5tycnIs1hmNRvOfAwICTJJMa9assegzfvx4k5ubm+nQoUMW7cOHDzc5OjqaUlNTTSaTyfTDDz+YJJkWL15s0W/NmjV5tv/T6NGjTZJMycnJpnPnzpmOHTtmmjNnjsnZ2dnk6+tryszMNJlMJtPChQtNkkzHjh0zmUwm09mzZ01OTk6mRx991OL8ZsyYYZJkWrBggbmtffv2poCAgJvWAQDAvYbh5QAA2EHbtm2VmJiojh07as+ePXrnnXcUHh6uSpUqadWqVeZ+K1eulNFo1KhRo+TgYPmf7X9OTBYUFKTw8HCLts8++0wtW7ZU6dKldf78efMSFhamnJwcff/99+Z+Xl5eatu2rUW/kJAQubu7a8OGDQU6rwcffFDlypVTUFCQ+vbtq2rVqmn16tVydXXNs/93332n7OxsDRkyxOL8+vTpI09PT61evbpAxwUA4F7F8HIAAOykcePGWr58ubKzs7Vnzx6tWLFC06ZN0zPPPKOkpCQFBwfryJEjcnBwUHBw8C33FxQUlKstJSVFP//8s8qVK5fnNjcmbUtJSdGlS5fyfJ787/1u5fPPP5enp6dKliypBx54QFWrVr1p/xMnTki6Htb/zsnJSVWqVDGvBwCgqCJ0AwBgZ05OTmrcuLEaN26sGjVqKCoqSp999plGjx5t1X7ymqncaDSqbdu2ev311/PcpkaNGuZ+5cuX1+LFi/Psl19o/6dWrVqZZy8HAACEbgAA7imhoaGSpNOnT0uSqlatKqPRqP3796tBgwZW769q1arKyMhQWFjYLft99913atGixV19zVhAQICk65PAValSxdyenZ2tY8eOWdR9p+/5BgDAHnimGwAAO9iwYYNMJlOu9q+//lrS/w237tSpkxwcHDRu3Lhcr/jKa/t/6tKlixITE7V27dpc6y5evKhr166Z++Xk5Gj8+PG5+l27ds3i9V2FKSwsTE5OTpo+fbrF+Xz44Ye6dOmS2rdvb25zc3PTpUuXbFIHAAC2wp1uAADsYODAgbp8+bI6d+6smjVrKjs7W1u2bNGyZcsUGBioqKgoSVK1atU0YsQIjR8/Xi1bttRTTz0lZ2dn7dixQxUrVlRsbOxNj/Paa69p1apVeuKJJ9S7d2+FhIQoMzNTe/fu1f/+9z8dP35cPj4+at26tfr27avY2FglJSXp0UcfVcmSJZWSkqLPPvtM7733np555plCvw7lypVTTEyMxo4dq8cee0wdO3ZUcnKyPvjgAzVu3Fg9evQw9w0JCdGyZcsUHR2txo0by93dXR06dCj0mgAAKEyEbgAA7GDy5Mn67LPP9PXXX2vu3LnKzs5W5cqV9fLLL+vNN9+Ut7e3ue+4ceMUFBSk999/XyNGjJCrq6vq1aunnj173vI4rq6u2rRpkyZOnKjPPvtMH330kTw9PVWjRg2NHTtWXl5e5r6zZ89WSEiI5syZozfeeEMlSpRQYGCgevTooRYtWtjiMki6/p7ucuXKacaMGXrllVdUpkwZvfTSS5o4caLFu8FffvllJSUlaeHChZo2bZoCAgII3QCAe57BVJCxaQAAAAAAwGo80w0AAAAAgI0QugEAAAAAsBFCNwAAAAAANkLoBgAAAADARgjdAAAAAADYCKEbAAAAAAAbue/e0200GvXbb7/Jw8NDBoPB3uUAAAAAAIogk8mkP//8UxUrVpSDQ/73s++70P3bb7/J39/f3mUAAAAAAIqBkydP6oEHHsh3/X0Xuj08PCRdvzCenp52rgYAAAAAUBSlp6fL39/fnDHzc9+F7htDyj09PQndAAAAAIA7cqvHlplIDQAAAAAAGyF0AwAAAABgI4RuAAAAAABs5L57phsAAAC4E0ajUdnZ2fYuA4CNlSxZUo6Ojne8H0I3AAAAUEDZ2dk6duyYjEajvUsBcBd4e3vLz8/vlpOl3QyhGwAAACgAk8mk06dPy9HRUf7+/nJw4ElNoLgymUy6fPmyzp49K0mqUKHCbe+L0A0AAAAUwLVr13T58mVVrFhRrq6u9i4HgI2VKlVKknT27FmVL1/+toea8/McAAAAUAA5OTmSJCcnJztXAuBuufED29WrV297H4RuAAAAwAp38mwngKKlMP55J3QDAAAAAGAjhG4AAAAAt2QwGLRy5UpJ0vHjx2UwGJSUlHRb2+fldvZZEIGBgYqLiyvUfVpr48aNMhgMunjxYoG3adOmjYYMGWKzmm7o3bu3OnXqZPPj5OVunaO9MZEaAAAAcAf6ftn3rh5vToc5VvXv3bu3Ll68eNPAay1/f3+dPn1aPj4+Bd7m9OnTKl26dKHVUNwtX75cJUuWtHcZNnU/nKN0D4TumTNn6t1331VaWprq16+v999/X02aNMm3/8WLFzVixAgtX75cFy5cUEBAgOLi4tSuXbu7WPXdcbf/BW4P1v5HAwAAAPbn6OgoPz8/q7axtv/9rkyZMvYuwWays7Pl5ORUrM/x7+w6vHzZsmWKjo7W6NGjtWvXLtWvX1/h4eHmd6H9U3Z2ttq2bavjx4/rf//7n5KTkzVv3jxVqlTpLlcOAAAAFE1t2rTRoEGD9Prrr6tMmTLy8/PTmDFjLPqkpKSoVatWcnFxUXBwsNatW2ex/u9DwY1Gox544AHNmjXLos/u3bvl4OCgEydOSMo9vHz79u1q2LChXFxcFBoaqt27d1tsHx8fL29vb4u2l
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1000x400 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA90AAAGGCAYAAABmGOKbAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAACK9ElEQVR4nOzdd1hTZxsG8DsE2UNFQEAExb03BWeVilp3VVTqoI62aq2CVm3F2bqqFle1WpXWutCits5axL0X6qeiIogiiAsQZJmc749TApGACRDDuH/XlYvkPe95z5MQI0/eJREEQQARERERERERFTk9XQdAREREREREVFox6SYiIiIiIiLSEibdRERERERERFrCpJuIiIiIiIhIS5h0ExEREREREWkJk24iIiIiIiIiLWHSTURERERERKQlTLqJiIiIiIiItIRJNxEREREREZGWMOkmIiIqYsOHD4ezs3ORthkYGAiJRIKoqKgibbc4OXr0KCQSCY4eParrUN6LsvZ8iYjKKibdRERULEVERODzzz9H9erVYWRkBAsLC7Ru3RrLli1DamqqrsPTmnnz5mH37t26DgMA0LNnT5iYmODVq1d51vH29oaBgQGeP3/+HiMr3rK+IMm6GRkZoVatWhg3bhyePHlSJNfYv38/Zs2aVSRtERGRdjHpJiKiYmffvn1o2LAhgoKC0KNHD6xYsQLz589H1apVMXnyZHz99de6DlFr8kq6hwwZgtTUVDg5Ob23WLy9vZGamopdu3apPP769Wvs2bMHXbp0gZWVVaGv165dO6SmpqJdu3aFbqs4mDNnDjZt2oSVK1fC3d0dq1evhpubG16/fl3otvfv34/Zs2cXQZRERKRt+roOgIiIKKfIyEgMHDgQTk5OOHLkCOzs7BTHxo4di3v37mHfvn06jFA3pFIppFLpe71mz549YW5uji1btmDo0KG5ju/ZswcpKSnw9vYu1HXS0tJgYGAAPT09GBkZFaqt4qRr165o0aIFAGDkyJGwsrLC0qVLsWfPHgwaNEjH0RER0fvCnm4iIipWFi1ahOTkZKxfv14p4c5So0YNRU93VFQUJBIJAgMDc9WTSCRKw29nzZoFiUSCO3fu4NNPP4WlpSWsra3h7+8PQRDw8OFD9OrVCxYWFqhcuTKWLFmi1F5ec6rVnZe7ePFiuLu7w8rKCsbGxmjevDl27tyZK+aUlBT89ttviqHJw4cPV3n97t27o3r16iqv5ebmpkj2svzxxx9o3rw5jI2NUbFiRQwcOBAPHz7MN2ZjY2P07dsXISEhiI+Pz3V8y5YtMDc3R8+ePfHixQtMmjQJDRs2hJmZGSwsLNC1a1eEhYUpnZP1em3btg3Tp0+Hg4MDTExMkJSUpPK1PHHiBPr374+qVavC0NAQjo6OmDhxYq4pBsOHD4eZmRliYmLQu3dvmJmZwdraGpMmTYJMJlOqK5fLsWzZMjRs2BBGRkawtrZGly5dcPHixUK/Zvnp2LEjAPGLpfzs2LFDcd1KlSrh008/RUxMjNJzXbVqFQAoDWMnIqLiiUk3EREVK3///TeqV68Od3d3rbTv5eUFuVyOBQsWwNXVFd9//z0CAgLw0UcfwcHBAQsXLkSNGjUwadIkHD9+vMiuu2zZMjRt2hRz5szBvHnzoK+vj/79+yv12m/atAmGhoZo27YtNm3ahE2bNuHzzz/P83lERkbiwoULSuUPHjzA2bNnMXDgQEXZDz/8gKFDh6JmzZpYunQpJkyYgJCQELRr1w4JCQn5xu3t7Y03b94gKChIqfzFixc4dOgQ+vTpA2NjY9y/fx+7d+9G9+7dsXTpUkyePBnXr19H+/bt8fjx41ztzp07F/v27cOkSZMwb948GBgYqLz+jh078Pr1a3z55ZdYsWIFPD09sWLFCpU97zKZDJ6enrCyssLixYvRvn17LFmyBGvXrlWqN2LECEyYMAGOjo5YuHAhpk6dCiMjI5w9e7ZIXrO8REREAEC+Q/EDAwMxYMAASKVSzJ8/H6NGjUJwcDDatGmjuO7nn3+Ojz76CAAU75NNmzYVKCYiInoPBCIiomIiMTFRACD06tVLrfqRkZECAGHjxo25jgEQZs6cqXg8c+ZMAYAwevRoRdmbN2+EKlWqCBKJRFiwYIGi/OXLl4KxsbEwbNgwRdnGjRsFAEJkZKTSdUJDQwUAQmhoqKJs2LBhgpOTk1K9169fKz3OyMgQGjRoIHTs2FGp3NTUVOm6eV0/MTFRMDQ0FPz8/JTqLVq0SJBIJMKDBw8EQRCEqKgoQSqVCj/88INSvevXrwv6+vq5yt/25s0bwc7OTnBzc1MqX7NmjQBAOHTokCAIgpCWlibIZDKlOpGRkYKhoaEwZ84cRVnW61W9evVcr4mq1/LtOoIgCPPnz1d6joIgvuYAlK4lCILQtGlToXnz5orHR44cEQAI48ePz9WuXC4XBKHwr1nW7+rff/8Vnj59Kjx8+FDYtm2bYGVlJRgbGwuPHj1S+XwzMjIEGxsboUGDBkJqaqqivb179woAhBkzZijKxo4dK/DPOCKikoE93UREVGwkJSUBAMzNzbV2jZEjRyruS6VStGjRAoIgYMSIEYry8uXLo3bt2rh//36RXdfY2Fhx/+XLl0hMTETbtm1x+fLlArWXNXw7KCgIgiAoyrdv344PPvgAVatWBQAEBwdDLpdjwIABePbsmeJWuXJl1KxZE6GhofleRyqVYuDAgThz5ozS0PotW7bA1tYWnTp1AgAYGhpCT0/8s0Imk+H58+cwMzND7dq1VT7HYcOGKb0meclZJyUlBc+ePYO7uzsEQcCVK1dy1f/iiy+UHrdt21bp9/jnn39CIpFg5syZuc7NGqJd2Ncsi4eHB6ytreHo6IiBAwfCzMwMu3btgoODg8r6Fy9eRHx8PMaMGaM0t/3jjz9GnTp1yuRaBkREpQEXUiMiomLDwsICAPLdoqqwspLRLJaWljAyMkKlSpVylRflNlh79+7F999/j6tXryI9PV1RXpi5uF5eXti9ezfOnDkDd3d3RERE4NKlSwgICFDUuXv3LgRBQM2aNVW2Ua5cuXdex9vbGz/99BO2bNmCb7/9Fo8ePcKJEycwfvx4xeJuWfOkf/75Z0RGRirNo1Y1nLpatWpqPcfo6GjMmDEDf/31F16+fKl0LDExUelx1vzsnCpUqKB0XkREBOzt7VGxYsU8r1kUrxkArFq1CrVq1YK+vj5sbW1Ru3ZtxRcTqjx48AAAULt27VzH6tSpg5MnT6p1XSIiKl6YdBMRUbFhYWEBe3t73LhxQ636eSWsby+clZOqFcDzWhU8Zw9yQa6V5cSJE+jZsyfatWuHn3/+GXZ2dihXrhw2btyILVu2vPP8vPTo0QMmJiYICgqCu7s7goKCoKenh/79+yvqyOVySCQSHDhwQOXzNDMze+d1mjdvjjp16mDr1q349ttvsXXrVgiCoLRq+bx58+Dv74/PPvsMc+fORcWKFaGnp4cJEyZALpfnalOdXm6ZTIaPPvoIL168wJQpU1CnTh2YmpoiJiYGw4cPz9VuUa3uXhSvGQC0atUq14J2RERU9jDpJiKiYqV79+5Yu3Ytzpw5Azc3t3zrVqhQAQByLWyV1WNYlApzrT///BNGRkY4dOgQDA0NFeUbN27MVVeTnm9TU1N0794dO3bswNKlS7F9+3a0bdsW9vb2ijouLi4QBAHVqlVDrVq11G77bd7e3vD398e1a9ewZcsW1KxZEy1btlQc37lzJz788EOsX79e6byEhIRcowjUdf36ddy5cwe//fab0sJphw8fLtiTgPh6HDp0CC9evMizt7uoXjNNZe3BHh4erljpPEt4eLjSHu1crZyIqOTgnG4iIipWvvnmG5iammLkyJF48uRJruMRERFYtmwZALFnvFKlSrlWGf/555+LPC4XFxcAULqWTCbLtTK2KlKpFBKJRKlXPCoqCrt3785V19TUVKPVsb28vPD48WP8+uuvCAsLg5eXl9Lxvn37QiqVYvbs2Uo994DYk6/uEPqsXu0ZM2bg6tWrufbmlkqludrfsWOH0lZXmsrqZc7Zr
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.decomposition import PCA\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"\n",
"X_1 = longformer_vectors\n",
"X_2 = linformer_vectors\n",
"\n",
"# Standardize the features independently\n",
"scaler1 = StandardScaler()\n",
"X_1_scaled = scaler1.fit_transform(X_1)\n",
"\n",
"scaler2 = StandardScaler()\n",
"X_2_scaled = scaler2.fit_transform(X_2)\n",
"\n",
"# Apply PCA\n",
"pca = PCA(n_components=2) # Reducing to 2 components for a 2D plot\n",
"X_1_pca = pca.fit_transform(X_1_scaled)\n",
"pca2 = PCA(n_components=2)\n",
"X_2_pca = pca2.fit_transform(X_2_scaled)\n",
"\n",
"# Plotting\n",
"plt.figure(figsize=(10, 6))\n",
"plt.scatter(X_1_pca[:, 0], X_1_pca[:, 1], color='blue', alpha=0.5, label='Longformer Model')\n",
"plt.scatter(X_2_pca[:, 0], X_2_pca[:, 1], color='red', alpha=0.5, label='Linformer Model')\n",
"plt.xlabel('Principal Component 1')\n",
"plt.ylabel('Principal Component 2')\n",
"plt.title('PCA Comparison of Two Models')\n",
"plt.legend()\n",
"plt.grid(True)\n",
"plt.show()\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 564
},
"id": "6X6EyalGEe4B",
"outputId": "568c07f4-cea8-40ff-c4ae-1a1d1039120c"
},
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1000x600 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA14AAAIjCAYAAAATE8pZAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAADaxUlEQVR4nOzdd3hUVf4/8PedmcxMJskkpEEghNATEKkiCAgqiIoi+MXe664uoqKr66qAbW0/yyqKdcXeRVdFF0S6iIo0IfQSUkiD9DLl3t8fH+6UFJghmTTer+fJk8ydOzPnziXhvuec8zmKpmkaiIiIiIiIKGQMLd0AIiIiIiKi9o7Bi4iIiIiIKMQYvIiIiIiIiEKMwYuIiIiIiCjEGLyIiIiIiIhCjMGLiIiIiIgoxBi8iIiIiIiIQozBi4iIiIiIKMQYvIiIiIiIiEKMwYuIiNqlcePGYdy4cS3djEb77bffcNpppyEiIgKKomDDhg0t3aRWb/78+VAUBfv27Qv6sXPmzIGiKE3fKCI64TF4EREdJ/3iTv+yWq3o06cPpk+fjry8vDr75+Xl4Z577kFaWhpsNhsiIiIwdOhQPPbYYyguLq73NYYPHw5FUTBv3ryg21daWoqHH34YAwcORGRkJMLDw3HSSSfhvvvuQ05OTtDPR83P6XTi4osvxqFDh/D888/jvffeQ7du3ersl5qa6vdvsaGv+fPnN2v7x40bB0VR0Lt373rvX7x4sadtn3/+ebO2jYiouZlaugFERG3dI488gu7du6O6uhqrVq3CvHnzsHDhQvz555+w2WwApNfivPPOQ3l5Oa666ioMHToUAPD777/jySefxIoVK7Bo0SK/5925cyd+++03pKam4oMPPsCtt94acJv27NmD8ePHIzMzExdffDFuueUWmM1mbNq0CW+99RYWLFiAHTt2NN2b0ArVfj/bot27d2P//v144403cNNNNzW43wsvvIDy8nLP7YULF+Kjjz7C888/j/j4eM/20047LaTtrY/VasWuXbvw66+/Yvjw4X73ffDBB7Baraiurm72dhERNTcGLyKiRjr33HMxbNgwAMBNN92EuLg4PPfcc/j6669x+eWXo7i4GFOnToXRaMT69euRlpbm9/jHH38cb7zxRp3nff/995GYmIhnn30W06ZNw759+5CamnrM9rhcLlx00UXIy8vDsmXLMHr06Dqv99RTTx3/AbdylZWVsNlsMJvNLd2URsvPzwcAxMTEHHW/KVOm+N0+ePAgPvroI0yZMiWgfzOh1LNnT7hcLnz00Ud+wau6uhoLFizApEmT8MUXX7RgC4mImgeHGhIRNbEzzzwTALB3714AwGuvvYbs7Gw899xzdUIXAHTs2BEPPvhgne0ffvghpk2bhvPPPx/R0dH48MMPA3r9L774Ahs3bsQDDzxQJ3QBgN1ux+OPP+637bPPPsPQoUMRHh6O+Ph4XHXVVcjOzvbb57rrrkNkZCQyMzNx/vnnIzIyEl26dMHLL78MANi8eTPOPPNMREREoFu3bnXaqw/NXLFiBf7yl78gLi4Odrsd11xzDQ4fPuy379dff41Jkyahc+fOsFgs6NmzJx599FG43W6//caNG4eTTjoJ69atw+mnnw6bzYZ//vOfnvtqz/F66aWX0L9/f9hsNnTo0AHDhg2r087169fj3HPPhd1uR2RkJM466yz88ssv9R7L6tWrMXPmTCQkJCAiIgJTp05FQUFBfaeljp9++gljxoxBREQEYmJicOGFFyIjI8Pv/R47diwA4OKLL4aiKMc9Z23mzJmIi4uDpmmebbfffjsURcGLL77o2ZaXl1dnaGt+fj5uvPFGdOzYEVarFQMHDsQ777wT1Otffvnl+OSTT6CqqmfbN998g8rKSlxyySX1PiaQ8wAAW7ZswZlnnonw8HAkJyfjscce83sdX99//73nPY+KisKkSZOwZcuWY7Z/8eLFGD16NGJiYhAZGYm+fft6/p0REQWKwYuIqInt3r0bABAXFwcA+O9//4vw8HBMmzYt4OdYu3Ytdu3ahcsvvxxmsxkXXXQRPvjgg4Ae+9///hcAcPXVVwe0//z583HJJZfAaDTiiSeewM0334wvv/wSo0ePrjP3zO1249xzz0XXrl3x9NNPIzU1FdOnT8f8+fNxzjnnYNiwYXjqqacQFRWFa665xhM+fU2fPh0ZGRmYM2cOrrnmGnzwwQeYMmWKXyiYP38+IiMjMXPmTPz73//G0KFDMWvWLPzjH/+o83xFRUU499xzMWjQILzwwgs444wz6j3ON954AzNmzEC/fv3wwgsv4OGHH8agQYOwdu1azz5btmzBmDFjsHHjRtx777146KGHsHfvXowbN85vP93tt9+OjRs3Yvbs2bj11lvxzTffYPr06cd8z3/88UdMnDgR+fn5mDNnDmbOnImff/4Zo0aN8hSE+Mtf/uK5uJ8xYwbee+89PPDAA8d87vqMGTMGhw4d8gsZK1euhMFgwMqVK/22AcDpp58OAKiqqsK4cePw3nvv4corr8QzzzyD6OhoXHfddfj3v/8d8OtfccUVyM3NxbJlyzzbPvzwQ5x11llITEyss3+g5+HgwYM444wzsGHDBvzjH//AnXfeiXfffbfetr333nuYNGkSIiMj8dRTT+Ghhx7C1q1bMXr06KMW4diyZQvOP/981NTU4JFHHsGzzz6LyZMnY/Xq1QEfPxERAEAjIqLj8vbbb2sAtB9//FErKCjQDhw4oH388cdaXFycFh4ermVlZWmapmkdOnTQBg4cGNRzT58+XevataumqqqmaZq2aNEiDYC2fv36Yz528ODBWnR0dECv43A4tMTERO2kk07SqqqqPNu//fZbDYA2a9Ysz7Zrr71WA6D961//8mw7fPiwFh4erimKon388cee7du2bdMAaLNnz/Zs09+voUOHag6Hw7P96aef1gBoX3/9tWdbZWVlnbb+5S9/0Ww2m1ZdXe3ZNnbsWA2A9uqrr9bZf+zYsdrYsWM9ty+88EKtf//+R30/pkyZopnNZm337t2ebTk5OVpUVJR2+umn1zmW8ePHe86RpmnaXXfdpRmNRq24uPiorzNo0CAtMTFRKyoq8mzbuHGjZjAYtGuuucazbenSpRoA7bPPPjvq89X2zDPPaAC0vXv3apqmafn5+RoA7ZVXXtE0TdOKi4s1g8GgXXzxxVrHjh09j5sxY4YWGxvrOaYXXnhBA6C9//77nn0cDoc2cuRILTIyUistLT1qO8aOHet5z4cNG6bdeOONmqbJvxuz2ay988479R5joOfhzjvv1ABoa9eu9WzLz8/XoqOj/Y6/rKxMi4mJ0W6++Wa/9h08eFCLjo722z579mzN9/Lo+eef1wBoBQUFRz1WIqJjYY8XEVEjjR8/HgkJCejatSsuu+wyREZGYsGCBejSpQsAqS4YFRUV8PO5XC588sknuPTSSz1lrc8880wkJiYG1OsVzOv9/vvvyM/Px2233Qar1erZPmnSJKSlpeG7776r8xjfIg8xMTHo27cvIiIi/IaM9e3bFzExMdizZ0+dx99yyy0ICwvz3L711lthMpmwcOFCz7bw8HDPz2VlZSgsLMSYMWNQWVmJbdu2+T2fxWLB9ddff8xjjYmJQVZWFn777bd673e73Vi0aBGmTJmCHj16eLYnJSXhiiuuwKpVq1BaWlrnWHxLj48ZMwZutxv79+9vsB25ubnYsGEDrrvuOsTGxnq2n3zyyZgwYYLf+9BUEhISkJaWhhUrVgAAVq9eDaPRiL///e/Iy8vDzp07AUiP1+jRoz3HtHDhQnTq1AmXX36557nCwsIwY8YMlJeXY/ny5QG34YorrsCXX34Jh8OBzz//HEajEVOnTq2zXzDnYeHChRgxYoTf3LGEhARceeWVfs+5ePFiFBcX4/LLL0dhYaHny2g04tRTT8XSpUsbbLc+v+7rr79ucAgjEVEgGLyIiBrp5ZdfxuLFi7F06VJs3boVe/bswcSJEz332+12lJWVBfx8ixYtQkFBAYYPH45du3Zh165d2Lt3L8444wx89NFHx7z4C+b19IDQt
},
"metadata": {}
}
]
},
{
"cell_type": "code",
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"from sklearn.metrics.pairwise import cosine_similarity\n",
"\n",
"# Assuming longformer_vector_df and linformer_vector_df have a column 'message_vector' with the vectors\n",
"# Stacking vectors into a 2D array for each model\n",
"longformer_vectors = np.stack(longformer_vector_df['message_vector'].to_numpy())\n",
"linformer_vectors = np.stack(linformer_vector_df['message_vector'].to_numpy())\n",
"\n",
"# Calculate the cosine similarity matrix for each model\n",
"cosine_sim_matrix_long = cosine_similarity(longformer_vectors)\n",
"cosine_sim_matrix_lin = cosine_similarity(linformer_vectors)\n",
"\n",
"# Extract the upper triangle values from each cosine similarity matrix, excluding the diagonal\n",
"upper_triangle_indices_long = np.triu_indices_from(cosine_sim_matrix_long, k=1)\n",
"cosine_sim_values_long = cosine_sim_matrix_long[upper_triangle_indices_long]\n",
"\n",
"upper_triangle_indices_lin = np.triu_indices_from(cosine_sim_matrix_lin, k=1)\n",
"cosine_sim_values_lin = cosine_sim_matrix_lin[upper_triangle_indices_lin]\n",
"\n",
"# Plotting the KDE of cosine similarities for both models\n",
"plt.figure(figsize=(12, 6))\n",
"sns.kdeplot(cosine_sim_values_long, color='blue', fill=True, label='Longformer Model')\n",
"sns.kdeplot(cosine_sim_values_lin, color='green', fill=True, label='Linformer Model')\n",
"plt.title('Comparison of Cosine Similarity Distributions')\n",
"plt.xlabel('Cosine Similarity')\n",
"plt.ylabel('Density')\n",
"plt.grid(True)\n",
"plt.legend()\n",
"plt.show()\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 564
},
"id": "hosMIGN4E5Vj",
"outputId": "2f481f84-219b-4289-866f-4e81b04d9db8"
},
"execution_count": null,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 1200x600 with 1 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA+QAAAIjCAYAAACKx9GpAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAACVRElEQVR4nOzdd3xT9f7H8XeapG3SPYC2UGjZQwSVIYIiylAR8boVFRTHveAW/blB3HpBvDLcuNCrXLdyVcSLOBARB6jI3i2U2d0mTc7vj2MCoStt0xbo6/l45JH2nJNzvk170Hc+32ExDMMQAAAAAABoUGGN3QAAAAAAAJoiAjkAAAAAAI2AQA4AAAAAQCMgkAMAAAAA0AgI5AAAAAAANAICOQAAAAAAjYBADgAAAABAIyCQAwAAAADQCAjkAAAAAAA0AgI5AKDWLBaLJk2a1NjNqLPXXntNnTt3lt1uV3x8fGM3p5yNGzfKYrHo5ZdfbuymVCnUfw8V/dyTJk2SxWIJ2TV8MjIyNGbMmJCfNxROPvlknXzyyQ1yrYN/h773e9euXQ1y/UP59wAA9YFADgB1sG7dOl177bVq27atIiMjFRsbq/79++upp55ScXFxYzcPQfjzzz81ZswYtWvXTs8//7yee+65al/zyy+/6NJLL1V6eroiIiKUmJiowYMHa/bs2fJ4PA3Q6ob1zTff6PTTT1fLli0VGRmp1q1ba8SIEXrjjTcau2n15o8//tCkSZO0cePGkJ53zJgxslgs/kd0dLTatm2r8847T++88468Xm9IrvPdd99p0qRJ2rdvX0jOF0qHctsAoKHZGrsBAHC4+uSTT3T++ecrIiJCl19+uY466ii5XC598803uu222/T7778HFe4OZ8XFxbLZDu//lCxcuFBer1dPPfWU2rdvX+3xL7zwgv7+97+rRYsWuuyyy9ShQwfl5+drwYIFGjt2rLKzs3XXXXeFtI1t2rRRcXGx7HZ7SM8bjLlz5+rCCy9Uz549deONNyohIUEbNmzQokWL9Pzzz+uSSy7xHxvqv4eG/LlXrVqlsLD9dYo//vhD999/v04++WRlZGSE9FoRERF64YUXJJnv2aZNm/TRRx/pvPPO08knn6wPPvhAsbGx/uM///zzGl/ju+++0/33368xY8bUqNdHQ9zTVbXt4N8DABzpDu//iwKARrJhwwZddNFFatOmjb788kulpqb6940fP15r167VJ5980ogtrD9er1cul0uRkZGKjIxs7ObUWU5OjiQFFVq+//57/f3vf1e/fv00b948xcTE+PfddNNN+vHHH/Xbb7+FvI0Wi6XR3utJkyapa9eu+v777xUeHh6wz/fe+YS6jfX9cxuGoZKSEjkcDkVERNTbdQ5ms9l06aWXBmx78MEH9eijj+rOO+/U1Vdfrbfeesu/7+D3PdQOpXu6IX8PAHAo4CNIAKiFxx9/XAUFBXrxxRcDwrhP+/btdeONN/q/Lysr0wMPPKB27dopIiJCGRkZuuuuu1RaWhrwuoyMDJ155plauHChevXqJYfDoe7du2vhwoWSpHfffVfdu3dXZGSkjjvuOP38888Brx8zZoyio6O1fv16DRs2TFFRUUpLS9PkyZNlGEbAsf/85z91wgknKCkpSQ6HQ8cdd5z+85//lPtZLBaLrrvuOs2ZM0fdunVTRESEPv30U/++A8eb5ufn66abblJGRoYiIiLUvHlzDRkyRD/99FPAOefOnavjjjtODodDycnJuvTSS7Vt27YKf5Zt27bp7LPPVnR0tJo1a6YJEyYE3S185syZ/janpaVp/PjxAd1kMzIyNHHiRElSs2bNqh0Dff/998tisWjOnDkBYdynV69eAeNfCwsLdeutt/q7tnfq1En//Oc/y/0u5s+frwEDBig+Pl7R0dHq1KlTQJW9orHUNXl/vF6vpk2bpm7duikyMlItWrTQtddeq71791b7Hq5bt069e/euMBQ2b9484PvKxh+vXr1al156qeLi4tSsWTPde++9MgxDW7Zs0ciRIxUbG6uUlBRNmTIl4HzBjp2fPXu2TjnlFDVv3lwRERHq2rWrZs2aVe443/312Wef+e+vZ5991r/P97t7+eWXdf7550uSBg0a5O9evnDhQo0ePVrJyclyu93lzj906FB16tSpyrZW5Y477tDQoUM1d+5crV692r+9ojHkTz/9tLp16yan06mEhAT16tXLP4Rg0qRJuu222yRJmZmZ/vb7ut/X5J722bVrly644ALFxsYqKSlJN954o0pKSvz7q/pdHXjO6tpW0Rjy9evX6/zzz1diYqKcTqeOP/74ch94Lly4UBaLRW+//bYeeughtWrVSpGRkTr11FO1du3agGPXrFmjc889VykpKYqMjFSrVq100UUXKTc3t1zbAaC+EcgBoBY++ugjtW3bVieccEJQx1911VW67777dOyxx+rJJ5/UwIED9cgjj+iiiy4qd+zatWt1ySWXaMSIEXrkkUe0d+9ejRgxQnPmzNHNN9+sSy+9VPfff7/WrVunCy64oNyYU4/Ho9NOO00tWrTQ448/ruOOO04TJ070B0+fp556Ssccc4wmT56shx9+WDabTeeff36Flf0vv/xSN998sy688EI99dRTlXbh/fvf/65Zs2bp3HPP1cyZMzVhwgQ5HA6tXLnSf8zLL7+sCy64QFarVY888oiuvvpqvfvuuxowYEC5MaUej0fDhg1TUlKS/vnPf2rgwIGaMmVKUEMBJk2apPHjxystLU1TpkzRueeeq2effVZDhw71h6lp06bpb3/7myRp1qxZeu2113TOOedUeL6ioiItWLBAJ510klq3bl3t9Q3D0FlnnaUnn3xSp512mqZOnapOnTrptttu0y233OI/7vfff9eZZ56p0tJSTZ48WVOmTNFZZ52lb7/9ttprBPv+XHvttbrtttv88xtcccUVmjNnjoYNG1ZhsDxQmzZttGDBAm3durXa9lTmwgsvlNfr1aOPPqq+ffvqwQcf1LRp0zRkyBC1bNlSjz32mNq3b68JEyZo0aJFNT7/rFmz1KZNG911112aMmWK0tPTNW7cOM2YMaPcsatWrdLFF1+sIUOG6KmnnlLPnj3LHXPSSSfphhtukCTdddddeu211/Taa6+pS5cuuuyyy7R792599tlnAa/Zvn27vvzyy3KV75q67LLLZBiG5s+fX+kxzz//vG644QZ17dpV06ZN0/3336+ePXtqyZIlkqRzzjlHF198sSTpySef9Le/WbNm/nMEe0/7XHDBBSopKdEjjzyiM844Q//61790zTXX1PjnC6ZtB9qxY4dOOOEEffbZZxo3bpweeughlZSU6KyzztJ7771X7vhHH31U7733niZMmKA777xT33//vUaNGuXf73K5NGzYMH3//fe6/vrrNWPGDF1zzTVav349Y9oBNA4DAFAjubm5hiRj5MiRQR3/yy+/GJKMq666KmD7hAkTDEnGl19+6d/Wpk0bQ5Lx3Xff+bd99tlnhiTD4XAYmzZt8m9/9tlnDUnG//73P/+20aNHG5KM66+/3r/N6/Uaw4cPN8LDw42dO3f6txcVFQW0x+VyGUcddZRxyimnBGyXZISFhRm///57uZ9NkjFx4kT/93Fxccb48eMrfS9cLpfRvHlz46ijjjKKi4v92z/++GNDknHfffeV+1kmT54ccI5jjjnGOO644yq9hmEYRk5OjhEeHm4MHTrU8Hg8/u3Tp083JBkvvfSSf9vEiRMNSQHvTUV+/fVXQ5Jx4403Vnmcz/vvv29IMh588MGA7eedd55hsViMtWvXGoZhGE8++WS119+wYYMhyZg9e7Z/W7Dvz9dff21IMubMmRNw3Kefflrh9oO9+OKLhiQjPDzcGDRokHHvvfcaX3/9dcD76nPw34Pvvb3mmmv828rKyoxWrVoZFovFePTRR/3b9+7dazgcDmP06NFV/ty+cx7o4L9lwzCMYcOGGW3btg3Y5ru/Pv3003LHt2nTJuDac+fOLXd/GYZheDweo1WrVsaFF14YsH3q1KmGxWIx1q9fX+7cBxo9erQRFRVV6f6ff
},
"metadata": {}
}
]
},
{
"cell_type": "markdown",
"source": [
"# Approaching Quantile Transoformer for feature scaling for LinFormer model vectors"
],
"metadata": {
"id": "SxhOQqiHhPju"
}
},
{
"cell_type": "code",
"source": [
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"import numpy as np\n",
"from sklearn.decomposition import PCA\n",
2024-07-29 12:01:01 +00:00
"from sklearn.preprocessing import StandardScaler, QuantileTransformer\n",
"\n",
2024-07-29 12:01:01 +00:00
"# Assuming X is your feature matrix\n",
"X = linformer_vectors\n",
"\n",
"def process_and_pca(X, scaler):\n",
" X_scaled = scaler.fit_transform(X)\n",
" pca = PCA(n_components=0.95)\n",
" X_pca = pca.fit_transform(X_scaled)\n",
" explained_variance_ratio = pca.explained_variance_ratio_\n",
" cumulative_variance = np.cumsum(explained_variance_ratio)\n",
" return explained_variance_ratio, cumulative_variance\n",
"\n",
"# Process for StandardScaler\n",
"explained_variance_ratio_std, cumulative_variance_std = process_and_pca(X, StandardScaler())\n",
"\n",
"# Process for QuantileTransformer\n",
"explained_variance_ratio_qt, cumulative_variance_qt = process_and_pca(X, QuantileTransformer(n_quantiles=1000, output_distribution='normal'))\n",
"\n",
"# Create subplots\n",
"fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(20, 12))\n",
"\n",
2024-07-29 12:01:01 +00:00
"# Scree plot StandardScaler\n",
"axes[0, 0].bar(range(1, len(explained_variance_ratio_std) + 1), explained_variance_ratio_std, alpha=0.6, color='g', label='StandardScaler: Individual explained variance')\n",
"axes[0, 0].set_title('StandardScaler Scree Plot')\n",
"axes[0, 0].set_xlabel('Principal components')\n",
"axes[0, 0].set_ylabel('Explained variance ratio')\n",
"axes[0, 0].legend()\n",
"\n",
2024-07-29 12:01:01 +00:00
"# Cumulative Variance plot StandardScaler\n",
"axes[1, 0].plot(range(1, len(cumulative_variance_std) + 1), cumulative_variance_std, marker='o', linestyle='-', color='b', label='StandardScaler: Cumulative explained variance')\n",
"axes[1, 0].axhline(y=0.95, color='r', linestyle='--', label='95% Explained Variance')\n",
2024-07-29 12:01:01 +00:00
"axes[1, 0].set_title('StandardScaler Cumulative Variance Plot')\n",
"axes[1, 0].set_xlabel('Number of components')\n",
"axes[1, 0].set_ylabel('Cumulative explained variance')\n",
"axes[1, 0].legend()\n",
"\n",
2024-07-29 12:01:01 +00:00
"# Scree plot QuantileTransformer\n",
"axes[0, 1].bar(range(1, len(explained_variance_ratio_qt) + 1), explained_variance_ratio_qt, alpha=0.6, color='g', label='QuantileTransformer: Individual explained variance')\n",
"axes[0, 1].set_title('QuantileTransformer Scree Plot')\n",
"axes[0, 1].set_xlabel('Principal components')\n",
"axes[0, 1].set_ylabel('Explained variance ratio')\n",
"axes[0, 1].legend()\n",
"\n",
2024-07-29 12:01:01 +00:00
"# Cumulative Variance plot QuantileTransformer\n",
"axes[1, 1].plot(range(1, len(cumulative_variance_qt) + 1), cumulative_variance_qt, marker='o', linestyle='-', color='b', label='QuantileTransformer: Cumulative explained variance')\n",
"axes[1, 1].axhline(y=0.95, color='r', linestyle='--', label='95% Explained Variance')\n",
2024-07-29 12:01:01 +00:00
"axes[1, 1].set_title('QuantileTransformer Cumulative Variance Plot')\n",
"axes[1, 1].set_xlabel('Number of components')\n",
"axes[1, 1].set_ylabel('Cumulative explained variance')\n",
"axes[1, 1].legend()\n",
"\n",
"# Adjust layout and display the plots\n",
"plt.tight_layout()\n",
2024-07-29 12:01:01 +00:00
"plt.show()\n",
"\n",
"# Print the number of components needed to explain 95% of the variance\n",
"n_components_std = np.argmax(cumulative_variance_std >= 0.95) + 1\n",
"n_components_qt = np.argmax(cumulative_variance_qt >= 0.95) + 1\n",
"\n",
"print(f\"Number of components needed to explain 95% of variance (StandardScaler): {n_components_std}\")\n",
"print(f\"Number of components needed to explain 95% of variance (QuantileTransformer): {n_components_qt}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "-n1_yRcaGYZD",
2024-07-29 12:01:01 +00:00
"outputId": "0a4bff53-25a4-4ca1-8030-7ee5d94ccb50"
},
2024-07-29 12:01:01 +00:00
"execution_count": 27,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 2000x1200 with 4 Axes>"
],
"image/png": "iVBORw0KGgoAAAANSUhEUgAAB8YAAASlCAYAAADaj7M5AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3gU1fv38c8mpJBGC0kghIQOAaQK0lHAKAKCSkdCaCoEkAgCKh2NiFL0izTpoiBSLICUCEhHadKFQGhCaNICBEzm+YMn+2NJYQNZEjfv13XtJXvmzDlnZpO497lnzpgMwzAEAAAAAAAAAAAAAICdcsjsAQAAAAAAAAAAAAAAYEskxgEAAAAAAAAAAAAAdo3EOAAAAAAAAAAAAADArpEYBwAAAAAAAAAAAADYNRLjAAAAAAAAAAAAAAC7RmIcAAAAAAAAAAAAAGDXSIwDAAAAAAAAAAAAAOwaiXEAAAAAAAAAAAAAgF0jMQ4AAAAAAAAAAAAAsGskxgEAeAxBQUHq1KnTE+mrU6dOCgoKeiJ9pWXdunUymUxat25dZg/FbnGOAQAAANir+vXrq379+ub3MTExMplMmjVrVqaN6Uk6cuSInn/+eeXKlUsmk0lLly7N7CHBRkwmk4YNG5bZwwAA3IfEOAAgy9m7d69ee+01BQYGytXVVf7+/mrUqJG++OILi3offfRRtg8gb9y4oaFDh6pcuXJyd3dXvnz5VLFiRfXp00d///13Zg8vQ8TExCgsLEzFihWTq6ur/Pz8VLduXQ0dOjSzh5Zus2bNkslkMr9cXV1VsmRJhYeHKzY2NkP6WL58OYE3AAAAYKf279+vDh06yN/fXy4uLipYsKA6dOigAwcOZPbQLBw4cEDDhg1TTEyMzfoICgqyiK9Se2W1hHtoaKj27t2rDz/8UHPnzlXVqlUze0g2Z09zF0kXciS9HB0dVbhwYbVo0UK7d+/OkD6exO8PAGRXOTJ7AAAA3G/z5s169tlnVbhwYXXr1k1+fn46deqUtm7dqgkTJqhXr17muh999JFee+01NW/ePPMGnInu3r2runXr6tChQwoNDVWvXr1048YN7d+/X998841atGihggULZvYwH8vRo0f19NNPK2fOnOrcubOCgoJ09uxZ7dy5U6NHj9bw4cMze4iPZMSIESpSpIhu376tjRs3atKkSVq+fLn27dsnNze3x2p7+fLlmjhxIslxAAAAwM4sXrxYbdu2Vd68edWlSxcVKVJEMTExmj59ur7//nstWLBAL7/8cmYPU9K9xN7w4cNVv379ZCufrVq1KkP6GD9+vG7cuGF+v3z5cn377bcaN26cvL29zeU1a9bMkP4ywq1bt7Rlyxa9//77Cg8Pz+zhPBH2OnfRtm1bNW7cWAkJCTp48KAmTZqkFStWaOvWrapYseJjtZ3W7w8A4PGQGAcAZCkffvihcuXKpd9//125c+e22Hb+/PnMGdQTcvv2bTk7O8vBwboFXZYuXapdu3Zp3rx5ateuXbK27ty5Y4thZri4uDi5u7unuG3cuHG6ceOGdu/ercDAQIttGfXzcPPmzcdORqfXiy++aL4roGvXrsqXL5/Gjh2rH374QW3btn2iYwEAAACQ9UVHR+v1119X0aJF9dtvvyl//vzmbX369FGdOnXUoUMH/fnnnypSpEgmjvThnJ2dM6SdBy+SP3funL799ls1b948zWRiWjGorV24cEGSks13PI7MPJ4kac1nPIm5i8w4B5UrV1aHDh3M72vVqqVmzZpp0qRJmjJlyhMdCwDAeiylDgDIUqKjo1W2bNkUg0QfHx/zv00mk+Li4jR79mzz8lVJz/o+ceKEevTooVKlSilnzpzKly+fWrZsmWwJqqRlrTdt2qSIiAjlz59f7u7uatGihTlYTWIYhkaNGqVChQrJzc1Nzz77rPbv359sjJcvX1a/fv1Uvnx5eXh4yMvLSy+++KL27NljUS/pGdLz58/XBx98IH9/f7m5uenatWuS7gWO5cqVk6urq8qVK6clS5akeK6ke8HXg1xdXeXl5WVRdujQIbVq1Ur58+dXzpw5VapUKb3//vvm7daet9Rs27ZNL7zwgnLlyiU3NzfVq1dPmzZtsqgzbNgwmUwmHThwQO3atVOePHlUu3btVNuMjo5WoUKFkiXFJcufhyQrVqxQvXr15OnpKS8vLz399NP65ptvzNvr16+vcuXKaceOHapbt67c3Nz03nvvSZLi4+M1dOhQFS9eXC4uLgoICNC7776r+Pj4ZP18/fXXqlKlinLmzKm8efOqTZs2OnXqlFXnKSXPPfecJOn48eNp1lu4cKG5X29vb3Xo0EFnzpwxb+/UqZMmTpwoSRZLuwEAAAD4bxszZoxu3rypqVOnWiTFJcnb21tTpkzRjRs3NGbMGHN5p06dUkwQJ8Vl95s5c6aee+45+fj4yMXFRcHBwZo0aVKyfYOCgtSkSRNt3LhR1apVk6urq4oWLao5c+aY68yaNUstW7aUJD377LPmuGTdunWSkj9jPDWHDh3Sa6+9prx588rV1VVVq1bVjz/++ND97tepUyd5eHgoOjpajRs3lqenp9q3by9J2rBhg1q2bKnChQubY8C+ffvq1q1bKbZx5swZNW/eXB4eHsqfP7/69eunhIQEi7rz589XlSpVzDFp+fLlNWHCBEn3zntSbNu/f3+ZTCaLz2fXrl168cUX5eXlJQ8PDzVo0EBbt261aD9pHmP9+vXq0aOHfHx8VKhQIfN5LVeunP7880/Vq1dPbm5uKl68uL7//ntJ0vr161W9enXzfMCaNWuSna8zZ86oc+fO8vX1lYuLi8qWLasZM2ZY1HnYfMaDMnru4mHzCtbG69bMYaSHtXH9wz7nh/3+AAAeD3eMAwCylMDAQG3ZskX79u1TuXLlUq03d+5cde3aVdWqVVP37t0lScWKFZMk/f7779q8ebPatGmjQoUKKSYmRpMmTVL9+vV14MCBZHcH9+rVS3ny5NHQoUMVExOj8ePHKzw8XAsWLDDXGTJkiEaNGqXGjRurcePG2rlzp55//vlkVzYfO3ZMS5cuVcuWLVWkSBHFxsZqypQpqlevng4cOJBsebCRI0fK2dlZ/fr1U3x8vJydnbVq1Sq9+uqrCg4OVmRkpC5duqSwsDBzsHv/uZKkOXPm6IMPPkgzAfrnn3+qTp06cnJyUvfu3RUUFKTo6Gj99NNP+vDDDx/pvN3v119/1YsvvqgqVapo6NChcnBwME+sbNiwQdWqVbOo37JlS5UoUUIfffSRDMNItd3AwECtWbNGv/76qznITM2sWbPUuXNnlS1bVoMGDVLu3Lm1a9cu/fLLLxZXpV+6dEkvvvii2rRpow4dOsjX11eJiYlq1qyZNm7cqO7du6tMmTLau3evxo0bp7/++sviWfYffvihBg8erFatWqlr1666cOGCvvjiC9WtW1e7du16pCv/kyYK8uXLl+bxhYWF6emnn1ZkZKRiY2M1YcIEbdq0ydzvG2+8ob///lurV6/W3Llz0z0OAAAAAFnTTz/9pKCgINWpUyfF7XXr1lVQUJB++uknffnll+luf9KkSSpbtqyaNWumHDly6KefflKPHj2UmJionj17WtQ9evSoXnvtNXXp0kWhoaGaMWOGOnXqpCpVqqhs2bKqW7euevfurc8//1zvvfeeypQpI0nm/1pj//79qlWrlvz9/TVw4EC5u7vru+++U/PmzbVo0SK1aNHC6rb+/fdfhYSEqHbt2vr000/Nse3ChQt18+ZNvfXWW8qXL5+2b9+uL774QqdPn9bChQst2khISFBISIiqV6+uTz/9VGvWrNFnn32mYsWK6a233pIkrV69Wm3btlWDBg00evRoSdLBgwe1adMm9enTR6+88opy586tvn37mpfh9vDwMB9vnTp15OXlpXfffVdOTk6aMmWK6tevb05o369Hjx7Knz+/hgwZori4OHP5P//8oyZNmqhNmzZq2bKlJk2apDZt2mjevHl6++239eabb6pdu3YaM2aMXnvtNZ06dUqenp6SpNjYWD3zzDMymUwKDw9X/vz5tWLFC
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Number of components needed to explain 95% of variance (StandardScaler): 16\n",
"Number of components needed to explain 95% of variance (QuantileTransformer): 44\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# Quantile Transformer vs. Standard Scaler for feature scaling"
],
"metadata": {
"id": "kgdy9N8lhEaF"
}
},
2024-07-29 12:01:01 +00:00
{
"cell_type": "code",
"source": [
"import matplotlib.pyplot as plt\n",
"import seaborn as sns\n",
"import numpy as np\n",
"from sklearn.decomposition import PCA\n",
"from sklearn.preprocessing import StandardScaler, QuantileTransformer\n",
"\n",
"# Assuming X is your feature matrix\n",
"X = linformer_vectors\n",
"\n",
"# Process and plot for StandardScaler\n",
"scaler_std = StandardScaler()\n",
"X_scaled_std = scaler_std.fit_transform(X)\n",
"pca_std = PCA(n_components=0.95)\n",
"X_pca_std = pca_std.fit_transform(X_scaled_std)\n",
"explained_variance_ratio_std = pca_std.explained_variance_ratio_\n",
"cumulative_variance_std = np.cumsum(explained_variance_ratio_std)\n",
"\n",
"# Process and plot for QuantileTransformer\n",
"scaler_qt = QuantileTransformer(n_quantiles=1000, output_distribution='normal')\n",
"X_scaled_qt = scaler_qt.fit_transform(X)\n",
"pca_qt = PCA(n_components=0.95)\n",
"X_pca_qt = pca_qt.fit_transform(X_scaled_qt)\n",
"explained_variance_ratio_qt = pca_qt.explained_variance_ratio_\n",
"cumulative_variance_qt = np.cumsum(explained_variance_ratio_qt)\n",
"\n",
"# Create subplots\n",
"fig, axes = plt.subplots(nrows=2, ncols=2, figsize=(20, 12))\n",
"\n",
"# Scree plot StandardScaler\n",
"axes[0, 0].bar(range(1, len(explained_variance_ratio_std) + 1), explained_variance_ratio_std, alpha=0.6, color='g', label='StandardScaler: Individual explained variance')\n",
"axes[0, 0].set_title('StandardScaler Scree Plot')\n",
"axes[0, 0].set_xlabel('Principal components')\n",
"axes[0, 0].set_ylabel('Explained variance ratio')\n",
"axes[0, 0].legend()\n",
"\n",
"# Cumulative Variance plot StandardScaler\n",
"axes[1, 0].plot(range(1, len(cumulative_variance_std) + 1), cumulative_variance_std, marker='o', linestyle='-', color='b', label='StandardScaler: Cumulative explained variance')\n",
"axes[1, 0].axhline(y=0.95, color='r', linestyle='--', label='95% Explained Variance')\n",
"axes[1, 0].set_title('StandardScaler Cumulative Variance Plot')\n",
"axes[1, 0].set_xlabel('Number of components')\n",
"axes[1, 0].set_ylabel('Cumulative explained variance')\n",
"axes[1, 0].legend()\n",
"\n",
"# Scree plot QuantileTransformer\n",
"axes[0, 1].bar(range(1, len(explained_variance_ratio_qt) + 1), explained_variance_ratio_qt, alpha=0.6, color='g', label='QuantileTransformer: Individual explained variance')\n",
"axes[0, 1].set_title('QuantileTransformer Scree Plot')\n",
"axes[0, 1].set_xlabel('Principal components')\n",
"axes[0, 1].set_ylabel('Explained variance ratio')\n",
"axes[0, 1].legend()\n",
"\n",
"# Cumulative Variance plot QuantileTransformer\n",
"axes[1, 1].plot(range(1, len(cumulative_variance_qt) + 1), cumulative_variance_qt, marker='o', linestyle='-', color='b', label='QuantileTransformer: Cumulative explained variance')\n",
"axes[1, 1].axhline(y=0.95, color='r', linestyle='--', label='95% Explained Variance')\n",
"axes[1, 1].set_title('QuantileTransformer Cumulative Variance Plot')\n",
"axes[1, 1].set_xlabel('Number of components')\n",
"axes[1, 1].set_ylabel('Cumulative explained variance')\n",
"axes[1, 1].legend()\n",
"\n",
"# Adjust layout and display the plots\n",
"plt.tight_layout()\n",
"plt.show()\n",
"\n",
"# Print the number of components needed to explain 95% of the variance\n",
"n_components_std = np.argmax(cumulative_variance_std >= 0.95) + 1\n",
"n_components_qt = np.argmax(cumulative_variance_qt >= 0.95) + 1\n",
"\n",
"print(f\"Number of components needed to explain 95% of variance (StandardScaler): {n_components_std}\")\n",
"print(f\"Number of components needed to explain 95% of variance (QuantileTransformer): {n_components_qt}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 702
},
"id": "dVlcnu0YSiIT",
"outputId": "26e305a3-a0fe-44d2-886c-fa49b8afc4de"
},
"execution_count": 28,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": [
"<Figure size 2000x1200 with 4 Axes>"
],
2024-07-29 12:01:01 +00:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAB8YAAASlCAYAAADaj7M5AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3gU1fv38c8mpJBGC0kghIQOAaQK0lHAKAKCSkdCaCoEkAgCKh2NiFL0izTpoiBSLICUCEhHadKFQGhCaNICBEzm+YMn+2NJYQNZEjfv13XtJXvmzDlnZpO497lnzpgMwzAEAAAAAAAAAAAAAICdcsjsAQAAAAAAAAAAAAAAYEskxgEAAAAAAAAAAAAAdo3EOAAAAAAAAAAAAADArpEYBwAAAAAAAAAAAADYNRLjAAAAAAAAAAAAAAC7RmIcAAAAAAAAAAAAAGDXSIwDAAAAAAAAAAAAAOwaiXEAAAAAAAAAAAAAgF0jMQ4AAAAAAAAAAAAAsGskxgEAeAxBQUHq1KnTE+mrU6dOCgoKeiJ9pWXdunUymUxat25dZg/FbnGOAQAAANir+vXrq379+ub3MTExMplMmjVrVqaN6Uk6cuSInn/+eeXKlUsmk0lLly7N7CHBRkwmk4YNG5bZwwAA3IfEOAAgy9m7d69ee+01BQYGytXVVf7+/mrUqJG++OILi3offfRRtg8gb9y4oaFDh6pcuXJyd3dXvnz5VLFiRfXp00d///13Zg8vQ8TExCgsLEzFihWTq6ur/Pz8VLduXQ0dOjSzh5Zus2bNkslkMr9cXV1VsmRJhYeHKzY2NkP6WL58OYE3AAAAYKf279+vDh06yN/fXy4uLipYsKA6dOigAwcOZPbQLBw4cEDDhg1TTEyMzfoICgqyiK9Se2W1hHtoaKj27t2rDz/8UHPnzlXVqlUze0g2Z09zF0kXciS9HB0dVbhwYbVo0UK7d+/OkD6exO8PAGRXOTJ7AAAA3G/z5s169tlnVbhwYXXr1k1+fn46deqUtm7dqgkTJqhXr17muh999JFee+01NW/ePPMGnInu3r2runXr6tChQwoNDVWvXr1048YN7d+/X998841atGihggULZvYwH8vRo0f19NNPK2fOnOrcubOCgoJ09uxZ7dy5U6NHj9bw4cMze4iPZMSIESpSpIhu376tjRs3atKkSVq+fLn27dsnNze3x2p7+fLlmjhxIslxAAAAwM4sXrxYbdu2Vd68edWlSxcVKVJEMTExmj59ur7//nstWLBAL7/8cmYPU9K9xN7w4cNVv379ZCufrVq1KkP6GD9+vG7cuGF+v3z5cn377bcaN26cvL29zeU1a9bMkP4ywq1bt7Rlyxa9//77Cg8Pz+zhPBH2OnfRtm1bNW7cWAkJCTp48KAmTZqkFStWaOvWrapYseJjtZ3W7w8A4PGQGAcAZCkffvihcuXKpd9//125c+e22Hb+/PnMGdQTcvv2bTk7O8vBwboFXZYuXapdu3Zp3rx5ateuXbK27ty5Y4thZri4uDi5u7unuG3cuHG6ceOGdu/ercDAQIttGfXzcPPmzcdORqfXiy++aL4roGvXrsqXL5/Gjh2rH374QW3btn2iYwEAAACQ9UVHR+v1119X0aJF9dtvvyl//vzmbX369FGdOnXUoUMH/fnnnypSpEgmjvThnJ2dM6SdBy+SP3funL799ls1b948zWRiWjGorV24cEGSks13PI7MPJ4kac1nPIm5i8w4B5UrV1aHDh3M72vVqqVmzZpp0qRJmjJlyhMdCwDAeiylDgDIUqKjo1W2bNkUg0QfHx/zv00mk+Li4jR79mzz8lVJz/o+ceKEevTooVKlSilnzpzKly+fWrZsmWwJqqRlrTdt2qSIiAjlz59f7u7uatGihTlYTWIYhkaNGqVChQrJzc1Nzz77rPbv359sjJcvX1a/fv1Uvnx5eXh4yMvLSy+++KL27NljUS/pGdLz58/XBx98IH9/f7m5uenatWuS7gWO5cqVk6urq8qVK6clS5akeK6ke8HXg1xdXeXl5WVRdujQIbVq1Ur58+dXzpw5VapUKb3//vvm7daet9Rs27ZNL7zwgnLlyiU3NzfVq1dPmzZtsqgzbNgwmUwmHThwQO3atVOePHlUu3btVNuMjo5WoUKFkiXFJcufhyQrVqxQvXr15OnpKS8vLz399NP65ptvzNvr16+vcuXKaceOHapbt67c3Nz03nvvSZLi4+M1dOhQFS9eXC4uLgoICNC7776r+Pj4ZP18/fXXqlKlinLmzKm8efOqTZs2OnXqlFXnKSXPPfecJOn48eNp1lu4cKG5X29vb3Xo0EFnzpwxb+/UqZMmTpwoSRZLuwEAAAD4bxszZoxu3rypqVOnWiTFJcnb21tTpkzRjRs3NGbMGHN5p06dUkwQJ8Vl95s5c6aee+45+fj4yMXFRcHBwZo0aVKyfYOCgtSkSRNt3LhR1apVk6urq4oWLao5c+aY68yaNUstW7aUJD377LPmuGTdunWSkj9jPDWHDh3Sa6+9prx588rV1VVVq1bVjz/++ND97tepUyd5eHgoOjpajRs3lqenp9q3by9J2rBhg1q2bKnChQubY8C+ffvq1q1bKbZx5swZNW/eXB4eHsqfP7/69eunhIQEi7rz589XlSpVzDFp+fLlNWHCBEn3zntSbNu/f3+ZTCaLz2fXrl168cUX5eXlJQ8PDzVo0EBbt261aD9pHmP9+vXq0aOHfHx8VKhQIfN5LVeunP7880/Vq1dPbm5uKl68uL7//ntJ0vr161W9enXzfMCaNWuSna8zZ86oc+fO8vX1lYuLi8qWLasZM2ZY1HnYfMaDMnru4mHzCtbG69bMYaSHtXH9wz7nh/3+AAAeD3eMAwCylMDAQG3ZskX79u1TuXLlUq03d+5cde3aVdWqVVP37t0lScWKFZMk/f7779q8ebPatGmjQoUKKSYmRpMmTVL9+vV14MCBZHcH9+rVS3ny5NHQoUMVExOj8ePHKzw8XAsWLDDXGTJkiEaNGqXGjRurcePG2rlzp55//vlkVzYfO3ZMS5cuVcuWLVWkSBHFxsZqypQpqlevng4cOJBsebCRI0fK2dlZ/fr1U3x8vJydnbVq1Sq9+uqrCg4OVmRkpC5duqSwsDBzsHv/uZKkOXPm6IMPPkgzAfrnn3+qTp06cnJyUvfu3RUUFKTo6Gj99NNP+vDDDx/pvN3v119/1YsvvqgqVapo6NChcnBwME+sbNiwQdWqVbOo37JlS5UoUUIfffSRDMNItd3AwECtWbNGv/76qznITM2sWbPUuXNnlS1bVoMGDVLu3Lm1a9cu/fLLLxZXpV+6dEkvvvii2rRpow4dOsjX11eJiYlq1qyZNm7cqO7du6tMmTLau3evxo0bp7/++sviWfYffvihBg8erFatWqlr1666cOGCvvjiC9WtW1e7du16pCv/kyYK8uXLl+bxhYWF6emnn1ZkZKRiY2M1YcIEbdq0ydzvG2+8ob///lurV6/W3Llz0z0OAAAAAFnTTz/9pKCgINWpUyfF7XXr1lVQUJB++uknffnll+luf9KkSSpbtqyaNWumHDly6KefflKPHj2UmJionj17WtQ9evSoXnvtNXXp0kWhoaGaMWOGOnXqpCpVqqhs2bKqW7euevfurc8//1zvvfeeypQpI0nm/1pj//79qlWrlvz9/TVw4EC5u7vru+++U/PmzbVo0SK1aNHC6rb+/fdfhYSEqHbt2vr000/Nse3ChQt18+ZNvfXWW8qXL5+2b9+uL774QqdPn9bChQst2khISFBISIiqV6+uTz/9VGvWrNFnn32mYsWK6a233pIkrV69Wm3btlWDBg00evRoSdLBgwe1adMm9enTR6+88opy586tvn37mpfh9vDwMB9vnTp15OXlpXfffVdOTk6aMmWK6tevb05o369Hjx7Knz+/hgwZori4OHP5P//8oyZNmqhNmzZq2bKlJk2apDZt2mjevHl6++239eabb6pdu3YaM2aMXnvtNZ06dUqenp6SpNjYWD3zzDMymUwKDw9X/vz5tWLFC
},
"metadata": {}
},
{
"output_type": "stream",
"name": "stdout",
"text": [
"Number of components needed to explain 95% of variance (StandardScaler): 16\n",
"Number of components needed to explain 95% of variance (QuantileTransformer): 44\n"
]
}
]
}
]
}