{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Feature Selection using `chemml.optimization.GeneticAlgorithm`\n", "\n", "\n", "We use a sample dataset from ChemML library which has the SMILES codes and Dragon molecular descriptors for 500 small organic molecules with their densities in $kg/m^3$. For simplicity, we perform feature selection using Genetic Algorithm on a subset of 20 molecular descriptors. \n", "\n", "For more information on Genetic Algorithm, please refer to our [paper](https://doi.org/10.26434/chemrxiv.9782387.v1) " ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(500, 1) (500, 20)\n" ] } ], "source": [ "from chemml.datasets import load_organic_density\n", "_,density,features = load_organic_density()\n", "num_cols = 20\n", "col_names = features.iloc[:, :num_cols].columns\n", "features = features.iloc[:,:num_cols]\n", "print(density.shape, features.shape)\n", "density, features = density.values, features.values" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Defining hyperparameter space\n", "\n", "For this, each individual feature is encoded as a binary bit of the chromosome. \n", "\n", "0 indicates feature is discarded.\n", "\n", "1 indicates feature is selected." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "space = tuple([{i: {'choice': [0,1]}} for i in range(features.shape[1])])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Defining objective function\n", "The objective function is defined as a function that receives one ‘individual’ of the genetic algorithm’s population that is an ordered list of the hyperparameters defined in the space variable. Within the objective function, the user does all the required calculations and returns the metric (as a tuple) that is supposed to be optimized. If multiple metrics are returned, all the metrics are optimized according to the fitness defined in the initialization of the Genetic Algorithm class.\n", "\n", "Here, we use a simple linear regression model to fit the data." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from sklearn.metrics import mean_absolute_error\n", "import pandas as pd\n", "from sklearn.linear_model import LinearRegression\n", "def obj(individual, features=features):\n", " df = pd.DataFrame(features)\n", " new_cols = list(map(bool, individual))\n", " df = df[df.columns[new_cols]]\n", " features = df.values\n", " ridge = LinearRegression(n_jobs=1)\n", " ridge.fit(features[:400], density[:400])\n", " pred = ridge.predict(features[400:])\n", " return mean_absolute_error(density[400:], pred)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Optimize the feature space" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "from chemml.optimization import GeneticAlgorithm\n", "ga = GeneticAlgorithm(evaluate=obj, space=space, fitness=(\"min\", ), crossover_type=\"Uniform\",\n", " pop_size = 10, crossover_size=6, mutation_size=4, algorithm=3)\n", "fitness_df, final_best_features = ga.search(n_generations=20)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "`ga.search` returns:\n", "\n", "- a dataframe with the best individuals of each generation along with their fitness values and the time taken to evaluate the model\n", "\n", "- a dictionary containing the best individual (in this case the top features) " ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Best_individualFitness_valuesTime (hours)
0(1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, ...14.2318720.000518
1(1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, ...13.5116230.000494
2(1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, ...11.9697150.000498
3(1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, ...11.9697150.000511
4(1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, ...11.9697150.000506
5(1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, ...11.9697150.000495
6(1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, ...11.9697150.000482
7(1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, ...11.6261400.000487
8(1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, ...11.6261400.000493
9(1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, ...11.5360570.000490
10(1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, ...11.5360570.000487
11(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, ...11.4054680.000486
12(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, ...11.4054680.000480
13(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, ...11.4054680.000484
14(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, ...11.4054680.000492
15(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, ...11.4054680.000489
16(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, ...11.4054680.000487
17(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, ...11.4054680.000494
18(1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, ...11.4054680.000504
19(1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, ...11.4040440.000476
\n", "
" ], "text/plain": [ " Best_individual Fitness_values \\\n", "0 (1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, ... 14.231872 \n", "1 (1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, ... 13.511623 \n", "2 (1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, ... 11.969715 \n", "3 (1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, ... 11.969715 \n", "4 (1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, ... 11.969715 \n", "5 (1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, ... 11.969715 \n", "6 (1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, ... 11.969715 \n", "7 (1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, ... 11.626140 \n", "8 (1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, ... 11.626140 \n", "9 (1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, ... 11.536057 \n", "10 (1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, ... 11.536057 \n", "11 (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, ... 11.405468 \n", "12 (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, ... 11.405468 \n", "13 (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, ... 11.405468 \n", "14 (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, ... 11.405468 \n", "15 (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, ... 11.405468 \n", "16 (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, ... 11.405468 \n", "17 (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, ... 11.405468 \n", "18 (1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, ... 11.405468 \n", "19 (1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, ... 11.404044 \n", "\n", " Time (hours) \n", "0 0.000518 \n", "1 0.000494 \n", "2 0.000498 \n", "3 0.000511 \n", "4 0.000506 \n", "5 0.000495 \n", "6 0.000482 \n", "7 0.000487 \n", "8 0.000493 \n", "9 0.000490 \n", "10 0.000487 \n", "11 0.000486 \n", "12 0.000480 \n", "13 0.000484 \n", "14 0.000492 \n", "15 0.000489 \n", "16 0.000487 \n", "17 0.000494 \n", "18 0.000504 \n", "19 0.000476 " ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "fitness_df" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{0: 1, 1: 1, 2: 1, 3: 0, 4: 1, 5: 1, 6: 1, 7: 1, 8: 1, 9: 1, 10: 1, 11: 0, 12: 0, 13: 1, 14: 1, 15: 1, 16: 1, 17: 0, 18: 0, 19: 1}\n" ] } ], "source": [ "print(final_best_features)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " MW AMW Sv Sp Si Mv Me Mp Mi GD \\\n", "0 285.54 7.932 22.324 25.801 39.995 0.620 0.980 0.717 1.111 0.140 \n", "1 240.24 9.240 18.569 18.455 29.242 0.714 1.032 0.710 1.125 0.131 \n", "2 313.42 8.471 24.643 25.938 41.697 0.666 1.009 0.701 1.127 0.108 \n", "3 218.32 9.492 15.153 16.466 25.544 0.659 1.024 0.716 1.111 0.179 \n", "4 319.45 9.396 23.303 24.824 38.202 0.685 1.015 0.730 1.124 0.114 \n", ".. ... ... ... ... ... ... ... ... ... ... \n", "495 296.36 7.799 24.676 25.500 42.903 0.649 1.010 0.671 1.129 0.108 \n", "496 328.50 8.645 25.621 27.887 42.324 0.674 0.996 0.734 1.114 0.108 \n", "497 373.52 8.120 30.982 33.006 51.318 0.674 0.995 0.718 1.116 0.088 \n", "498 323.50 8.513 24.158 26.142 43.404 0.636 1.007 0.688 1.142 0.114 \n", "499 348.42 8.103 30.445 31.796 47.112 0.708 0.993 0.739 1.096 0.088 \n", "\n", " nTA nBT nBO nBM RBF \n", "0 0.0 38.0 19.0 5.0 0.079 \n", "1 1.0 28.0 20.0 17.0 0.071 \n", "2 0.0 40.0 25.0 16.0 0.075 \n", "3 2.0 24.0 14.0 5.0 0.042 \n", "4 0.0 37.0 24.0 15.0 0.081 \n", ".. ... ... ... ... ... \n", "495 0.0 41.0 25.0 16.0 0.073 \n", "496 1.0 41.0 25.0 16.0 0.049 \n", "497 0.0 50.0 31.0 22.0 0.060 \n", "498 0.0 41.0 24.0 10.0 0.073 \n", "499 0.0 47.0 31.0 28.0 0.064 \n", "\n", "[500 rows x 15 columns]\n" ] } ], "source": [ "# Printing out which columns were selected \n", "\n", "df = pd.DataFrame(features, columns=col_names)\n", "new_cols = list(map(bool, tuple([v for k,v in final_best_features.items()])))\n", "print(df[df.columns[new_cols]])" ] } ], "metadata": { "kernelspec": { "display_name": "chemml_dev_env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.5" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "state": {}, "version_major": 2, "version_minor": 0 } } }, "nbformat": 4, "nbformat_minor": 2 }