TabPFN and Tabular Data

TabPFN is a transformer NN for unsupervised classification of small tabular datasets. It's described in this paper and at this GitHub site. It's based on PyTorch. TabPFN does not fit a new model from scratch each time. Instead, it uses a large transformer that has been pre-trained to solve artificially generated classification tasks from a prior dataset. 

The TabPFN details are described in the paper. Rather than going into the details of how it works, I want to test it on some simple datasets. You can read the paper for the details and how the authors evaluated the method.

Diabetes Data Table.

The first set is a collection of data originally from  the National Institute of Diabetes and Digestive and Kidney Diseases. The dataset is was designed to predict whether or not a patient has diabetes, based on certain measurements. All patients included in the data were females at least 21 years old of Pima Indian heritage. The data table is included in the R library mlbench

> library(mlbench)
> data(PimaIndiansDiabetes)
> df_diabetes <- PimaIndiansDiabetes %>% 
                     mutate(diabetes = ifelse(diabetes == 'pos', 1, 0))
>

The columns of PimaIndiansDiabetes are numeric except for the last column which contains the values pos and neg to indicate the results of tests for diabetes. I changed these values to 1 for pos and 0 for neg. The diabetes column is the variable to be predicted.

> library(tidyverse)
> glimpse(df_diabetes)
Rows: 768
Columns: 9
$ pregnant <dbl> 6, 1, 8, 1, 0, 5, 3, 10, 2, 8, 4, 10, 10, 1, 5, 7, 0, 7, 1, 1, 3, 8, 7, 9, 11, 10, 7, 1, 13, 5, 5, 3, 3, 6, 10, 4, 11, 9, 2, 4, 3, 7, 7, 9, 7, …
$ glucose  <dbl> 148, 85, 183, 89, 137, 116, 78, 115, 197, 125, 110, 168, 139, 189, 166, 100, 118, 107, 103, 115, 126, 99, 196, 119, 143, 125, 147, 97, 145, 117$ pressure <dbl> 72, 66, 64, 66, 40, 74, 50, 0, 70, 96, 92, 74, 80, 60, 72, 0, 84, 74, 30, 70, 88, 84, 90, 80, 94, 70, 76, 66, 82, 92, 75, 76, 58, 92, 78, 60, 7$ triceps  <dbl> 35, 29, 0, 23, 35, 0, 32, 0, 45, 0, 0, 0, 0, 23, 19, 0, 47, 0, 38, 30, 41, 0, 0, 35, 33, 26, 0, 15, 19, 0, 26, 36, 11, 0, 31, 33, 0, 37, 42, 47$ insulin  <dbl> 0, 0, 0, 94, 168, 0, 88, 0, 543, 0, 0, 0, 0, 846, 175, 0, 230, 0, 83, 96, 235, 0, 0, 0, 146, 115, 0, 140, 110, 0, 0, 245, 54, 0, 0, 192, 0, 0, …
$ mass     <dbl> 33.6, 26.6, 23.3, 28.1, 43.1, 25.6, 31.0, 35.3, 30.5, 0.0, 37.6, 38.0, 27.1, 30.1, 25.8, 30.0, 45.8, 29.6, 43.3, 34.6, 39.3, 35.4, 39.8, 29.0, …
$ pedigree <dbl> 0.627, 0.351, 0.672, 0.167, 2.288, 0.201, 0.248, 0.134, 0.158, 0.232, 0.191, 0.537, 1.441, 0.398, 0.587, 0.484, 0.551, 0.254, 0.183, 0.529, 0.7$ age      <dbl> 50, 31, 32, 21, 33, 30, 26, 29, 53, 54, 30, 34, 57, 59, 51, 32, 31, 31, 33, 32, 27, 50, 41, 29, 51, 41, 43, 22, 57, 38, 60, 28, 22, 28, 45, 33,…
$ diabetes <dbl> 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 0, 0,…
> write_csv('data/diabetes.csv')

TabPFN is easy to use. It's API is based on scikit-learn. To use TabPFN, create a TabPFNClassifier object, call the fit method to fit the training data, and then use the predict method on test  data. predict returns the predicted values of the test data on the 0/1 scale. 

The authors provide a nice Colab to let you test it. I copied  much of the following   code from the Colab.

def main():
    args = GetArgs()
    data_file = args.data_file
    output_file = args.out_file

    if torch.cuda.is_available():
        device = 'cuda'
    else:
        device = 'cpu'

    df = pd.read_csv(data_file)
    X = df.iloc[:, :-1].to_numpy()
    y = df.iloc[:, -1].to_numpy()
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.33)

    classifier = TabPFNClassifier(device = device, N_ensemble_configurations=4)

    start = time.time()
    classifier.fit(X_train, y_train)
    y_eval, p_eval = classifier.predict(X_test, return_winning_probability = True)
    print('Prediction time: ', time.time() - start, 'Accuracy', accuracy_score(y_test, y_eval))

    out_table = pd.DataFrame(X_test.copy().astype(str))
    out_table['y_eval'] = y_eval
    out_table['probability'] = p_eval
    out_table.to_csv(output_file, index = False)

    # PLOTTING
    # from https://colab.research.google.com/drive/194mCs6SEPEW6C0rcP7xWzcEtt1RBc8jJ#scrollTo=Bkj2F3Q72OB0
 
    fig = plt.figure(figsize=(10,10))
    ax = fig.add_subplot(111)
    cm = plt.cm.RdBu
    cm_bright = ListedColormap(["#FF0000", "#0000FF"])

    # Plot the training points
    vfunc = np.vectorize(lambda x : np.where(classifier.classes_ == x)[0])
    y_train_index = vfunc(y_train)
    y_train_index = y_train_index == 0

    ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train_index, cmap=cm_bright)

    classifier.fit(X_train[:, 0:2], y_train_index)

    DecisionBoundaryDisplay.from_estimator(
        classifier, X_train[:, 0:2], alpha=0.6, ax=ax, eps=2.0, 
                grid_resolution=25, response_method="predict_proba")
    plt.show()


$ python src/test_tabpfn.py data/diabetes.csv -o data/diabetes_out.csv
Loading model that can be used for inference only
Using a Transformer with 25.82 M parameters
Prediction time:  0.8954763412475586 Accuracy 0.7913385826771654

The Colab code also provides a nice plotting routine to display the contours of the fit for the first two testing variables. 

I wanted to compare the performance of TabPFN to the standard technique for this type of data, logistic regression. I looked at logistic regression in a previous post

#' logistic_regr
#'  Run logistic regression on a dataframe.
#'
#' @param df - a dataframe with numeric columns. See requirements below.
#' @param test_pct - the proportion of df used for testing.
#' @param opt_cutoff - the cutoff value for the repsonse to be 1.
#'                     if NULL, uses InformationValue::optimalCutoff
#'
#' @return
#'  a list
#'    model - the logistic regression model.
#'    prediction - predictions of response the test data on the model.
#'    y_eval - predictions on 0/1 basis. Values >= optimal_cutoff = 1
#'    formula - the formula used for the glm model.
#'    optimal_cutoff - probability cutoff score, based on minimizing 
#'                     test case misclassification.
#'    misclass_error - the proportion of the test cases misclassified.
#'    train - the training data from df.
#'    test - the test data from df.
#'    
#' @requires
#'  All dataframe columns must be numeric.
#'  Target variable must be 0/1 and in last column of df.
#'  InformationValue library from https://github.com/selva86/InformationValue.
#'  If you don't want to use this library, set opt_cutoff to some value,
#'  0 < opt_cutoff < 1.
#'
logistic_regr <- function(df, 
                          test_pct = 0.33,
                          opt_cutoff = NULL) {
  if(is.null(opt_cutoff)) {
    require(InformationValue)
  }
  
  samples <- sample(c(TRUE, FALSE), nrow(df), 
                    replace=FALSE, 
                    prob = c(1-test_pct, test_pct))
 
  train <- df[samples, ]
  test <- df[!samples, ]
  
  # create a formula from the df columns
  formula <- paste(names(train)[ncol(train)], '~', names(train)[1])
  for(col in 2:(ncol(train)-1)) {
    formula <- paste(formula, '+', names(train)[col])
  }
  
  model <- glm(formula, 
               data = train, 
               family = binomial,
               maxit = 100)
  
  prediction <- predict(model, test, type = 'response')
  
  if(is.null(opt_cutoff)) {
    opt_cutoff <- optimalCutoff(test[ncol(test)], prediction)
  }
  
  y_eval <- ifelse(prediction >= opt_cutoff, 1, 0)
  
  # do it this way because InformationValue::misClassError sometimes returns
  # count instead of proportion
  err <- sum(test[, ncol(test)] != y_eval) / nrow(test)
  
  return(list(model = model, 
              prediction = prediction,
              y_eval = y_eval,
              formula = formula,
              optimal_cutoff = opt_cutoff,
              misclass_error = err,
              train = train,
              test = test))
}

This code accepts a dataframe, uses the columns of the dataframe to create a formula. It expects the target variable to be in the last column. The training data is fit with glm. The resulting model is used to make predictions based on the test data. The model and predictions are returned.

> system.time(diabetes <- logistic_regr(df_diabetes))
Loading required package: InformationValue
   user  system elapsed 
  0.111   0.000   0.112 
> 
> summary(diabetes$model)

Call:
glm(formula = formula, family = binomial, data = train, maxit = 100)

Coefficients:
             Estimate Std. Error z value Pr(>|z|)    
(Intercept)  2.656721   3.255151   0.816  0.41441    
age         -0.003748   0.027914  -0.134  0.89320    
sex         -1.835069   0.598534  -3.066  0.00217 ** 
cp           0.949214   0.237278   4.000 6.32e-05 ***
trtbps      -0.014622   0.013537  -1.080  0.28009    
chol        -0.002251   0.004427  -0.508  0.61115    
fbs          0.560588   0.668882   0.838  0.40198    
restecg      0.532114   0.432514   1.230  0.21859    
thalachh     0.023373   0.013458   1.737  0.08243 .  
exng        -1.088470   0.524989  -2.073  0.03814 *  
oldpeak     -0.790002   0.275644  -2.866  0.00416 ** 
slp          0.411506   0.458375   0.898  0.36932    
caa         -0.563122   0.217851  -2.585  0.00974 ** 
thall       -1.023453   0.376574  -2.718  0.00657 ** 
---
Signif. codes:  0***0.001**0.01*0.05 ‘.’ 0.1 ‘ ’ 1

(Dispersion parameter for binomial family taken to be 1)

    Null deviance: 282.43  on 204  degrees of freedom
Residual deviance: 142.20  on 191  degrees of freedom
AIC: 170.2

Number of Fisher Scoring iterations: 6

> 1-diabetes$misclass_error
[1] 0.8041667
> diabetes$formula
[1] "diabetes ~ pregnant + glucose + pressure + triceps + insulin + mass + pedigree + age"
> diabetes$optimal_cutoff
[1] 0.7939239

TabPFN and logistic regression have similar performance on this data. By using an optimal cutoff, we give R a slight advantage. Logistic regression for data sets like the diabetes data has one big advantage. It provides a clear, understandable  model of the data. TabPFN like mode NN models tend to be opaque.

Heart Attack Data Table


I tested TabPFN on a second collection of data, a Heart Attack Analysis and Prediction dataset from here.

> glimpse(df_heart)
Rows: 303
Columns: 14
$ age      <dbl> 63, 37, 41, 56, 57, 57, 56, 44, 52, 57, 54, 48, 49, 64, 58, 50, 58, 66, 43, 69, 59, 44, 42, 61, 40, 71, 59, 51, 65, 53, 41, 65, 44, 54, 51, 46,…
$ sex      <dbl> 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1,…
$ cp       <dbl> 3, 2, 1, 1, 0, 0, 1, 1, 2, 2, 0, 2, 1, 3, 3, 2, 2, 3, 0, 3, 0, 2, 0, 2, 3, 1, 2, 2, 2, 2, 1, 0, 1, 2, 3, 2, 2, 2, 2, 2, 2, 1, 0, 0, 2, 1, 2, 2,…
$ trtbps   <dbl> 145, 130, 130, 120, 120, 140, 140, 120, 172, 150, 140, 130, 130, 110, 150, 120, 120, 150, 150, 140, 135, 130, 140, 150, 140, 160, 150, 110, 140$ chol     <dbl> 233, 250, 204, 236, 354, 192, 294, 263, 199, 168, 239, 275, 266, 211, 283, 219, 340, 226, 247, 239, 234, 233, 226, 243, 199, 302, 212, 175, 417$ fbs      <dbl> 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,…
$ restecg  <dbl> 0, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0,…
$ thalachh <dbl> 150, 187, 172, 178, 163, 148, 153, 173, 162, 174, 160, 139, 171, 144, 162, 158, 172, 114, 171, 151, 161, 179, 178, 137, 178, 162, 157, 123, 157$ exng     <dbl> 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0,…
$ oldpeak  <dbl> 2.3, 3.5, 1.4, 0.8, 0.6, 0.4, 1.3, 0.0, 0.5, 1.6, 1.2, 0.2, 0.6, 1.8, 1.0, 1.6, 0.0, 2.6, 1.5, 1.8, 0.5, 0.4, 0.0, 1.0, 1.4, 0.4, 1.6, 0.6, 0.8$ slp      <dbl> 0, 0, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 1, 2, 1, 2, 0, 2, 2, 1, 2, 2, 1, 2, 2, 2, 2, 2, 0, 2, 2, 2, 0, 2, 0, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2,…
$ caa      <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 2, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0,…
$ thall    <dbl> 1, 2, 2, 2, 2, 1, 2, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,…
$ output   <dbl> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,…
> 

Here's how TabPFN performed on this data.

$ python src/test_tabpfn.py data/heart.csv -o data/heart_out.csv
Loading model that can be used for inference only
Using a Transformer with 25.82 M parameters
Prediction time:  0.48320674896240234 Accuracy 0.82


And R's version.


> heart <- logistic_regr(df_heart)
> summary(heart$model)

Call:
glm(formula = formula, family = binomial, data = train, maxit = 100)

Coefficients:
             Estimate Std. Error z value Pr(>|z|)    
(Intercept)  5.531398   3.271904   1.691 0.090918 .  
age         -0.020558   0.029456  -0.698 0.485222    
sex         -1.692600   0.585434  -2.891 0.003838 ** 
cp           0.939770   0.263902   3.561 0.000369 ***
trtbps      -0.018361   0.014361  -1.279 0.201049    
chol        -0.004937   0.004639  -1.064 0.287148    
fbs          0.362309   0.703695   0.515 0.606647    
restecg      0.422797   0.442651   0.955 0.339503    
thalachh     0.013254   0.013531   0.980 0.327322    
exng        -0.677887   0.562504  -1.205 0.228156    
oldpeak     -0.619189   0.277085  -2.235 0.025440 *  
slp          0.844068   0.452875   1.864 0.062350 .  
caa         -1.156652   0.264214  -4.378  1.2e-05 ***
thall       -1.048357   0.362148  -2.895 0.003794 ** 
---
Signif. codes:  0***0.001**0.01*0.05 ‘.’ 0.1 ‘ ’ 1

(Dispersion parameter for binomial family taken to be 1)

    Null deviance: 287.87  on 207  degrees of freedom
Residual deviance: 136.25  on 194  degrees of freedom
AIC: 164.25

Number of Fisher Scoring iterations: 6

> 
> 1 - heart$misclass_error
[1] 0.8631579
> heart$formula
[1] "output ~ age + sex + cp + trtbps + chol + fbs + restecg + thalachh + exng + oldpeak + slp + caa + thall"
> heart$optimal_cutoff
[1] 0.2381438

Performance is comparable. R has a slightly better misclassification rate due to using an optimal cutoff rate.

Source code is available on GitHub.

No comments:

Post a Comment