Introduction to Machine Learning

How data from the past can build predictions for the future

What is machine learning?

Machine learning is a pretty cool thing. You basically feed a computer a dataset and it will learn from that dataset. Once it learns from that dataset it can then predict the outcome of new data. The more data you feed it, the more accurately it can predict.

For example, lets say I told the computer about 100 males and 100 females based off their age, height, weight, and body mass index. I could then ask the computer whether it thinks a person with an age of 20, height of 5’3", weight of 112lbs, and a BMI of 20 is a male or a female. Based off of the dataset it was given to learn from it would probably predict this new person is a female.

I’m going to show you an in depth example of machine learning using the R programming language. You can download R Studio for free and follow along if you’d like.

Let’s Get Some Data

First I’m going to download a dataset from Kaggle. This dataset contains features (descriptors, or columns) describing specific type of a plant, the dataset also tells us what species the plant is. We are going to try to leverage machine learning to let the computer predict what species a plant is when we feed it data.

Once they dataset is downloaded you will have to open up R Studio and change your working directory to wherever you downloaded the CSV dataset file to and read the CSV file in R.

# Change Working Directory and Read CSV File
getwd()
setwd("/Users/SamuelCuster/Desktop/R")

# List Files and Verifiy CSV in Current Directory
list.files()

# Read CSV file
iris <- read.csv("Iris.csv")

Now I'm going to load the ggplot2 graphing library, this library gives us some easy to use graphing functions in R.

# Advanced Plotting Library
install.packages("ggplot2")
library(ggplot2)

Now we want to try to learn the basics of our data. What features are present? What is the mean, median, and mode of each feature? How many of each species of plant?

# Understand Structure of Data
head(iris,5)
str(iris)
summary(iris)

Now we are going to build a pie chart to visualize the number of species of our plants.

# Visualize Data
## Pie
slices <- c(
  nrow(subset(iris,(Species == "Iris-setosa"))),
  nrow(subset(iris,(Species == "Iris-versicolor"))),
  nrow(subset(iris,(Species == "Iris-virginica"))))

slice.labels <- c("Iris-setosa","Iris-versicolor","Iris-virginica")
pie(slices,labels=slice.labels,main="Species of Iris")
Pie Chart Visualizing Plant Species Count Ratio
Pie Chart Visualizing Plant Species Count Ratio

Now we are going to visualize our data, and see if we can discover any patterns based on species. The best way to visualize for these patterns is by using a scatter plot. We can color code each species and visually inspect for a pattern or grouping by species.

# Scatter Plot 1
p <- ggplot(iris, aes(x=PetalWidthCm, y=PetalLengthCm, color=Species)) +
  geom_point(size=2) +
  theme_light(base_size=16) +
  ggtitle("Petal Width vs. Petal Length")

ggsave("petalWidthvsLength.png", p, height=6, width=10, units="in")
Scatter Plot: Petal Length vs Petal Width, Colored by Species
Scatter Plot: Petal Length vs Petal Width, Colored by Species

We can clearly and easily observer that Iris-Setosa (Red) is isolated in its own group away from Iris-Versicolor (Green) and Iris-Virginica (Blue), but Iris-Versicolor and Iris-Virginica seem to overlap slightly.

We can gather new insights by comparing different features of our dataset.

# Scatter Plot 2
p <- ggplot(iris, aes(x=SepalWidthCm, y=SepalLengthCm, color=Species)) +
  geom_point(size=2) +
  theme_light(base_size=16) +
  ggtitle("Sepal Width vs. Sepal Length")
  ggsave("sepalWidthvsLength.png", p, height=6, width=10, units="in")
Scatter Plot: Sepal Length vs Sepal Width, Colored by Species
Scatter Plot: Sepal Length vs Sepal Width, Colored by Species
# Scatter Plot 3
p <- ggplot(iris, aes(x=PetalWidthCm, y=SepalWidthCm, color=Species)) +
  geom_point(size=2) +
  theme_light(base_size=16) +
  ggtitle("Petal Width vs. Sepal Width")

ggsave("petalWidthvsSepalWidth.png", p, height=6, width=10, units="in")
Scatter Plot: Petal Width vs Sepal Width, Colored by Species
Scatter Plot: Petal Width vs Sepal Width, Colored by Species

We were not able to gather much more insight by visualizing the other features, but we can clearly see that if we started with the second visualization, that we would not see the full picture of the data.

Now we are going to build our training data and testing data. Our training data contains the data that the computer will learn from, the more training data given, the more accurate future predictions can be.

# Isolate Training and Test Data
trainer <- rbind(iris[1:25,],iris[51:75,],iris[101:125,])
test <- rbind(iris[26:50,],iris[76:100,],iris[126:150,])

Now we train the machine/computer using the training data. We can use a library called "MASS" do save ourselves some time, this provides the algorithm we need to train the computer by building a model we can "fit" our test data into.

# Train Classifier
library(MASS)
data(trainer)
# fit model
fit <- lda(Species~., data=trainer)
# summarize the fit
summary(fit)

Now we can predict the test data and then compare our predictions to the true results.

# Predict Test Data
predictions <- predict(fit, test[,1:5])

# Compare Predictions with Real Answers in a Table
table(predictions$class, test$Species)

# Compare Predictions with Real Answers Individually
test$Species
predictions$class

#install.packages("compare")
library(compare)
# Compare Predictions with Real Answers as True or False (Entirely)
compare(test$Species,predictions$class)

How Does This Apply to UX?

This means that we can make predictions about our users behaviors based on previous behaviors. We can know when a user might be upset with our product without them having to tell us, and when we should offer them assistance. We can tell users what they might also be interested in based on previous behaviors of other users in the past. We can also use machine learning to know when it is the best time to sell to a customer. You'll discover that companies like eBay, Walmart, and Amazon use machine learning in their web and mobile services. Machine learning is also used in services like Optimizely for A-B testing.

If you're interested in applying machine learning to your online services, checkout the Limdu JavaScript library!