These days I had some data mining problems in which I wanted to use Ruby instead of Python. One of the problems I faced is that I wanted to use k-fold cross validation, but couldn’t find a sufficiently simple gem (or gist or whatever) for it. I ended up creating my own version.
A review of k-fold cross validation
A common way to study the performance of a model is to partition a dataset into training and validation sets. Cross validation is a method to assess if a model can generalize well independent of how this separation is decided.
The method of k-fold cross validation is to divide the dataset into k partitions, select one at a time for the validation set and use the other k – 1 partitions for training. So you end up with k different models, and respective performance measures against the validation sets.
The image below is an example of one fold: the k-th partition is left out for validation and partitions 1, …, k-1 are used for training.
E.g., if k = 5, you end up with 80% of the dataset for training and 20% for validation.
My solution is a function that receives the dataset, the number of partitions and a block, responsible for training and using the classifier in question. Most of it is straightforward, just keep in mind that the last partition (
k-1-th) should encompass all the remaining elements when
dataset.size isn’t divisible by
k. Obviously, the training set is defined by the elements not in the validation set.
def cross_validate(dataset, k)
partition_size = (dataset.size / k.to_f).floor
partitions = (0 .. k-1).to_a.map do |i|
if i == k-1
(partition_size * i) .. dataset.size - 1
(partition_size * i) .. (partition_size * (i + 1) - 1)
partitions.each_with_index do |partition, i|
validation = dataset[partition]
training = (partitions - [partition]).reduce() do |acc, part|
acc + dataset[part]
# Let the classifier do its work...
yield training, validation
Now suppose you have a CSV file called “dataset.csv” and you have a classifier you want to train. It’s as easy as:
DATASET = CSV.read("dataset.csv")
cross_validation(DATASET, 5) do |training, validation|
# Train your classifier with `training`.
# Calculate the performance of the classifier on `validation`.
And your classifier’s code is totally decoupled from the cross validation function. I like it.
This code isn’t a full-blown gem and I don’t think there should be one just for cross validation. It fits in a whole machine learning library, though–and I hope to build one based on NMatrix… eventually.