Predicting the Future With Linear Regression in Ruby

The world is full of linear relationships. When one apple costs $1 and two apples cost $2, it's easy to figure out the price of any number of apples. But what happens when you have 100s of data points? What if your data source is noisy? That's when it's helpful to use a technique called linear regression. In this article Julie Kent shows us how linear regression works, and walks through a practical example in Ruby.

Many choices that we make revolve around numerical relationships.

  • We eat certain foods because science says they lower our cholesterol
  • We further our education because we're likely to have an increased salary
  • We buy a house in the neighborhood we believe is going to appreciate in value the most

How do we come to these conclusions? Most likely, someone gathered a large amount of data and used it to form conclusions. One common technique is linear regression, which is a form of supervised learning. For more info on supervised learning and examples of what it is often used for, check out Part 1 of this series.

Linear Relationships

When two values — call them x and y — have a linear relationship, it means that changing x by 1 will always cause y to change by a fixed amount. It's easier to give examples:

  • 10 pizzas cost 10x the price of one pizza.
  • A 10-foot-tall wall needs twice as much paint as a 5-foot wall

Mathematically, this kind of relationship is described using the equation of a line:

y = mx + b

Math can be dreadfully confusing, but oftentimes it seems like magic to me. When I first learned the equation of a line, I remember thinking how beautiful it was to be able to calculate distance, slope, and other points on a line with just one formula.

But how do you get this formula, if all you have are data points? The answer is linear regression — a very popular machine learning tool.

An Example of Linear Regression

In this post, we are going to explore whether the beats per minute (BPM) in a song predicts its popularity on Spotify.

Linear regression models the relationship between two variables. One is called the "explanatory variable" and the other is called the "dependent variable."

In our example, we want to see if BPM can "explain" popularity. So BPM will be our explanatory variable. That makes popularity the dependent variable.

The model will utilize least-squares regression to find the best fitting line of the form, you guessed it, y = mx + b.

While there can be multiple explanatory variables, for this example we'll be conducting simple linear regression where there is just one.

Least-Squares What?

There are several ways to do linear regression. One of them is called "least-squares." It calculates the best fitting line by minimizing the sum of the squares of the vertical deviations from each data point to the line.

I know that sounds confusing, but it's basically just saying, "Build me a line that minimizes the amount of space between said line and the data points."

The reason for the squaring and summing is so there aren't any cancellations between positive and negative values.

leastsquares

Here is an image I found on Quora that does a pretty good job of explaining it.

The Dataset

we will be using this dataset from Kaggle: https://www.kaggle.com/leonardopena/top50spotify2019 You can download it as a CSV.

The dataset has 16 columns; however, we only care about three — "Track Name," "Beats Per Minute," and "Popularity." One of the most important steps of machine learning is getting your data properly formatted, often referred to as "munging." You can delete all of the data except for the three aforementioned columns.

Your CSV should look like this: csv

Using Ruby to do the Regression

In this example, we will be utilizing the ruby_linear_regression gem. To install, run:

gem install ruby_linear_regression

OK, we're ready to start coding! Create a new Ruby file and add these requires:

require "ruby_linear_regression"
require "csv"

Next, we read our CSV data and call #shift, to discard the header row. Alternatively, you could just delete the first row from the CSV file.

csv = CSV.read("top50.csv")
csv.shift

Let's create two empty arrays to hold our x-data points and y-data points.

x_data = []
y_data = []

...and we iterate using the .each method to add the Beats Per Minute data to our x array and Popularity data to our y array.

If you're curious to see what is actually happening here, you can experiment by logging your row with either a puts or p. For example: puts row

csv.each do |row|
  x_data.push( [row[1].to_i] )
  y_data.push( row[2].to_i )
end

Now it's time to use the ruby_linear_regression gem. We'll create a new instance of our regression model, load our data, and train our model:

linear_regression = RubyLinearRegression.new
linear_regression.load_training_data(x_data, y_data)
linear_regression.train_normal_equation

Next, we'll print the mean square error (MSE) — a measure of the difference between the observed values and the predicted values. The difference is squared so that negative and positive values do not cancel each other out. We want to minimize the MSE because we do not want the distance between our predicted and actual values to be large.

puts "Trained model with the following cost fit #{linear_regression.compute_cost}"

Finally, let's have the computer use our model to make a prediction. Specifically, how popular will a song with 250 BPM be? Feel free to play around with different values in the prediction_data array.

prediction_data = [250]
predicted_popularity = linear_regression.predict(prediction_data)
puts "Predicted popularity: #{predicted_popularity.round}"

Results

Let's run the program in our console and see what we get!

âžś  ~ ruby spotify_regression.rb
Trained model with the following cost fit 9.504882197447587
Predicted popularity: 91

Cool! Let's change the "250" to "50" and see what our model predicts.

âžś  ~ ruby spotify_regression.rb
Trained model with the following cost fit 9.504882197447587
Predicted popularity: 86

It appears that songs with more beats per minute are more popular.

Entire Program

Here's what my entire file looks like:

require 'csv'
require 'ruby_linear_regression'

x_data = []
y_data = []
csv = CSV.read("top50.csv")
csv.shift

# Load data from CSV file into two arrays -- one for independent variables X (x_data) and one for the dependent variable y (y_data)
# Row[0] = title
# Row[1] = BPM
# Row[2] = Popularity
csv.each do |row|
  x_data.push( [row[1].to_i] )
  y_data.push( row[2].to_i )
end

# Create regression model
linear_regression = RubyLinearRegression.new

# Load training data
linear_regression.load_training_data(x_data, y_data)

# Train the model using the normal equation
linear_regression.train_normal_equation

# Output the cost
puts "Trained model with the following cost fit #{linear_regression.compute_cost}"

# Predict the popularity of a song with 250 BPM
prediction_data = [250]
predicted_popularity = linear_regression.predict(prediction_data)
puts "Predicted popularity: #{predicted_popularity.round}"

Next Steps

This is a very simple example, but nevertheless, you've just run your first linear regression, which is a key technique used for machine learning. If you're yearning for more, here are a few other things you could do next: - Check out the source code for the Ruby gem we were using to see the math happening under the hood - Go back to the original data set and try adding additional variables to the model and run a multi-variable linear regression to see if that can reduce our MSE. For example, maybe "valence" (how positive the song is) also plays a role in popularity. - Try out a gradient descent model, which can also be run using the ruby_linear_regression gem.

What to do next:
  1. Try Honeybadger for FREE
    Honeybadger helps you find and fix errors before your users can even report them. Get set up in minutes and check monitoring off your to-do list.
    Start free trial
    Easy 5-minute setup — No credit card required
  2. Get the Honeybadger newsletter
    Each month we share news, best practices, and stories from the DevOps & monitoring community—exclusively for developers like you.
    author photo

    Julie Kent

    Julie is an engineer at Stitch Fix. In her free time, she likes reading, cooking, and walking her dog.

    More articles by Julie Kent
    Stop wasting time manually checking logs for errors!

    Try the only application health monitoring tool that allows you to track application errors, uptime, and cron jobs in one simple platform.

    • Know when critical errors occur, and which customers are affected.
    • Respond instantly when your systems go down.
    • Improve the health of your systems over time.
    • Fix problems before your customers can report them!

    As developers ourselves, we hated wasting time tracking down errors—so we built the system we always wanted.

    Honeybadger tracks everything you need and nothing you don't, creating one simple solution to keep your application running and error free so you can do what you do best—release new code. Try it free and see for yourself.

    Start free trial
    Simple 5-minute setup — No credit card required

    Learn more

    "We've looked at a lot of error management systems. Honeybadger is head and shoulders above the rest and somehow gets better with every new release."
    — Michael Smith, Cofounder & CTO of YvesBlue

    Honeybadger is trusted by top companies like:

    “Everyone is in love with Honeybadger ... the UI is spot on.”
    Molly Struve, Sr. Site Reliability Engineer, Netflix
    Start free trial
    Are you using Sentry, Rollbar, Bugsnag, or Airbrake for your monitoring? Honeybadger includes error tracking with a whole suite of amazing monitoring tools — all for probably less than you're paying now. Discover why so many companies are switching to Honeybadger here.
    Start free trial