Cross validation in Ruby

These days I had some data mining problems in which I wanted to use Ruby instead of Python. One of the problems I faced is that I wanted to use k-fold cross validation, but couldn’t find a sufficiently simple gem (or gist or whatever) for it. I ended up creating my own version.

A review of k-fold cross validation

A common way to study the performance of a model is to partition a dataset into training and validation sets. Cross validation is a method to assess if a model can generalize well independent of how this separation is decided.

The method of k-fold cross validation is to divide the dataset into k partitions, select one at a time for the validation set and use the other k – 1 partitions for training. So you end up with k different models, and respective performance measures against the validation sets.

The image below is an example of one fold: the k-th partition is left out for validation and partitions 1, …, k-1 are used for training.

k-fold cross validation

E.g., if k = 5, you end up with 80% of the dataset for training and 20% for validation.

The implementation

My solution is a function that receives the dataset, the number of partitions and a block, responsible for training and using the classifier in question. Most of it is straightforward, just keep in mind that the last partition (k-1-th) should encompass all the remaining elements when dataset.size isn’t divisible by k. Obviously, the training set is defined by the elements not in the validation set.

The last part is to yield both sets to the given block. Some information regarding the functionality of the yield keyword can be seen in another post and in Ruby core’s documentation.

Now suppose you have a CSV file called “dataset.csv” and you have a classifier you want to train. It’s as easy as:

And your classifier’s code is totally decoupled from the cross validation function. I like it.

Conclusion

I found a gem on GitHub the other day unsurprisingly called cross validation. Its API is similar to scikit-learn‘s, which I find particularly strange. Too object oriented for me.

This code isn’t a full-blown gem and I don’t think there should be one just for cross validation. It fits in a whole machine learning library, though–and I hope to build one based on NMatrix… eventually.

Leave a Reply