TabPFN is a transformer NN for unsupervised classification of small tabular datasets. It's described in this paper and at this GitHub site. It's based on PyTorch. TabPFN does not fit a new model from scratch each time. Instead, it uses a large transformer that has been pre-trained to solve artificially generated classification tasks from a prior dataset.
The TabPFN details are described in the paper. Rather than going into the details of how it works, I want to test it on some simple datasets. You can read the paper for the details and how the authors evaluated the method.
Diabetes Data Table.
The first set is a collection of data originally from the National Institute of Diabetes and Digestive and Kidney Diseases. The dataset is was designed to predict whether or not a patient has diabetes, based on certain measurements. All patients included in the data were females at least 21 years old of Pima Indian heritage. The data table is included in the R library mlbench.
> library(mlbench) > data(PimaIndiansDiabetes) > df_diabetes <- PimaIndiansDiabetes %>% mutate(diabetes = ifelse(diabetes == 'pos', 1, 0)) >
> library(tidyverse) > glimpse(df_diabetes) Rows: 768 Columns: 9 $ pregnant <dbl> 6, 1, 8, 1, 0, 5, 3, 10, 2, 8, 4, 10, 10, 1, 5, 7, 0, 7, 1, 1, 3, 8, 7, 9, 11, 10, 7, 1, 13, 5, 5, 3, 3, 6, 10, 4, 11, 9, 2, 4, 3, 7, 7, 9, 7, … $ glucose <dbl> 148, 85, 183, 89, 137, 116, 78, 115, 197, 125, 110, 168, 139, 189, 166, 100, 118, 107, 103, 115, 126, 99, 196, 119, 143, 125, 147, 97, 145, 117… $ pressure <dbl> 72, 66, 64, 66, 40, 74, 50, 0, 70, 96, 92, 74, 80, 60, 72, 0, 84, 74, 30, 70, 88, 84, 90, 80, 94, 70, 76, 66, 82, 92, 75, 76, 58, 92, 78, 60, 7… $ triceps <dbl> 35, 29, 0, 23, 35, 0, 32, 0, 45, 0, 0, 0, 0, 23, 19, 0, 47, 0, 38, 30, 41, 0, 0, 35, 33, 26, 0, 15, 19, 0, 26, 36, 11, 0, 31, 33, 0, 37, 42, 47… $ insulin <dbl> 0, 0, 0, 94, 168, 0, 88, 0, 543, 0, 0, 0, 0, 846, 175, 0, 230, 0, 83, 96, 235, 0, 0, 0, 146, 115, 0, 140, 110, 0, 0, 245, 54, 0, 0, 192, 0, 0, … $ mass <dbl> 33.6, 26.6, 23.3, 28.1, 43.1, 25.6, 31.0, 35.3, 30.5, 0.0, 37.6, 38.0, 27.1, 30.1, 25.8, 30.0, 45.8, 29.6, 43.3, 34.6, 39.3, 35.4, 39.8, 29.0, … $ pedigree <dbl> 0.627, 0.351, 0.672, 0.167, 2.288, 0.201, 0.248, 0.134, 0.158, 0.232, 0.191, 0.537, 1.441, 0.398, 0.587, 0.484, 0.551, 0.254, 0.183, 0.529, 0.7… $ age <dbl> 50, 31, 32, 21, 33, 30, 26, 29, 53, 54, 30, 34, 57, 59, 51, 32, 31, 31, 33, 32, 27, 50, 41, 29, 51, 41, 43, 22, 57, 38, 60, 28, 22, 28, 45, 33,… $ diabetes <dbl> 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0,… > write_csv('data/diabetes.csv')
def main(): args = GetArgs() data_file = args.data_file output_file = args.out_file if torch.cuda.is_available(): device = 'cuda' else: device = 'cpu' df = pd.read_csv(data_file) X = df.iloc[:, :-1].to_numpy() y = df.iloc[:, -1].to_numpy() X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.33) classifier = TabPFNClassifier(device = device, N_ensemble_configurations=4) start = time.time(), y_train) y_eval, p_eval = classifier.predict(X_test, return_winning_probability = True) print('Prediction time: ', time.time() - start, 'Accuracy', accuracy_score(y_test, y_eval)) out_table = pd.DataFrame(X_test.copy().astype(str)) out_table['y_eval'] = y_eval out_table['probability'] = p_eval out_table.to_csv(output_file, index = False) # PLOTTING # from fig = plt.figure(figsize=(10,10)) ax = fig.add_subplot(111) cm = cm_bright = ListedColormap(["#FF0000", "#0000FF"]) # Plot the training points vfunc = np.vectorize(lambda x : np.where(classifier.classes_ == x)[0]) y_train_index = vfunc(y_train) y_train_index = y_train_index == 0 ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train_index, cmap=cm_bright)[:, 0:2], y_train_index) DecisionBoundaryDisplay.from_estimator( classifier, X_train[:, 0:2], alpha=0.6, ax=ax, eps=2.0, grid_resolution=25, response_method="predict_proba")
$ python src/ data/diabetes.csv -o data/diabetes_out.csv Loading model that can be used for inference only Using a Transformer with 25.82 M parameters Prediction time: 0.8954763412475586 Accuracy 0.7913385826771654
#' logistic_regr #' Run logistic regression on a dataframe. #' #' @param df - a dataframe with numeric columns. See requirements below. #' @param test_pct - the proportion of df used for testing. #' @param opt_cutoff - the cutoff value for the repsonse to be 1. #' if NULL, uses InformationValue::optimalCutoff #' #' @return #' a list #' model - the logistic regression model. #' prediction - predictions of response the test data on the model. #' y_eval - predictions on 0/1 basis. Values >= optimal_cutoff = 1 #' formula - the formula used for the glm model. #' optimal_cutoff - probability cutoff score, based on minimizing #' test case misclassification. #' misclass_error - the proportion of the test cases misclassified. #' train - the training data from df. #' test - the test data from df. #' #' @requires #' All dataframe columns must be numeric. #' Target variable must be 0/1 and in last column of df. #' InformationValue library from #' If you don't want to use this library, set opt_cutoff to some value, #' 0 < opt_cutoff < 1. #' logistic_regr <- function(df, test_pct = 0.33, opt_cutoff = NULL) { if(is.null(opt_cutoff)) { require(InformationValue) } samples <- sample(c(TRUE, FALSE), nrow(df), replace=FALSE, prob = c(1-test_pct, test_pct)) train <- df[samples, ] test <- df[!samples, ] # create a formula from the df columns formula <- paste(names(train)[ncol(train)], '~', names(train)[1]) for(col in 2:(ncol(train)-1)) { formula <- paste(formula, '+', names(train)[col]) } model <- glm(formula, data = train, family = binomial, maxit = 100) prediction <- predict(model, test, type = 'response') if(is.null(opt_cutoff)) { opt_cutoff <- optimalCutoff(test[ncol(test)], prediction) } y_eval <- ifelse(prediction >= opt_cutoff, 1, 0) # do it this way because InformationValue::misClassError sometimes returns # count instead of proportion err <- sum(test[, ncol(test)] != y_eval) / nrow(test) return(list(model = model, prediction = prediction, y_eval = y_eval, formula = formula, optimal_cutoff = opt_cutoff, misclass_error = err, train = train, test = test)) }
> system.time(diabetes <- logistic_regr(df_diabetes)) Loading required package: InformationValue user system elapsed 0.111 0.000 0.112 > > summary(diabetes$model) Call: glm(formula = formula, family = binomial, data = train, maxit = 100) Coefficients: Estimate Std. Error z value Pr(>|z|) (Intercept) 2.656721 3.255151 0.816 0.41441 age -0.003748 0.027914 -0.134 0.89320 sex -1.835069 0.598534 -3.066 0.00217 ** cp 0.949214 0.237278 4.000 6.32e-05 *** trtbps -0.014622 0.013537 -1.080 0.28009 chol -0.002251 0.004427 -0.508 0.61115 fbs 0.560588 0.668882 0.838 0.40198 restecg 0.532114 0.432514 1.230 0.21859 thalachh 0.023373 0.013458 1.737 0.08243 . exng -1.088470 0.524989 -2.073 0.03814 * oldpeak -0.790002 0.275644 -2.866 0.00416 ** slp 0.411506 0.458375 0.898 0.36932 caa -0.563122 0.217851 -2.585 0.00974 ** thall -1.023453 0.376574 -2.718 0.00657 ** --- Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 (Dispersion parameter for binomial family taken to be 1) Null deviance: 282.43 on 204 degrees of freedom Residual deviance: 142.20 on 191 degrees of freedom AIC: 170.2 Number of Fisher Scoring iterations: 6 > 1-diabetes$misclass_error [1] 0.8041667 > diabetes$formula [1] "diabetes ~ pregnant + glucose + pressure + triceps + insulin + mass + pedigree + age"
> diabetes$optimal_cutoff
[1] 0.7939239
Heart Attack Data Table
> glimpse(df_heart) Rows: 303 Columns: 14 $ age <dbl> 63, 37, 41, 56, 57, 57, 56, 44, 52, 57, 54, 48, 49, 64, 58, 50, 58, 66, 43, 69, 59, 44, 42, 61, 40, 71, 59, 51, 65, 53, 41, 65, 44, 54, 51, 46,… $ sex <dbl> 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1,… $ cp <dbl> 3, 2, 1, 1, 0, 0, 1, 1, 2, 2, 0, 2, 1, 3, 3, 2, 2, 3, 0, 3, 0, 2, 0, 2, 3, 1, 2, 2, 2, 2, 1, 0, 1, 2, 3, 2, 2, 2, 2, 2, 2, 1, 0, 0, 2, 1, 2, 2,… $ trtbps <dbl> 145, 130, 130, 120, 120, 140, 140, 120, 172, 150, 140, 130, 130, 110, 150, 120, 120, 150, 150, 140, 135, 130, 140, 150, 140, 160, 150, 110, 140… $ chol <dbl> 233, 250, 204, 236, 354, 192, 294, 263, 199, 168, 239, 275, 266, 211, 283, 219, 340, 226, 247, 239, 234, 233, 226, 243, 199, 302, 212, 175, 417… $ fbs <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,… $ restecg <dbl> 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0,… $ thalachh <dbl> 150, 187, 172, 178, 163, 148, 153, 173, 162, 174, 160, 139, 171, 144, 162, 158, 172, 114, 171, 151, 161, 179, 178, 137, 178, 162, 157, 123, 157… $ exng <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,… $ oldpeak <dbl> 2.3, 3.5, 1.4, 0.8, 0.6, 0.4, 1.3, 0.0, 0.5, 1.6, 1.2, 0.2, 0.6, 1.8, 1.0, 1.6, 0.0, 2.6, 1.5, 1.8, 0.5, 0.4, 0.0, 1.0, 1.4, 0.4, 1.6, 0.6, 0.8… $ slp <dbl> 0, 0, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 1, 2, 1, 2, 0, 2, 2, 1, 2, 2, 1, 2, 2, 2, 2, 2, 0, 2, 2, 2, 0, 2, 0, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2,… $ caa <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 2, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,… $ thall <dbl> 1, 2, 2, 2, 2, 1, 2, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,… $ output <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,… >
$ python src/ data/heart.csv -o data/heart_out.csv Loading model that can be used for inference only Using a Transformer with 25.82 M parameters Prediction time: 0.48320674896240234 Accuracy 0.82
> heart <- logistic_regr(df_heart) > summary(heart$model) Call: glm(formula = formula, family = binomial, data = train, maxit = 100) Coefficients: Estimate Std. Error z value Pr(>|z|) (Intercept) 5.531398 3.271904 1.691 0.090918 . age -0.020558 0.029456 -0.698 0.485222 sex -1.692600 0.585434 -2.891 0.003838 ** cp 0.939770 0.263902 3.561 0.000369 *** trtbps -0.018361 0.014361 -1.279 0.201049 chol -0.004937 0.004639 -1.064 0.287148 fbs 0.362309 0.703695 0.515 0.606647 restecg 0.422797 0.442651 0.955 0.339503 thalachh 0.013254 0.013531 0.980 0.327322 exng -0.677887 0.562504 -1.205 0.228156 oldpeak -0.619189 0.277085 -2.235 0.025440 * slp 0.844068 0.452875 1.864 0.062350 . caa -1.156652 0.264214 -4.378 1.2e-05 *** thall -1.048357 0.362148 -2.895 0.003794 ** --- Signif. codes: 0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1 (Dispersion parameter for binomial family taken to be 1) Null deviance: 287.87 on 207 degrees of freedom Residual deviance: 136.25 on 194 degrees of freedom AIC: 164.25 Number of Fisher Scoring iterations: 6 > > 1 - heart$misclass_error [1] 0.8631579 > heart$formula [1] "output ~ age + sex + cp + trtbps + chol + fbs + restecg + thalachh + exng + oldpeak + slp + caa + thall" > heart$optimal_cutoff [1] 0.2381438
