En los 煤ltimos meses he estado tratando de aprender a optimizar mi c贸digo en R. En general, R funciona bastante bien, pero hay varias aplicaciones donde se puede poner muy, pero muy lento, en particular los loops. Un b煤squeda r谩pida muestra que existen muchos art铆culos que discuten el tema con bastante profundidad.

Uno de los temas que encontr茅 y llamaron mi atenci贸n fue la capacidad de integrar R con otros lenguajes de programaci贸n, en particular con algunos de m谩s bajo nivel, como C, Fortran o C++. La integraci贸n con C y Fortran est谩 ya solucionada directamente en R, a trav茅s de la funci贸n .Call, pero esta forma de llamar a c贸digo externo tiene la gran desventaja que todas las funciones deben ser declaradas como void, es decir, debemos modificar los valores entregados usando directamente punteros a los objetos en memoria. Esto es… engorroso, al menos lo fue en mis primeras experiencias programando.

La soluci贸n: usar C++. Es un lenguaje moderno, que corre eficientemente los loops, y que tiene muchas herramientas que facilitan la vida al programador. Para integrar R y C++ existe el excelente paquete Rcpp, el cual permite crear funciones de forma din谩mica y sencilla en R, y que adem谩s permite usar objetos inherentes a R (vectores num茅ricos, vectores de caract茅res, dataframes, etc.) en C++, pero a煤n este paquete tiene una falencia: no soporta operaciones matriciales eficientes. Para ello, un plugin del mismo paquete llamado RcppArmadillo ofrece una interfaz entre Rcpp y la librer铆a Armadillo que entrega herramientas para operaciones matriciales en C++. Con esto tenemos todo lo que necesitamos.

En este art铆culo programar茅 un c贸digo sencillo para obtener un cl煤ster de K-Medias y comparar茅 las implementaciones en c贸digo en R, la implementaci贸n oficial (en C) y una implementaci贸n propia en C++. El algoritmo es bastante sencillo: inicializa vectores de forma aleatoria, para luego ir ajustando los clusters en cada iteraci贸n hasta alcanzar convergencia.

Implementaci贸n (ineficiente) en R.

kmeans.R <- function(x,k,eps = 10e-6){
  # k es la cantidad de clusters, x es la matriz de datos.
  # eps es la tolerancia m谩xima entre una iteraci贸n y otra.
  require(pdist)
  cluster <- rep(0,nrow(x)) # Cluster actual del caso.
  # Inicializa centroides.
  centroids <- matrix(runif(nrow(x) * k), nrow = k, ncol = ncol(x))
  # Guarda 煤ltima iteraci贸n
  centroids.old <- matrix(0, nrow = k, ncol = ncol(x))
  eps.vec <- rep(eps,k)
  dists.centroids <- 1000
  while(any(dists.centroids > eps.vec)){
    centroids.old <- centroids
    # Matriz con las distancias entre casos y vectores.
    dists.x <- as.matrix(pdist(x, centroids)) 
    min.x <- apply(dists.x, 1, min)
    # Obtiene indice del menor valor, i.e., el cluster.
    cluster <- apply(cbind(dists.x,min.x), 1, 
                     function(x) which(x[1:k] == x[k+1])[1])     
    for(i in 1:k){
      # Calcula el centroide.
      centroids[i,] <- apply(x[cluster == i, ], 2, mean) 
    }
    dists.centroids <- diag(as.matrix(pdist(centroids, centroids.old)))
  }
  # Devuelve cluster y centroides.
  return(list(cluster = cluster, centroids = centroids)) 
}

Para probar el c贸digo, en R simularemos una matriz sencilla con dos distribuciones. La primera una normal de est谩ndar y la segunda una normal de media 5 y desviaci贸n est谩ndar 1. Posteriormente aplicamos la funci贸n y graficamos. Noten que estoy usando el paquete ‘pdist’, solo para facilitarme la vida cuando calculo las distancias entre matrices. Este paquete tira algunas warnings en la 煤ltima iteraci贸n, simplemente para avisar que convergi贸, as铆 que ign贸renlas.

library(pdist)
x <- rbind(matrix(rnorm(1000 * 2, mean = 0), nrow = 1000, ncol = 2), 
           matrix(rnorm(1000 * 2, mean = 5), nrow = 1000, ncol = 2))
x.kmeans.R <- kmeans.R(x, 2)
x.kmeans.R$centroids
# Resultado:
#            [,1]       [,2]
# [1,] 0.04097412 0.04618219
# [2,] 5.02960813 4.97157822
plot(x, col=ifelse(x.kmeans.R$cluster == 1, "red", "green"))

El resultado es el siguiente:

Datos en dos segmentos.

Implementaci贸n en C++.
Implementemos ahora la misma funci贸n en C++. Para ello usaremos Rcpp y RcppArmadillo. Primero, es necesario instalar los paquetes:

install.packages(c('Rcpp', 'RcppArmadillo'))

.

Ahora, existen varias rutas para implementar la funci贸n. Es posible escribir directamente la funci贸n en un gran string de texto y compilarla en l铆nea, o se puede guardar un archivo anexo que tenga la funci贸n programada. Seguiremos este camino. En un nuevo archivo (kmeansCpp.cpp) programamos la funci贸n de kmedias.

#include <RcppArmadillo.h>
using namespace Rcpp;

// [[Rcpp::depends(RcppArmadillo)]]

// [[Rcpp::export]]
List kmeansCpp(NumericMatrix xa, int k, double eps = 10e-3) {
  arma::mat x(xa.begin(), xa.nrow(), xa.ncol(), false);
  // Inicializa centroides de forma uniforme en [0-1]
  arma::mat centroids(k, xa.ncol(), arma::fill::randu);
  // Incializa con puros 0's.
  arma::mat centroids_old(k, xa.ncol(), arma::fill::zeros);
  // Crea vector fila tama帽o 1xk.
  arma::rowvec eps_vec = eps * arma::ones<arma::rowvec>(k);
  arma::rowvec dists_centroids = 100 * arma::ones<arma::rowvec>(k);
  // Crea matriz tama帽o casos x k.
  arma::mat dists_x(xa.nrow(), k);
  LogicalMatrix temp(xa.nrow(), k);
  arma::mat cluster(xa.nrow(), k);
  while(arma::any(dists_centroids > eps_vec)){
    centroids_old = centroids;
    // Calcula matriz de distancias
    for(int i = 0; i<k; i++){
      dists_x.col(i) = sqrt(sum((x - (arma::ones(xa.nrow(), 1) 
                                  * centroids.row(i))) % 
                      (x - (arma::ones(xa.nrow(), 1) * centroids.row(i))), 1));
    }
    // Obtiene indice del menor valor, i.e., el cluster.
    arma::colvec min_x = min(dists_x, 1);
    temp = wrap((min_x * arma::ones(1, k) == dists_x));
    cluster = as<arma::mat>(temp);
    // Calcula el centroide.
    for(int ia = 0; ia < k; ia++){
      centroids.row(ia) = mean(x.rows(find(cluster.col(ia) == 1)), 0);
    }
    dists_centroids = (sqrt(sum((centroids - centroids_old) % 
                      (centroids - centroids_old), 1))).t();
  }
  // Devuelve cluster y centroides.
  return List::create(Named("cluster") = wrap(cluster), 
                      Named("centroids") = wrap(centroids));
}

Analicemos un poco el programa:

  • La l铆nea #include <RcppArmadillo.h> le indica al compilador de C++ que importe la cabecera que nos permitir谩 usar todas las herramientas para la interfaz con R.
  • using namespaces Rcpp nos permite utilizar todas las funciones de Rcpp sin tener que colocar Rcpp::NombreFuncion cada vez que la llamamos. Tambi茅n puede ser 煤til colocar using namespaces arma para ahorrarse los arma::.
  • // [[Rcpp::depends(RcppArmadillo)]] y // [[Rcpp::export]] le dicen a Rcpp que tiene que ocupar Armadillo y que debe exportar la funci贸n a R.

Luego viene la funci贸n en si. Tanto Rcpp como Armadillo poseen clases que permiten manejar vectores, matrices, listas, etc. En este caso, es importante notar que creamos una funci贸n que recibe y entrega objetos de Rcpp. La funci贸n recibe una matriz num茅rica (NumericMatrix, objeto de Rcpp), un entero y un double y entrega una lista de R. Los objetos de Rcpp son, a grandes rasgos, matrices (Matrix), vectores (Vector), listas (List) y Dataframes (Dataframe); mientras que los dos primeros pueden ser num茅ricos (Numeric), enteros (Integer), literales (String) o booleanos (Logical). As铆, una matriz num茅rica corresponde a un objeto NumericMatrix, mientras que un vector booleano ser谩 un LogicalVector. Cada objeto tiene diversos m茅todos que se le pueden aplicar, pero en general las operaciones matem谩ticas funcionan correctamente. Para una descripci贸n detallada, este tutorial es un buen punto de partida.

Dentro de la funci贸n, las primeras l铆neas inicializan los objetos que usaremos en la iteraci贸n. Por ejemplo, arma::mat x(xa.begin(), xa.nrow(), xa.ncol(), false); indica que crearemos una matriz x, objeto de Armadillo (los objetos m谩s usados son vec, para vector num茅rico y mat para matrices num茅ricas) reutilizando la memoria que apunta a la matriz num茅rica de Rcpp xa, nuestro input. Los dem谩s objetos se inicializan a partir de los tipos que usaremos.

El siguiente paso es el loop principal, que sigue la misma estructura que el loop de R, pero ahora utilizando las funciones propias de Armadillo o Rcpp, dependiendo de cu谩les sean los objetos sobre los que se las estamos aplicando. Por ejemplo, el operador % corresponde a la multiplicaci贸n t茅rmino a t茅rmino entre matrices de Armadillo. Finalmente, devolvemos una lista de R con el cluster (ahora como una matriz con 铆ndices) y los centroides. La funci贸n wrap() permite transformar objetos de Armadillo en objetos de Rcpp f谩cilmente.

Para poder utilizar la funci贸n en R en nuestra sesi贸n corremos el comando Rcpp::sourceCpp("kmeansCpp.cpp"). Con ello la funci贸n kmeansCpp() est谩 disponible para su uso.

Comparando resultados

Comparemos los c贸digos, para ver cu谩nto ganamos por traspasar la funci贸n a C++. Utilizar茅 el paquete rbenchmark que permite r谩pidamente comparar distintas funciones en cuanto a su tiempo de ejecuci贸n. El c贸digo de comparaci贸n completo queda entonces:

library(RcppArmadillo)
library(rbenchmark)
Rcpp::sourceCpp("kmeansCpp.cpp")
x <- rbind(matrix(rnorm(1000 * 2, mean = 0), nrow = 1000, ncol = 2), 
           matrix(rnorm(1000 * 2, mean = 5), nrow = 1000, ncol = 2))
benchmark(kmeans(x,2), kmeansCpp(x,2), kmeans.R(x,2), order = "elapsed",
          columns = c("test", "replications", "elapsed", "relative", 
                      "user.self"))
# Resultados:
#              test replications elapsed relative user.self
# 2 kmeansCpp(x, 2)          100   0.040    1.000     0.036
# 1    kmeans(x, 2)          100   0.059    1.475     0.056
# 3  kmeans.R(x, 2)          100   7.775  194.375     7.752

Nuestra implementaci贸n en C++ corre en 0.036 segundos, 194 veces m谩s r谩pido (!!!) que la implementaci贸n en R puro y corre adem谩s casi un 50% m谩s r谩pido que la implementaci贸n en C que trae incorporado R. El resultado habla por si solo, vale la pena traspasar c贸digos con muchas iteraciones a C++, pues es m谩s eficiente tanto en uso de memoria como en velocidad. La diferencia con kmeans de R es enga帽osa, pues no realizamos ning煤n control de errores ni depuraciones a las entradas que probablemente s铆 realiza esta funci贸n, por lo que las ganancias se deben a eso probablemente.

As铆 que ya saben, si quieren que su c贸digo vuele, bajen un poco de nivel y utilicen c贸digo en C++. R ofrece todas las herramientas para lograrlo de forma sencilla.