{ "cells": [ { "metadata": {}, "cell_type": "markdown", "source": [ "# Foundational code for TPOT\n", "\n", "Exploring the foundations of the Genetic Programming (GP) library TPOT, which automates the process of selecting the best machine learning model and hyperparameters for a given dataset. This notebook demonstrates the following foundational concepts:\n", "\n", "* Loading data from Elasticsearch\n", "* Preparing nested data for the data pipeline\n", "* Filtering out irrelevant information from traces (problem model)\n", "* Vectorizing text data using BERT (uncased), for English language (semi-natural)\n", "* Training and selecting a model with TPOT\n", "* Evaluating the model and exporting the pipeline\n", "* Visualizing the frequency of models tested by TPOT (GP internals)\n", "* Loading the trained model and making predictions (todo)" ], "id": "9090fc8231b5aa47" }, { "metadata": {}, "cell_type": "code", "source": [ "import requests\n", "import pandas as pd\n", "import json\n", "\n", "# Function to recursively normalize nested columns in a DataFrame\n", "def recursively_normalize(data):\n", " df = pd.json_normalize(data)\n", " while True:\n", " nested_cols = [col for col in df.columns if isinstance(df[col].iloc[0], (dict, list))]\n", " if not nested_cols:\n", " break\n", " for col in nested_cols:\n", " if isinstance(df[col].iloc[0], dict):\n", " normalized = pd.json_normalize(df[col])\n", " df = df.drop(columns=[col]).join(normalized)\n", " elif isinstance(df[col].iloc[0], list):\n", " df = df.explode(col)\n", " normalized = pd.json_normalize(df[col])\n", " df = df.drop(columns=[col]).join(normalized)\n", " return df\n", "\n", "# Function to fetch the next batch using the cursor from the Elastic API\n", "def fetch_next_batch(cursor):\n", " response = requests.post(\n", " f\"{base_url}/_sql?format=json\",\n", " headers={\"Content-Type\": \"application/json\"},\n", " json={\"cursor\": cursor}\n", " ).json()\n", " return response\n", "\n", "# Elasticsearch base URL\n", "base_url = \"http://192.168.20.106:9200\"\n", "# Index name\n", "index = \"winlogbeat-*\"\n", "\n", "from datetime import datetime, timedelta\n", "\n", "# Calculate the current time and the time one hour ago\n", "current_time = datetime.utcnow()\n", "one_hour_ago = current_time - timedelta(hours=1)\n", "\n", "# Format times in ISO8601 format as expected by Elasticsearch\n", "current_time_iso = current_time.strftime('%Y-%m-%dT%H:%M:%SZ')\n", "one_hour_ago_iso = one_hour_ago.strftime('%Y-%m-%dT%H:%M:%SZ')\n", "\n", "# SQL query with time filter\n", "sql_query = f\"\"\"\n", "SELECT \"@timestamp\", host.hostname, host.ip, log.level, winlog.event_id, winlog.task, message\n", "FROM \"winlogbeat-7.10.0-2024.06.23-*\"\n", "WHERE host.hostname = 'win10'\n", "AND winlog.provider_name = 'Microsoft-Windows-Sysmon'\n", "AND \"@timestamp\" >= '{one_hour_ago_iso}'\n", "AND \"@timestamp\" <= '{current_time_iso}'\n", "\"\"\"\n", "\n", "# Initial search request to start scrolling\n", "initial_response = requests.post(\n", " f\"{base_url}/_sql?format=json\",\n", " headers={\"Content-Type\": \"application/json\"},\n", " json={\n", " \"query\": sql_query,\n", " \"field_multi_value_leniency\": True\n", " }\n", ").json()\n", "\n", "# Extract the cursor for scrolling\n", "cursor = initial_response.get('cursor')\n", "rows = initial_response.get('rows')\n", "columns = [col['name'] for col in initial_response['columns']]\n", "\n", "# Initialize CSV file (assumes the first batch is not empty)\n", "if rows:\n", " df = pd.DataFrame(rows, columns=columns)\n", " df = recursively_normalize(df.to_dict(orient='records'))\n", " df.to_csv(\"lab_logs_blindtest_activity.csv\", mode='w', index=False, header=True)\n", "\n", "# Track total documents retrieved\n", "total_documents_retrieved = len(rows)\n", "print(f\"Retrieved {total_documents_retrieved} documents.\")\n", "\n", "# Loop to fetch subsequent batches of documents until no more documents are left\n", "while cursor:\n", " # Fetch next batch of documents using cursor\n", " response = fetch_next_batch(cursor)\n", " \n", " # Update cursor for the next batch\n", " cursor = response.get('cursor')\n", " rows = response.get('rows')\n", " \n", " # If no rows, break out of the loop\n", " if not rows:\n", " break\n", " \n", " # Normalize data and append to CSV\n", " df = pd.DataFrame(rows, columns=columns)\n", " df = recursively_normalize(df.to_dict(orient='records'))\n", " \n", " # Append to CSV file without headers\n", " df.to_csv(\"lab_logs_blindtest_activity.csv\", mode='a', index=False, header=False)\n", " \n", " # Convert DataFrame to JSON, line by line\n", " json_lines = df.to_json(orient='records', lines=True).splitlines()\n", " # Append each line to an existing JSON file\n", " with open(\"lab_logs_blindtest_activity.json\", 'a') as file:\n", " for line in json_lines:\n", " file.write(line + '\\n') # Append each line and add a newline\n", " \n", " # Update total documents retrieved\n", " total_documents_retrieved += len(rows)\n", " \n", " print(f\"Retrieved {total_documents_retrieved} documents.\")\n", "\n", "print(\"Files have been written.\")" ], "id": "initial_id", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "markdown", "source": [ "## Load data from a CSV file\n", "\n", "Load the data from the CSV file into a DataFrame using Polars, a fast DataFrame library in Rust. This step is necessary to prepare the data for further processing and filtering.\n" ], "id": "7dc4287c4b67a923" }, { "metadata": {}, "cell_type": "code", "source": [ "import polars as pl\n", "\n", "# Define the path to your CSV file\n", "csv_file_path = 'lab_logs_blindtest_activity.csv'\n", "\n", "# Load the CSV file into a DataFrame\n", "df = pl.read_csv(csv_file_path)\n", "\n", "# Show the DataFrame to confirm it's loaded correctly\n", "print(df)\n" ], "id": "847862813f6a8c74", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "markdown", "source": [ "## Data filtering and transformation\n", "\n", "Filter out irrelevant information from the traces to focus on the key details. This step involves removing specific lines based on keywords present at the start of the line. The goal is to clean up the data and make it more manageable for further processing." ], "id": "6fb9c9c06da8a061" }, { "metadata": {}, "cell_type": "code", "source": [ "def remove_keyword_lines(batch, keywords):\n", " def modify_line(line):\n", " # Check each keyword; filter the line if the keyword is at the start followed by a colon\n", " for keyword in keywords:\n", " if line.startswith(f\"{keyword}:\"):\n", " # Special handling for 'User' keyword\n", " if keyword == 'User':\n", " parts = line.split('\\\\')\n", " if len(parts) > 1:\n", " return f\"User: {parts[1]}\" # Only keep the part after the backslash\n", " elif keyword == 'SourceHostname':\n", " parts = line.split('.')\n", " if len(parts) > 0:\n", " return f\"{keyword}: {parts[0].split(': ')[1]}\" # Only keep the part before the first dot, remove keyword duplication\n", " return None # For other keywords, remove the line altogether\n", " return line # Return the line unchanged if no keyword conditions are met\n", "\n", " # Use map_elements to apply a function to each message in the batch\n", " return batch.map_elements(lambda message: '\\n'.join(\n", " filter(None, (modify_line(line) for line in message.split('\\n')))), \n", " return_dtype=pl.Utf8)\n", "\n", "\n", "# keywords to filter or process\n", "keywords_to_filter = [\"UtcTime\", \"SourceProcessGUID\",\"ProcessGuid\", \"TargetProcessGUID\", \"TargetObject\", \"FileVersion\", \"Hashes\", \"LogonGuid\", \"LogonId\", \"CreationUtcTime\", \"User\", \"ParentProcessGuid\", \"SourceHostname\"]\n", "\n", "\n", "# Apply the transformation to the 'message' column using map_batches\n", "df_f = df.with_columns(\n", " pl.col(\"message\").map_batches(lambda batch: remove_keyword_lines(batch, keywords_to_filter), return_dtype=pl.Utf8).alias(\"filtered_message\")\n", ")\n", "\n", "# Assuming df_f is your DataFrame with the 'filtered_message' column\n", "# Fetch the first three rows from the 'filtered_message' column\n", "first_messages = df_f[\"filtered_message\"].head(200)\n", "\n", "# Print each message completely\n", "for i, message in enumerate(first_messages):\n", " print(f\"Message {i+1}:\")\n", " print(message)\n", " print(\"-\" * 50) # Separator for readability\n" ], "id": "fc93fe038bcb00c5", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "markdown", "source": [ "## Select specific columns and write to a CSV file\n", "\n", "This is a data reduction approach where only the necessary columns are selected for further processing. The selected columns are then written to a new CSV file for use in subsequent steps." ], "id": "fa298e1c9d0999bd" }, { "metadata": {}, "cell_type": "code", "source": [ "# Assuming df_f is your modified DataFrame with all necessary columns including 'filtered_message'\n", "# Select specific columns from the DataFrame\n", "selected_columns_df = df_f.select([\"log.level\", \"winlog.event_id\", \"winlog.task\",\"filtered_message\"])\n", "\n", "# Write the selected columns to a CSV file\n", "selected_columns_df.write_csv('lab_logs_blindtest_activity_filtered.csv')\n" ], "id": "ff54936e81a933fd", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": "selected_columns_df.head(5)", "id": "da3c38ca8c474ba", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "markdown", "source": [ "## Indexing and inserting a new column\n", "\n", "The following code indexes the events in the dataframe and inserts the index as the first column. This step is essential for tracking the order of events and ensuring that the data remains organized throughout the process." ], "id": "b5eb69ab1b69523f" }, { "metadata": {}, "cell_type": "code", "source": [ "# Create an index series directly\n", "index_series = pl.Series(\"index\", range(selected_columns_df.height))\n", "\n", "# Insert the index series as the first column using the recommended method\n", "selected_columns_df = selected_columns_df.insert_column(0, index_series)\n", "\n", "# Write the DataFrame to a CSV file, including the new index column\n", "selected_columns_df.write_csv('lab_logs_blindtest_activity_filtered.csv')\n" ], "id": "35cd4cc645761608", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "markdown", "source": [ "## TPOT model training and evaluation\n", "\n", "The following code demonstrates how to train a TPOT model using the data prepared in the previous steps. The model is trained on the vectorized text data and evaluated to determine its performance. The best model is then exported for future use." ], "id": "2173f7e8f3ae63a9" }, { "metadata": {}, "cell_type": "markdown", "source": "### Install necessary libraries", "id": "2fbe4ebc4d9038a2" }, { "metadata": {}, "cell_type": "code", "source": "%conda install numpy scipy scikit-learn pandas joblib pytorch", "id": "b3f6a7f89fb1f92e", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": "%pip install deap update_checker tqdm stopit xgboost", "id": "47de32d351fad54f", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": "%pip install tpot", "id": "737d462c559936e2", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "markdown", "source": [ "### Initialize TPOT for Genetic Programming on the CPU\n", "\n", "The following code initializes a TPOT classifier for genetic programming on the CPU. The classifier is trained on the vectorized text data and evaluated to determine its performance. The best model is then exported for future use." ], "id": "ddf2807e5c8a393b" }, { "metadata": {}, "cell_type": "code", "source": [ "import os\n", "\n", "# taking care of a warning message\n", "os.environ[\"KMP_DUPLICATE_LIB_OK\"] = \"TRUE\"\n", "\n", "import polars as pl\n", "import re\n", "from transformers import BertTokenizer, BertModel\n", "import torch\n", "from tpot import TPOTClassifier\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import LabelEncoder" ], "id": "ae96e41f08c7908b", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "markdown", "source": [ "### Building the feature vector\n", "\n", "Here a feature vector is build to extract the relevant features from Sysmon traces. The feature vector is then used to train the classifier with TPOT." ], "id": "33c422b756ff0d9b" }, { "metadata": {}, "cell_type": "code", "source": [ "# Extract relevant information using regular expressions\n", "def extract_info(text):\n", " image = re.search(r\"Image: (.*?\\.exe)\", text, re.IGNORECASE)\n", " target_filename = re.search(r\"TargetFilename: (.*?\\.exe)\", text, re.IGNORECASE)\n", " return {\n", " \"image\": image.group(1) if image else \"\",\n", " \"target_filename\": target_filename.group(1) if target_filename else \"\",\n", " \"text\": text\n", " }" ], "id": "5cecd995c579cd0f", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Apply extraction to the Polars DataFrame using map_elements\n", "selected_columns_df = selected_columns_df.with_columns(\n", " pl.col(\"filtered_message\").map_elements(lambda x: extract_info(x), return_dtype=pl.Object).alias(\"extracted_info\")\n", ")" ], "id": "c2f84d1d644f9111", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Extract fields from the extracted_info column using map_elements with return_dtype\n", "selected_columns_df = selected_columns_df.with_columns(\n", " pl.col(\"extracted_info\").map_elements(lambda x: x['image'], return_dtype=pl.Utf8).alias(\"image\"),\n", " pl.col(\"extracted_info\").map_elements(lambda x: x['target_filename'], return_dtype=pl.Utf8).alias(\"target_filename\"),\n", " pl.col(\"extracted_info\").map_elements(lambda x: x['text'], return_dtype=pl.Utf8).alias(\"text\")\n", ").drop(\"extracted_info\")" ], "id": "b4c8e805cdb9b634", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": "print(selected_columns_df)", "id": "c700056897cc8dd8", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "markdown", "source": [ "#### Define the label based on conditions\n", "\n", "The following code defines the label based on specific conditions. The conditions are applied to the image and target_filename columns to determine whether the event is malicious or benign. The label is then assigned accordingly. This step is crucial for training the TPOT classifier.\n", "\n", "This is a single-label classification problem, where the label is binary (good or bad)." ], "id": "3df9414538271fdc" }, { "metadata": {}, "cell_type": "code", "source": [ "def define_label(row):\n", " conditions = {\n", " (\"EXCEL.EXE\" in row['image'] and \".exe\" in row['target_filename']): \"bad\",\n", " (row['index'] == 874): \"bad\",\n", " # Add more conditions here if needed\n", " }\n", " return conditions.get(True, \"good\")" ], "id": "8d21ff3214accd7a", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Apply the define_label function\n", "selected_columns_df = selected_columns_df.with_columns(\n", " pl.struct([\"index\", \"image\", \"target_filename\"]).map_elements(define_label, return_dtype=pl.Utf8).alias(\"label\")\n", ")" ], "id": "3017223325f75d03", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": "print(selected_columns_df)", "id": "feac611ac2db9fb", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "bad_rows = selected_columns_df.filter(pl.col(\"label\") == \"bad\")\n", "print(bad_rows)" ], "id": "5d634a8db0b99c4", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "markdown", "source": [ "### Vectorizing the text data using BERT\n", "\n", "The following code demonstrates how to vectorize the text data using BERT. The vectorized text data is then used as input for the TPOT classifier. The BERT model is loaded and applied to the text column in the DataFrame to generate the feature vector." ], "id": "a4697a39b64b182f" }, { "metadata": {}, "cell_type": "code", "source": [ "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n", "model = BertModel.from_pretrained('bert-base-uncased')\n", "\n", "def vectorize_text(text):\n", " inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True)\n", " outputs = model(**inputs)\n", " return outputs.last_hidden_state.mean(dim=1).detach().numpy()\n", "\n", "# Apply vectorization to the Polars DataFrame using map_elements\n", "selected_columns_df = selected_columns_df.with_columns(\n", " pl.col(\"text\").map_elements(lambda x: vectorize_text(x).flatten(), return_dtype=pl.Object).alias(\"text_vector\")\n", ")\n", "\n", "print(selected_columns_df)" ], "id": "9262f948e3361ee9", "outputs": [], "execution_count": null }, { "metadata": { "ExecuteTime": { "end_time": "2024-06-25T10:14:44.851973Z", "start_time": "2024-06-25T10:14:43.503508Z" } }, "cell_type": "code", "source": [ "from transformers import BertTokenizer\n", "\n", "# Load the tokenizer\n", "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n", "\n", "# Get the maximum number of tokens\n", "max_tokens = tokenizer.model_max_length\n", "print(\"Maximum number of tokens:\", max_tokens)" ], "id": "3fc605ccdcba223d", "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/mc/anaconda3/lib/python3.11/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Maximum number of tokens: 512\n" ] } ], "execution_count": 10 }, { "metadata": {}, "cell_type": "code", "source": [ "df = selected_columns_df.to_pandas()\n", "\n", "# Save the Pandas DataFrame to a Parquet file\n", "df.to_parquet(\"vectorized_texts.parquet\")" ], "id": "91e007e2b208dc7f", "outputs": [], "execution_count": null }, { "metadata": { "ExecuteTime": { "end_time": "2024-06-25T10:00:03.308042Z", "start_time": "2024-06-25T10:00:01.192778Z" } }, "cell_type": "code", "source": [ "import pandas as pd\n", "# Load the DataFrame from the Parquet file\n", "loaded_df = pd.read_parquet(\"vectorized_texts.parquet\")\n", "\n", "# Verify the loaded DataFrame\n", "print(loaded_df)" ], "id": "48a10b20636b4a2d", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " index log.level winlog.event_id \\\n", "0 0 information 10 \n", "1 1 information 10 \n", "2 2 information 1 \n", "3 3 information 13 \n", "4 4 information 1 \n", "... ... ... ... \n", "1022 1022 information 1 \n", "1023 1023 information 10 \n", "1024 1024 information 1 \n", "1025 1025 information 22 \n", "1026 1026 information 1 \n", "\n", " winlog.task \\\n", "0 Process accessed (rule: ProcessAccess) \n", "1 Process accessed (rule: ProcessAccess) \n", "2 Process Create (rule: ProcessCreate) \n", "3 Registry value set (rule: RegistryEvent) \n", "4 Process Create (rule: ProcessCreate) \n", "... ... \n", "1022 Process Create (rule: ProcessCreate) \n", "1023 Process accessed (rule: ProcessAccess) \n", "1024 Process Create (rule: ProcessCreate) \n", "1025 Dns query (rule: DnsQuery) \n", "1026 Process Create (rule: ProcessCreate) \n", "\n", " filtered_message \\\n", "0 Process accessed:\\nRuleName: -\\nSourceProcessI... \n", "1 Process accessed:\\nRuleName: -\\nSourceProcessI... \n", "2 Process Create:\\nRuleName: -\\nProcessId: 5196\\... \n", "3 Registry value set:\\nRuleName: Tamper-Winlogon... \n", "4 Process Create:\\nRuleName: -\\nProcessId: 6140\\... \n", "... ... \n", "1022 Process Create:\\nRuleName: -\\nProcessId: 5312\\... \n", "1023 Process accessed:\\nRuleName: -\\nSourceProcessI... \n", "1024 Process Create:\\nRuleName: -\\nProcessId: 5000\\... \n", "1025 Dns query:\\nRuleName: -\\nProcessId: 9568\\nQuer... \n", "1026 Process Create:\\nRuleName: -\\nProcessId: 8728\\... \n", "\n", " image target_filename \\\n", "0 C:\\Windows\\system32\\svchost.exe \n", "1 C:\\Windows\\system32\\svchost.exe \n", "2 C:\\Windows\\servicing\\TrustedInstaller.exe \n", "3 C:\\Windows\\servicing\\TrustedInstaller.exe \n", "4 C:\\Windows\\WinSxS\\amd64_microsoft-windows-serv... \n", "... ... ... \n", "1022 C:\\Program Files (x86)\\Microsoft\\EdgeUpdate\\Mi... \n", "1023 C:\\Program Files (x86)\\Microsoft\\EdgeUpdate\\Mi... \n", "1024 C:\\Windows\\System32\\taskhostw.exe \n", "1025 \n", "1026 C:\\Program Files\\RUXIM\\PLUGScheduler.exe \n", "\n", " text label \\\n", "0 Process accessed:\\nRuleName: -\\nSourceProcessI... good \n", "1 Process accessed:\\nRuleName: -\\nSourceProcessI... good \n", "2 Process Create:\\nRuleName: -\\nProcessId: 5196\\... good \n", "3 Registry value set:\\nRuleName: Tamper-Winlogon... good \n", "4 Process Create:\\nRuleName: -\\nProcessId: 6140\\... good \n", "... ... ... \n", "1022 Process Create:\\nRuleName: -\\nProcessId: 5312\\... good \n", "1023 Process accessed:\\nRuleName: -\\nSourceProcessI... good \n", "1024 Process Create:\\nRuleName: -\\nProcessId: 5000\\... good \n", "1025 Dns query:\\nRuleName: -\\nProcessId: 9568\\nQuer... good \n", "1026 Process Create:\\nRuleName: -\\nProcessId: 8728\\... good \n", "\n", " text_vector \n", "0 [-0.32128870487213135, -0.008510575629770756, ... \n", "1 [-0.3122658133506775, -0.00911662820726633, 0.... \n", "2 [-0.3229663372039795, -0.0005048469174653292, ... \n", "3 [-0.21143896877765656, -0.12685905396938324, 0... \n", "4 [-0.3781927824020386, 0.12961240112781525, 0.4... \n", "... ... \n", "1022 [-0.3417365550994873, -0.07532583177089691, 0.... \n", "1023 [-0.2859322726726532, 0.0036172550171613693, 0... \n", "1024 [-0.3556979298591614, -0.038922905921936035, 0... \n", "1025 [-0.2601829469203949, -0.17018236219882965, 0.... \n", "1026 [-0.34428584575653076, -0.09368517994880676, 0... \n", "\n", "[1027 rows x 10 columns]\n" ] } ], "execution_count": 1 }, { "metadata": { "ExecuteTime": { "end_time": "2024-06-25T10:02:36.442233Z", "start_time": "2024-06-25T10:02:36.434607Z" } }, "cell_type": "code", "source": "print(loaded_df.iloc[833])", "id": "53ace25ac9c9a884", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "index 833\n", "log.level information\n", "winlog.event_id 1\n", "winlog.task Process Create (rule: ProcessCreate)\n", "filtered_message Process Create:\\nRuleName: -\\nProcessId: 7680\\...\n", "image C:\\Users\\student\\AppData\\Local\\Temp\\file.exe\n", "target_filename \n", "text Process Create:\\nRuleName: -\\nProcessId: 7680\\...\n", "label good\n", "text_vector [-0.26701870560646057, 0.045040227472782135, 0...\n", "Name: 833, dtype: object\n" ] } ], "execution_count": 2 }, { "metadata": { "ExecuteTime": { "end_time": "2024-06-25T10:09:28.090846Z", "start_time": "2024-06-25T10:09:28.083155Z" } }, "cell_type": "code", "source": [ "# Select row 833\n", "row_833 = loaded_df.iloc[833]\n", "\n", "# Open a file in write mode\n", "with open('output_row_833.txt', 'w') as file:\n", " # Iterate over each item in the row and write it to the file\n", " for column, value in row_833.items():\n", " file.write(f\"{column}: {value}\\n\")\n", "\n", "print(\"Row 833 has been printed to 'output_row_833.txt'\")" ], "id": "9d30ec0bbe695fa3", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Row 833 has been printed to 'output_row_833.txt'\n" ] } ], "execution_count": 9 }, { "metadata": { "ExecuteTime": { "end_time": "2024-06-25T10:09:05.919819Z", "start_time": "2024-06-25T10:09:05.913175Z" } }, "cell_type": "code", "source": [ "# Retrieve the 'text_vector' and print its content and length\n", "text_vector = row_833['text_vector']\n", "print(\"text_vector:\", text_vector)\n", "print(\"Number of elements in 'text_vector':\", len(text_vector))" ], "id": "541fd9246ace12b5", "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "text_vector: [-2.67018706e-01 4.50402275e-02 4.92597967e-01 -1.67466640e-01\n", " 5.49281299e-01 -9.46303830e-02 1.31402090e-02 2.45818749e-01\n", " -3.08580045e-02 -5.99840544e-02 -4.08833206e-01 -3.07607472e-01\n", " -2.50869066e-01 3.46208543e-01 6.52808100e-02 4.03921425e-01\n", " -1.08293675e-01 1.94034234e-01 -1.76812530e-01 1.86863258e-01\n", " 2.65023381e-01 -6.35761321e-02 -1.79275349e-01 4.30784822e-01\n", " 4.68610853e-01 -4.75924127e-02 -5.95399998e-02 -1.90991446e-01\n", " -4.59939897e-01 3.03420693e-01 2.12900817e-01 -1.35921072e-02\n", " 8.98968354e-02 -1.90729022e-01 5.30772842e-02 -1.82377368e-01\n", " -3.18339169e-02 -7.27270097e-02 1.21271454e-01 3.73404026e-01\n", " -2.88292766e-01 -7.65656829e-01 1.46813408e-01 -1.36955306e-01\n", " -8.59620795e-02 -5.23494661e-01 2.55036026e-01 2.86784172e-01\n", " 1.17207281e-01 1.32331714e-01 -5.52943230e-01 1.82576746e-01\n", " -1.92080103e-02 -1.37846798e-01 2.42283568e-01 5.64341307e-01\n", " 5.23214936e-01 -5.78589380e-01 -3.80595475e-01 -2.89306790e-01\n", " 1.57396853e-01 2.33618170e-01 5.47610037e-03 -2.83375513e-02\n", " -2.05104738e-01 3.46614942e-02 1.02004237e-01 6.59699738e-01\n", " -1.08752000e+00 -1.89360574e-01 -2.04591796e-01 2.27338776e-01\n", " 9.44029614e-02 2.02513877e-02 9.50051397e-02 2.69410044e-01\n", " -9.46513638e-02 3.40882629e-01 3.03654343e-01 3.72446954e-01\n", " -5.61144017e-03 2.12720528e-01 1.53695956e-01 4.43269908e-01\n", " 4.89115492e-02 5.01248091e-02 4.95022386e-02 -4.29542996e-02\n", " -6.52586877e-01 2.04600200e-01 -1.13314502e-01 -2.85417408e-01\n", " 6.96188398e-03 2.27123648e-01 5.33699930e-01 2.62654662e-01\n", " -3.85321081e-01 -1.77746072e-01 -9.27561820e-02 -1.84348021e-02\n", " -2.73407519e-01 -2.92661160e-01 -6.63963780e-02 1.52780414e-01\n", " -6.08130038e-01 -2.49575093e-01 -1.89034671e-01 5.01903772e-01\n", " -5.29249646e-02 5.60552955e-01 2.33052485e-03 -1.14192143e-01\n", " -4.86897789e-02 -5.37705064e-01 2.11079076e-01 -1.91006884e-02\n", " -2.13935465e-01 1.59453765e-01 2.82036662e-01 -8.20130631e-02\n", " -1.14828445e-01 -1.57852583e-02 9.31767821e-02 8.90546679e-01\n", " 1.58424422e-01 3.17291170e-01 -2.32521594e-01 3.79392534e-01\n", " 1.57995850e-01 -8.01902786e-02 2.77487546e-01 6.63853526e-01\n", " 2.02542543e-01 1.92820743e-01 -2.98547029e-01 -6.23865426e-03\n", " -2.24696010e-01 -1.12392403e-01 -1.16663694e-01 -2.93985337e-01\n", " 1.10340931e-01 -6.40198588e-01 1.19767703e-01 3.31612915e-01\n", " -4.03458662e-02 -2.13825703e-01 -1.55758616e-02 1.88647553e-01\n", " 3.78850214e-02 2.62739778e-01 -3.14275064e-02 7.16081187e-02\n", " -1.59955516e-01 -2.33498782e-01 1.16682462e-01 7.07069188e-02\n", " 1.99481528e-02 1.67005315e-01 -2.67922860e-02 -1.51443243e-01\n", " 4.75167215e-01 2.98228055e-01 -4.16279286e-01 1.52619034e-01\n", " 2.29130775e-01 -2.76472181e-01 -9.07065049e-02 4.01566863e-01\n", " -3.44911754e-01 -2.54283696e-01 -2.69026697e-01 -6.32005274e-01\n", " 3.29785615e-01 3.28085631e-01 7.28618950e-02 4.50593531e-01\n", " 5.99252284e-01 -3.49201411e-01 6.84104681e-01 -9.08417925e-02\n", " -4.26395625e-01 6.78869784e-01 -1.14344638e-02 -3.80911008e-02\n", " 8.67788568e-02 6.36232719e-02 3.15140396e-01 -4.56063092e-01\n", " 4.09926146e-01 -3.85044441e-02 -3.96032304e-01 -5.23748934e-01\n", " 5.83351962e-02 6.17916696e-02 3.35887998e-01 -2.56321043e-01\n", " -3.37718010e-01 -3.04874420e-01 -1.84115171e-01 -9.61723551e-02\n", " 5.34990847e-01 1.52843669e-01 2.12507904e-01 -3.45869631e-01\n", " -2.10483983e-01 -1.62834480e-01 -2.61416376e-01 -6.48462832e-01\n", " -6.30479634e-01 -1.89883504e-02 -5.69964230e-01 2.21523076e-01\n", " 3.93228568e-02 4.26025271e-01 1.52408615e-01 4.67580482e-02\n", " -1.33270025e-01 -4.19659883e-01 3.19966525e-01 -7.34397303e-03\n", " -3.26589853e-01 6.67044967e-02 -2.77049989e-01 5.33060491e-01\n", " -6.39602005e-01 9.96551275e-01 2.61471361e-01 -5.72765708e-01\n", " 6.08807981e-01 3.22332233e-01 -2.96456724e-01 -5.70633411e-01\n", " 2.50657015e-02 -4.21076864e-01 -2.39366159e-01 2.82484680e-01\n", " -4.70239632e-02 -5.96059561e-02 1.62214309e-01 -1.46467626e-01\n", " -2.85191298e-01 4.87600356e-01 7.80523941e-02 -2.40019396e-01\n", " -1.00024633e-01 -2.66046822e-01 -3.14182967e-01 -1.29386354e-02\n", " 1.45505860e-01 -4.24869448e-01 -3.58255208e-01 1.09599195e-01\n", " 1.08492278e-01 -4.42916781e-01 -8.82621296e-03 -3.69780749e-01\n", " -3.94908078e-02 -2.16823459e-01 -3.14186126e-01 3.29045177e-01\n", " 3.04894775e-01 -5.81190772e-02 1.92793369e-01 1.05493270e-01\n", " 4.00548466e-02 -5.53802609e-01 1.49122447e-01 4.65973318e-01\n", " 5.04793048e-01 1.90268219e-01 -3.01060498e-01 8.74963179e-02\n", " -1.38923228e-01 6.07473552e-01 -2.98727393e-01 -2.53881574e-01\n", " 3.15812156e-02 3.51313591e-01 -3.37296072e-03 -3.98383260e-01\n", " 1.90306306e-02 4.74248469e-01 3.50109525e-02 7.59407952e-02\n", " -1.66543007e-01 -1.51751995e-01 3.15138370e-01 -1.26559790e-02\n", " -3.93086821e-01 -3.61495733e-01 -2.07677290e-01 1.42010570e-01\n", " -3.46750170e-01 -1.39620225e-03 6.34464979e-01 1.79754898e-01\n", " 5.54298684e-02 7.12397397e-02 7.98927248e-02 9.68952402e-02\n", " -2.38104612e-01 4.86172810e-02 -1.19116485e-01 -1.46843791e-01\n", " -1.14842623e-01 2.60607451e-02 -2.00201645e-01 -3.53976756e-01\n", " -3.05691338e+00 -1.97098270e-01 -3.32592458e-01 -2.39055544e-01\n", " -8.48524347e-02 -3.63588184e-01 9.99084562e-02 3.65674160e-02\n", " -3.24371845e-01 2.53723919e-01 -2.17285275e-01 -1.56254023e-01\n", " 7.01330379e-02 -2.64307000e-02 6.06196374e-02 3.07962388e-01\n", " 5.35847187e-01 1.05104328e-03 -2.68761903e-01 5.44907629e-01\n", " -3.44546527e-01 -7.85830259e-01 -6.53781444e-02 1.12628629e-02\n", " 7.69731581e-01 2.55240619e-01 -1.69711173e-01 -2.69975424e-01\n", " -4.25650209e-01 -2.45908737e-01 -8.45987573e-02 -3.33810121e-01\n", " -1.78606287e-01 4.63501394e-01 -8.45592767e-02 1.80356771e-01\n", " 5.39632104e-02 -2.70663440e-01 -2.72808224e-01 -6.24837756e-01\n", " 2.52348691e-01 -1.13212860e+00 -2.23641261e-01 -4.17149335e-01\n", " 7.78771996e-01 -2.98168570e-01 -6.18152581e-02 3.05032164e-01\n", " 1.84052348e-01 3.22197855e-01 -4.59894240e-01 2.41766665e-02\n", " -3.09486777e-01 2.71266922e-02 -4.36577909e-02 -1.54916435e-01\n", " 1.70625895e-01 6.71002865e-01 -5.52631319e-01 -1.30260691e-01\n", " 5.30620515e-01 -2.95754939e-01 -4.66853023e-01 -4.96686220e-01\n", " -2.56180793e-01 -2.70336747e-01 -8.99529755e-01 -3.69107187e-01\n", " -1.47936121e-01 3.16734612e-01 -2.55744904e-01 5.11437297e-01\n", " -5.53669691e-01 -2.04824597e-01 2.02070717e-02 -5.02038836e-01\n", " 1.57953084e-01 -5.00334650e-02 4.21860209e-03 -2.44256444e-02\n", " -3.80785346e-01 -2.30885297e-01 1.54570043e-01 -4.54314053e-02\n", " -4.31950688e-01 -2.24424489e-02 -2.09175229e-01 -3.37632835e-01\n", " -9.06778276e-02 -7.65353292e-02 4.85478938e-01 1.07486293e-01\n", " -1.91006269e-02 8.64723995e-02 2.12791160e-01 -1.12376571e-01\n", " 6.12910688e-01 -2.27594852e-01 2.41480902e-01 -5.89063823e-01\n", " 3.01331013e-01 -9.39759687e-02 5.01071572e-01 -4.15869176e-01\n", " -1.80203408e-01 7.34310299e-02 -6.65124536e-01 -5.82042448e-02\n", " 4.00682539e-01 1.73466071e-01 -3.30419913e-02 2.84719933e-02\n", " 4.09189403e-01 -7.00836182e-01 -3.14641923e-01 6.34882972e-02\n", " 4.62329119e-01 7.15641975e-01 3.61046731e-01 -6.83900356e-01\n", " -4.31078523e-01 4.07016009e-01 -3.25080842e-01 -4.50711042e-01\n", " -6.18534565e-01 3.97195667e-01 -4.49945852e-02 1.95479855e-01\n", " -3.71466905e-01 -2.61483312e-01 -5.66799462e-01 -2.68759340e-01\n", " -1.11281881e-02 1.69699028e-01 1.95249468e-01 -3.53737548e-02\n", " 1.43572360e-01 -3.94385725e-01 -9.68108103e-02 7.33911842e-02\n", " 1.22984558e-01 1.95549816e-01 7.62628242e-02 -1.62749082e-01\n", " 4.00625952e-02 3.39616209e-01 1.87836900e-01 7.70032555e-02\n", " 9.84009579e-02 2.37510577e-01 -3.85231704e-01 -7.11839437e-01\n", " 6.21556379e-02 1.18853919e-01 -1.90843031e-01 3.52634341e-01\n", " 3.28800559e-01 -1.81894362e-01 1.32143423e-02 -5.42842984e-01\n", " 7.44588971e-01 -1.24051608e-01 9.95109901e-02 9.61814374e-02\n", " -1.69666228e-03 7.10986555e-01 5.19338906e-01 -6.60278574e-02\n", " -2.73238301e-01 -3.43201578e-01 -4.30085868e-01 2.00812265e-01\n", " -5.85207231e-02 -2.60092467e-01 -1.12541668e-01 5.20969391e-01\n", " -2.98785549e-02 -2.45256156e-01 1.54213458e-01 7.47456849e-01\n", " 2.11172149e-01 8.73897001e-02 -3.36751372e-01 -2.28837356e-01\n", " 2.90950000e-01 2.33124614e-01 -6.97619617e-02 -1.05824389e-01\n", " 2.29767874e-01 1.42880091e-02 -4.27708745e-01 3.25902164e-01\n", " 1.23169057e-01 -1.22042544e-01 1.31285295e-01 -1.56314909e-01\n", " 3.55436057e-01 5.85144050e-02 -6.84823021e-02 3.34698677e-01\n", " -2.22972199e-01 -4.21880148e-02 -5.25956213e-01 3.95876216e-03\n", " 2.57475451e-02 2.13791475e-01 1.46186620e-01 3.05620939e-01\n", " 1.72737658e-01 1.87649921e-01 -2.57006716e-02 -2.62150764e-01\n", " -2.20731944e-01 -3.92923653e-01 -4.10969891e-02 -3.94284487e-01\n", " 1.25443399e-01 -4.58293818e-02 -2.98196971e-01 1.29661441e-01\n", " -4.10887659e-01 1.30000532e-01 9.74752903e-02 9.46554095e-02\n", " -3.63030657e-02 -1.02588110e-01 -4.85813916e-02 1.79652736e-01\n", " -2.09566951e-01 -2.71597773e-01 1.71052456e-01 -6.08415604e-01\n", " 3.20643261e-02 2.64066100e-01 2.37517804e-01 1.44562602e-01\n", " -6.38284460e-02 -5.08399129e-01 7.96856508e-02 1.05304541e-02\n", " 7.58946314e-02 -1.82433039e-01 -4.07906443e-01 -5.48315167e-01\n", " 5.03587544e-01 -5.18710911e-02 -2.37072140e-01 3.64776939e-01\n", " -2.57527471e-01 3.79561543e-01 1.13509797e-01 -5.83635494e-02\n", " -1.32364985e-02 -2.53996015e-01 3.09146464e-01 -7.05221713e-01\n", " -1.50314182e-01 8.68031830e-02 -4.37532932e-01 1.45563170e-01\n", " -1.22643612e-01 9.08425450e-03 1.94618672e-01 2.91354567e-01\n", " 3.93040568e-01 5.08264005e-01 9.64119434e-02 4.05358881e-01\n", " 1.70785695e-01 -6.33388340e-01 -7.08788261e-02 -3.24917585e-01\n", " -1.15828209e-01 -3.88947241e-02 8.84301811e-02 -4.54051405e-01\n", " -5.41520268e-02 -2.07897663e-01 4.38057929e-01 -4.25215125e-01\n", " -2.27744356e-01 -6.45529926e-02 -4.73755524e-02 2.75590867e-01\n", " 3.33647728e-01 -3.95871699e-01 -5.84744290e-02 3.32015902e-01\n", " 7.79225752e-02 -1.45332381e-01 -6.35314807e-02 2.66312987e-01\n", " -3.48183624e-02 -8.77191201e-02 -4.48695242e-01 2.38731444e-01\n", " 8.37602377e-01 1.71276927e-01 2.29854379e-02 3.79543871e-01\n", " -1.59313351e-01 1.72441214e-01 3.91818672e-01 2.38760009e-01\n", " -3.35281645e-03 4.04623747e-01 1.93819061e-01 -3.36620361e-01\n", " -3.78950536e-01 -2.13610940e-03 -2.61642277e-01 -6.36941493e-01\n", " 5.27241886e-01 2.35063493e-01 -1.94114745e-01 -2.12649107e-01\n", " -1.77295357e-01 1.05155604e-02 -1.80615142e-01 -2.47896895e-01\n", " 7.88988639e-03 2.18618929e-01 3.94153148e-01 1.27877682e-01\n", " 2.46168748e-01 2.86875844e-01 -3.13660413e-01 -2.07360253e-01\n", " 6.35999367e-02 1.97521850e-01 -6.57110848e-03 2.90463027e-02\n", " -2.72324644e-02 6.83545291e-01 -1.96372122e-01 -2.69587375e-02\n", " -1.11118995e-01 3.03671286e-02 -3.13001186e-01 -1.17833935e-01\n", " 2.12221742e-01 3.11293155e-01 2.73206800e-01 1.63645178e-01\n", " 8.53751376e-02 4.16733742e-01 3.74694429e-02 -5.57099730e-02\n", " 6.29755437e-01 6.22547388e-01 6.87231794e-02 -1.88300475e-01\n", " -1.83759540e-01 -9.02508423e-02 -5.79310171e-02 3.23269188e-01\n", " 2.10681632e-01 3.39448228e-02 1.54202551e-01 7.15607107e-01\n", " 2.56150514e-01 8.09349120e-02 2.78974384e-01 -5.26742458e-01\n", " -3.03853303e-01 1.26324639e-01 8.42980564e-01 -8.46742019e-02\n", " 2.52615511e-02 9.13625360e-02 4.60760072e-02 2.99819440e-01\n", " -6.11319542e-02 -1.79427490e-01 6.39528036e-02 -1.20934166e-01\n", " 6.33755550e-02 3.53766531e-01 -2.75289297e-01 -9.12263095e-02\n", " -3.77972484e-01 -2.97383815e-01 -1.29695993e-03 -2.98657745e-01\n", " -1.81566820e-01 -3.03760320e-01 5.15599661e-02 2.27060001e-02\n", " -6.88230246e-02 -5.33181429e-01 1.97003484e-01 -2.06039757e-01\n", " 1.46390632e-01 -2.36624375e-01 -8.99250656e-02 -7.15578049e-02\n", " -1.34127915e-01 3.78443688e-01 7.13742316e-01 4.28920053e-02\n", " 1.49822384e-01 -6.30574346e-01 -2.76538849e-01 -3.03445198e-02\n", " 2.55319506e-01 1.43449921e-02 8.51057097e-02 9.58990455e-02\n", " -3.84655952e-01 1.10321894e-01 -1.68833118e-02 3.01812626e-02\n", " -2.01801896e-01 -8.05913210e-02 -2.59703428e-01 3.84925336e-01\n", " 1.88098475e-01 1.00619376e-01 -5.08317411e-01 -3.72170620e-02\n", " -2.70777494e-01 -3.42400938e-01 1.07098818e-01 -2.33876809e-01\n", " 5.28669618e-02 -9.76444557e-02 -4.68501411e-02 2.59407789e-01\n", " 5.92825040e-02 -2.38227639e-02 2.69253582e-01 6.03045106e-01\n", " 1.63432300e-01 3.88628036e-01 -4.27469641e-01 -2.59980768e-01\n", " -4.64719757e-02 2.97374547e-01 -2.86296487e-01 3.56291205e-01\n", " -8.48512501e-02 1.18831411e-01 1.91253379e-01 2.94847518e-01\n", " -2.25561559e-02 -4.92864288e-03 -1.96359120e-02 -1.94969982e-01\n", " -4.33832109e-01 -5.63065827e-01 2.26254407e-02 3.34261507e-02\n", " 8.88495073e-02 -1.32661715e-01 -1.71310976e-01 -9.50060636e-02\n", " -3.76367658e-01 -5.01346588e-01 -1.32544145e-01 5.43285310e-01]\n", "Number of elements in 'text_vector': 768\n" ] } ], "execution_count": 8 }, { "metadata": {}, "cell_type": "code", "source": [ "import os\n", "import pandas as pd\n", "import torch\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import LabelEncoder\n", "from tpot import TPOTClassifier\n", "\n", "# Load the DataFrame from the Parquet file\n", "df = pd.read_parquet(\"vectorized_texts.parquet\")\n", "\n", "# Ensure to use only CPU for PyTorch\n", "device = torch.device(\"cpu\")\n", "\n", "# Encode labels\n", "le = LabelEncoder()\n", "df['label_encoded'] = le.fit_transform(df['label'])\n", "\n", "# Split data\n", "X_train, X_test, y_train, y_test = train_test_split(df['text_vector'].tolist(), df['label_encoded'], test_size=0.2, random_state=42)\n", "\n", "# Convert lists to numpy arrays\n", "X_train = torch.tensor(X_train, device=device).numpy()\n", "X_test = torch.tensor(X_test, device=device).numpy()\n", "\n", "# TPOT classifier with higher verbosity\n", "tpot = TPOTClassifier(verbosity=3, generations=5, population_size=20)\n", "tpot.fit(X_train, y_train)\n", "\n", "# Evaluate the model\n", "print(\"TPOT Score:\", tpot.score(X_test, y_test))\n", "\n", "# Save the trained model\n", "tpot.export('tpot_pipeline.py')\n", "\n", "# Print the exported pipeline\n", "with open('tpot_pipeline.py') as f:\n", " print(f.read())\n", "\n", "# Example of using the trained model\n", "predictions = tpot.predict(X_test)\n", "print(\"Predictions:\", predictions)\n" ], "id": "75d84e297b03eaf4", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": "print(\"The accuracy of the best model is: \", tpot.score(X_test, y_test))\n", "id": "6cf76b5736411710", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": "%pip install matplotlib", "id": "d99c8aa5529a72d1", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "import os\n", "import pandas as pd\n", "import torch\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.preprocessing import LabelEncoder\n", "from tpot import TPOTClassifier\n", "from collections import Counter\n", "\n", "# Load the DataFrame from the Parquet file\n", "df = pd.read_parquet(\"vectorized_texts.parquet\")\n", "\n", "# Ensure to use only CPU for PyTorch\n", "device = torch.device(\"cpu\")\n", "\n", "# Encode labels\n", "le = LabelEncoder()\n", "df['label_encoded'] = le.fit_transform(df['label'])\n", "\n", "# Split data\n", "X_train, X_test, y_train, y_test = train_test_split(df['text_vector'].tolist(), df['label_encoded'], test_size=0.2, random_state=42)\n", "\n", "# Convert lists to numpy arrays\n", "X_train = torch.tensor(X_train, device=device).numpy()\n", "X_test = torch.tensor(X_test, device=device).numpy()\n", "\n", "# TPOT classifier with higher verbosity\n", "tpot = TPOTClassifier(verbosity=3, generations=5, population_size=20)\n", "tpot.fit(X_train, y_train)\n", "\n", "# Evaluate the model\n", "print(\"TPOT Score:\", tpot.score(X_test, y_test))\n", "\n", "# Save the trained model\n", "tpot.export('tpot_pipeline.py')\n", "\n", "# Print the exported pipeline\n", "with open('tpot_pipeline.py') as f:\n", " print(f.read())\n", "\n", "# Example of using the trained model\n", "predictions = tpot.predict(X_test)\n", "print(\"Predictions:\", predictions)\n", "\n", "# Extract information about models tested\n", "evaluated_pipelines = tpot.evaluated_individuals_\n" ], "id": "705690ce71dfda4c", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": [ "# Count occurrences of each model type\n", "model_counter = Counter()\n", "for pipeline_str in evaluated_pipelines.keys():\n", " models = re.findall(r'\\w+\\(.*?\\)', pipeline_str)\n", " for model in models:\n", " model_name = model.split('(')[0]\n", " model_counter[model_name] += 1\n", "\n", "print(\"Models and their occurrences:\")\n", "for model, count in model_counter.items():\n", " print(f\"{model}: {count}\")\n", "\n", "# Visualize the count of different models\n", "import matplotlib.pyplot as plt\n", "\n", "model_names = list(model_counter.keys())\n", "model_counts = list(model_counter.values())\n", "\n", "plt.figure(figsize=(12, 6))\n", "plt.barh(model_names, model_counts, color='skyblue')\n", "plt.xlabel('Number of Occurrences')\n", "plt.ylabel('Model')\n", "plt.title('Frequency of Models Tested by TPOT')\n", "plt.show()" ], "id": "565066bf3b5f0820", "outputs": [], "execution_count": null }, { "metadata": {}, "cell_type": "code", "source": "", "id": "f6faa6d6265c094e", "outputs": [], "execution_count": null } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 5 }