This notebook was created by Jean de Dieu Nyandwi for the love of machine learning community. For any feedback, errors or suggestion, he can be reached on email (johnjw7084 at gmail dot com), Twitter, or LinkedIn.
Data Visualization with Seaborn¶
Seaborn is a fantastic and easy to use Python Visualization which is built on Matplotlib.
For a quick look, check out the gallery page.
To be covered:
In this lab, we will use real world datasets, which are already part of Seaborn.
1. Relational Plots¶
These kind of plots are used to analyze the relationship between features.
- Scatter Plots
- Line Plots
Imports¶
import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# Loading the datasets to be used in this lab
titanic = sns.load_dataset('titanic')
fmri = sns.load_dataset('fmri')
tips = sns.load_dataset('tips')
flights = sns.load_dataset('flights')
titanic.head()
survived | pclass | sex | age | sibsp | parch | fare | embarked | class | who | adult_male | deck | embark_town | alive | alone | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | 3 | male | 22.0 | 1 | 0 | 7.2500 | S | Third | man | True | NaN | Southampton | no | False |
1 | 1 | 1 | female | 38.0 | 1 | 0 | 71.2833 | C | First | woman | False | C | Cherbourg | yes | False |
2 | 1 | 3 | female | 26.0 | 0 | 0 | 7.9250 | S | Third | woman | False | NaN | Southampton | yes | True |
3 | 1 | 1 | female | 35.0 | 1 | 0 | 53.1000 | S | First | woman | False | C | Southampton | yes | False |
4 | 0 | 3 | male | 35.0 | 0 | 0 | 8.0500 | S | Third | man | True | NaN | Southampton | no | True |
fmri.head(3)
subject | timepoint | event | region | signal | |
---|---|---|---|---|---|
0 | s13 | 18 | stim | parietal | -0.017552 |
1 | s5 | 14 | stim | parietal | -0.080883 |
2 | s12 | 18 | stim | parietal | -0.081033 |
tips.head(3)
total_bill | tip | sex | smoker | day | time | size | |
---|---|---|---|---|---|---|---|
0 | 16.99 | 1.01 | Female | No | Sun | Dinner | 2 |
1 | 10.34 | 1.66 | Male | No | Sun | Dinner | 3 |
2 | 21.01 | 3.50 | Male | No | Sun | Dinner | 3 |
flights.head(3)
year | month | passengers | |
---|---|---|---|
0 | 1949 | Jan | 112 |
1 | 1949 | Feb | 118 |
2 | 1949 | Mar | 132 |
Scatter Plots¶
In order to visualize the relationship between between two numeric features, scatter plot can be a go to plot over other types.
We will use sns.scatterplot(data, x, y, hue, style, palette, size, sizes, legend, markers...)
and also sns.relplot()
.
sns.scatterplot(data=titanic, x='age', y='fare')
<AxesSubplot:xlabel='age', ylabel='fare'>
With hue
parameter, we can map another feature to the plot.
sns.scatterplot(data=titanic, x='age', y='fare', hue='sex')
<AxesSubplot:xlabel='age', ylabel='fare'>
You can see it makes the plots more clear. In terms of this titanic dataset, you can directly identify that women paid high fare than men.
To also highlight the difference between the hue classes
, we can add marker style as follows.
sns.scatterplot(data=titanic, x='age', y='fare', hue='sex', style='sex')
<AxesSubplot:xlabel='age', ylabel='fare'>
Increasing the figure size...
plt.figure(figsize=(8,6))
sns.scatterplot(data=titanic, x='age', y='fare', hue='sex', style='pclass')
<AxesSubplot:xlabel='age', ylabel='fare'>
You can also use specific markers with style
and markers
paremeters.
plt.figure(figsize=(8,6))
markers = {1:'P', 2:'X', 3:'D'} # P, X, and D are markers
sns.scatterplot(data=titanic, x='age', y='fare', hue='sex', style='pclass', markers=markers)
<AxesSubplot:xlabel='age', ylabel='fare'>
Adding same feature on hue
to size
will make the plot more meaninfgul. sizes
will control the range of marker areas of size
. The
sns.scatterplot(data=titanic, x='age', y='fare', hue='sex', size='sex', sizes=(20,200))
<AxesSubplot:xlabel='age', ylabel='fare'>
sns.scatterplot(data=titanic, x='age', y='fare', hue='sex', size='pclass', sizes=(20,200))
<AxesSubplot:xlabel='age', ylabel='fare'>
Line Plots¶
Line plots are used to analyze the changes in one variable with respect to other variable such as time. An example, say you want to see the daily changes in revenues everyday.
Data visualization is an art. We will see more visualization types but always, there is no viz type that is going to be fit all. It is always good to consider the audience and the goal of the analysis.
To make line plot in Seaborn, we will use sns.relplot()
, the kind parameter being line
. You can make a scatter plots by changing the kind
to scatter
.
sns.relplot(data=flights, x='month', y='passengers', kind='line')
<seaborn.axisgrid.FacetGrid at 0x7f9b3d01c2d0>
sns.relplot(data=flights, x='year', y='passengers', kind='line')
<seaborn.axisgrid.FacetGrid at 0x7f9b1ea225d0>
You can remove the confidence intervals by setting parameter ci=None
. Setting ci=sd
will instead show standard deviation and that can be useful sometime.
sns.set_theme(style="darkgrid") #setting the grid for all next plots
sns.relplot(data=flights, x='month', y='passengers', ci=None, kind='line')
<seaborn.axisgrid.FacetGrid at 0x7f9b3d11bed0>
sns.relplot(data=flights, x='month', y='passengers', ci='sd', kind='line')
<seaborn.axisgrid.FacetGrid at 0x7f9b3d1ab790>
By default, relplot
will aggregate the values on y axis in accordance to x axis. You can get rid of it by settng estimator=None
.
sns.relplot(data=flights, x='month', y='passengers', ci=None, estimator=None, kind='line')
<seaborn.axisgrid.FacetGrid at 0x7f9b3d1f7910>
You can add multiple lines....
sns.relplot(data=fmri, x='timepoint', y='signal',
hue='event', kind='line')
<seaborn.axisgrid.FacetGrid at 0x7f9b1eb0b110>
sns.relplot(data=fmri, x='timepoint', y='signal',
hue='event', kind='line', style='event')
<seaborn.axisgrid.FacetGrid at 0x7f9adb4bd110>
sns.relplot(data=fmri, x='timepoint', y='signal', hue='event', kind='line',
style='event', dashes=False, markers=True)
<seaborn.axisgrid.FacetGrid at 0x7f9ac927e0d0>
Showing multiple relationships
sns.relplot(x='total_bill', y='tip', hue='smoker',
col='time', data=tips);
sns.relplot(x='timepoint', y='signal', hue='subject',
col='region', row='event', height=4,
kind='line', estimator=None, data=fmri)
<seaborn.axisgrid.FacetGrid at 0x7f9ac9285ed0>
2. Distribution Plots¶
These kind of plots are used to visualize the distribution of the features. Understanding how your data is distributed can help you to determine the range of values, their central tendency, or if they maybe inclined in one direction and also spotting outliers.
Distribution plots functions:
- displot()
- jointplot()
- pairplot()
- rugplot()
- kdeplot
Plotting Histograms with displot() and histplot()¶
Histogram is a bar plot whose x-axis is a variable or a feature and y-axis being the count of values of that particular variable. By default, displot() plots the histogram.
sns.displot(titanic['age'])
<seaborn.axisgrid.FacetGrid at 0x7f9adb9d1d50>
By default, Kernel Density Estimator(KDE) is True. By setting it off, we should remain with the histograms. A histogram displays data by grouping data into bins
. We can set the bins
to a value of choice.
sns.displot(titanic['age'], kde=False, bins=10)
<seaborn.axisgrid.FacetGrid at 0x7f9adbde86d0>
We can also plot categorical data on histograms.
sns.histplot(tips, x='day', shrink=.8)
<AxesSubplot:xlabel='day', ylabel='Count'>
sns.displot(titanic, x='age',hue='pclass')
<seaborn.axisgrid.FacetGrid at 0x7f9adb77fcd0>
sns.histplot(titanic, x="age",hue="sex")
<AxesSubplot:xlabel='age', ylabel='Count'>
sns.histplot(titanic, x="age",hue="survived")
<AxesSubplot:xlabel='age', ylabel='Count'>
We can also make stacked bars by setting the parameter multiple='stack'
.
sns.histplot(titanic, x='age',hue='pclass', multiple='stack')
<AxesSubplot:xlabel='age', ylabel='Count'>
Plotting Bivariate Data with Jointplot()¶
We use jointplot() to plot two variables with bivariate and univariate graphs. We can have the following options in kind
: scatter, reg, resid, kde, hex
. The default kind
is scatter
.
sns.jointplot(data=tips, x='total_bill', y='tip')
<seaborn.axisgrid.JointGrid at 0x7f9ac9bab850>
sns.jointplot(data=tips, x='total_bill', y='tip', kind='hex')
<seaborn.axisgrid.JointGrid at 0x7f9b1f29a490>
sns.jointplot(data=tips, x='total_bill', y='tip', kind='kde', hue='day')
<seaborn.axisgrid.JointGrid at 0x7f9b1f2a9210>
sns.jointplot(data=tips, x='total_bill', y='tip', kind='reg')
<seaborn.axisgrid.JointGrid at 0x7f9ac855d250>
sns.jointplot(data=tips, x='total_bill', y='tip', kind='resid')
<seaborn.axisgrid.JointGrid at 0x7f9adaf10dd0>
Plotting Many Distribution with pairplot()¶
pairplot() visualize all possible distributions of variables in datasets. This is a cool plot, you can immmediately see relationships between features.
sns.pairplot(titanic)
<string>:6: RuntimeWarning: Converting input from bool to <class 'numpy.uint8'> for compatibility. <string>:6: RuntimeWarning: Converting input from bool to <class 'numpy.uint8'> for compatibility. <string>:6: RuntimeWarning: Converting input from bool to <class 'numpy.uint8'> for compatibility. <string>:6: RuntimeWarning: Converting input from bool to <class 'numpy.uint8'> for compatibility.
<seaborn.axisgrid.PairGrid at 0x7f9b3d38cfd0>
sns.pairplot(tips)
<seaborn.axisgrid.PairGrid at 0x7f9aca0fc410>
Plotting Distributions with rugplot()¶
rugplot() plot marginal distributions by drawing ticks or dash along the x and y axis of the univariate variable.
sns.rugplot(data=tips, x='total_bill', y='tip')
<AxesSubplot:xlabel='total_bill', ylabel='tip'>
Combining relplot()
and rugplot
in one figure...
sns.relplot(data=tips, x='total_bill', y='tip')
sns.rugplot(data=tips, x='total_bill', y='tip')
<AxesSubplot:xlabel='total_bill', ylabel='tip'>
Kernel Density Estimation (KDE) Plot with kdeplot() and displot()¶
We can visualize the probability density of a variable. Different to histogram which shows the counts, KDE plot smooths the visualization with a Guassian kernel.
sns.kdeplot(data=tips, x='tip')
<AxesSubplot:xlabel='tip', ylabel='Density'>
sns.kdeplot(data=tips, x='tip', hue='day')
<AxesSubplot:xlabel='tip', ylabel='Density'>
sns.displot(titanic,x='age', kind='kde', hue='sex')
<seaborn.axisgrid.FacetGrid at 0x7f9b3df93b10>
sns.kdeplot(data=tips, x='tip', hue='day', multiple='stack')
<AxesSubplot:xlabel='tip', ylabel='Density'>
sns.displot(data=tips, x='tip', hue='day', kind='kde', fill=True)
<seaborn.axisgrid.FacetGrid at 0x7f9b1e106dd0>
Whether you use displot()
or kdeplot()
, you can see they are very handy in visualizing the density distributions.
Cumulative Distributions¶
There are times we would like to visualize the cumulative distributions. By setting kind parameter in displot() to ecdf
, we can plot the cumulative or increasing curve of an univariate variable.
ecdf
stands for empirical cumulative distribution function
sns.displot(titanic, x='age', kind='ecdf')
<seaborn.axisgrid.FacetGrid at 0x7f9b3dfb4b50>
3. Categorical Plots¶
Categorical plots are used to visualize the categorical data.
In Seaborn, there are various plot functions that we are going to see:
- Categorical estimate plots
- barplot()
- countplot()
- pointplot()
- Categorical distribution plots
- boxplot()
- boxenplot()
- violinplot()
- Categorical scatter plots
- stripplot()
- swarmplot()
Just like we saw in distribution plots, Seaborn also provides a high level function catplot()
to plot all these types above. You just have to pass the kind
parameter. In order of what how they are listed above, here is the kind parameter: bar, count, point, box, boxen, violin, strip, swarm
.
We will use both the high level function catplot and the specific categorical function along the way.
Categorical estimate plots¶
Barplot() and Countplot()¶
Barplot is used to visualize the aggregated categorical data based on different estimation functions (mean being the default. We can either use barplot()
or catplot(...,kind='bar)
.
Countplot is used to visualize the number of observations in each category. It's like histogram for categorical data.
sns.catplot(data=titanic, x='sex', y='survived', hue='pclass', kind='bar')
<seaborn.axisgrid.FacetGrid at 0x7f9b3e12df10>
sns.barplot(data=tips, x='sex', y='total_bill', palette='rocket', hue='day')
<AxesSubplot:xlabel='sex', ylabel='total_bill'>
sns.countplot(data=tips, x='day', palette='coolwarm')
#sns.catplot(data=tips, x='day', kind='count', palette='coolwarm') will do the same
<AxesSubplot:xlabel='day', ylabel='count'>
pointplot()¶
Rather than plotting bars, pointplot() plots the point estimation of the categorical data. You may also notice that it connects the points with the categorical variable specified at hue
.
sns.pointplot(data=titanic, x='sex', y='survived', hue='class', palette='dark')
<AxesSubplot:xlabel='sex', ylabel='survived'>
Categorical distribution plots¶
boxplot() and boxenplot(), violinplot()¶
Both Box, boxen, and violin plots are used to plot the distributions of the categorical data.
"A box plot (or box-and-whisker plot) shows the distribution of quantitative data in a way that facilitates comparisons between variables or across levels of a categorical variable. The box shows the quartiles of the dataset while the whiskers extend to show the rest of the distribution, except for points that are determined to be “outliers” using a method that is a function of the inter-quartile range."
sns.boxplot(data=tips, x='day', y='total_bill')
<AxesSubplot:xlabel='day', ylabel='total_bill'>
sns.catplot(data=tips, x='day', y='total_bill', hue='smoker', kind='box')
<seaborn.axisgrid.FacetGrid at 0x7f9b1fe1d810>
sns.catplot(data=tips, x='day', y='total_bill', hue='smoker', kind='boxen')
<seaborn.axisgrid.FacetGrid at 0x7f9addbe12d0>
Violin plot shows the distributions of the categorical data but features the kernel density of the underlyining estimation.
sns.violinplot(data=tips, x='day', y='total_bill', palette='autumn')
<AxesSubplot:xlabel='day', ylabel='total_bill'>
# split parameter save space when hue have two levels
sns.catplot(data=tips, x='total_bill', y='day',
hue='sex', kind='violin', split=True)
<seaborn.axisgrid.FacetGrid at 0x7f9b1fe26750>
# strip is the default kind paremeter when using catplot()
sns.catplot(data=tips, x='day', y='total_bill')
<seaborn.axisgrid.FacetGrid at 0x7f9b3e491350>
sns.stripplot(data=tips, x='day', y='total_bill', jitter=False)
<AxesSubplot:xlabel='day', ylabel='total_bill'>
sns.swarmplot(data=tips, x='day', y='total_bill')
<AxesSubplot:xlabel='day', ylabel='total_bill'>
sns.catplot(data=tips, x='total_bill', y='day', hue='sex', kind='swarm')
<seaborn.axisgrid.FacetGrid at 0x7f9b1fedb310>
Plotting Multiple Categorical plots¶
Using Facetgrid (Multi-plot grid for plotting conditional relationships), we can plot multiple plots with catplot(....)
sns.catplot(data=tips, x='day', y='total_bill', hue='smoker',
col='time', kind='swarm')
<seaborn.axisgrid.FacetGrid at 0x7f9aca8d36d0>
That's it for categorical plots.
4. Regression Plots¶
Seaborn takes data visualizations to other extents: Not only you can plot features, but you can also plot the linear relationship between two variables (linear model).
Even if regplot()
can be used for such goal, we will use lmplot()
.
sns.lmplot(data=tips, x='total_bill', y='tip')
<seaborn.axisgrid.FacetGrid at 0x7f9aca912b50>
sns.lmplot(data=tips, x='total_bill', y='tip', hue='sex')
<seaborn.axisgrid.FacetGrid at 0x7f9b3e667a10>
tips['big_tip'] = (tips.tip / tips.total_bill) >.15
tips.head()
total_bill | tip | sex | smoker | day | time | size | big_tip | |
---|---|---|---|---|---|---|---|---|
0 | 16.99 | 1.01 | Female | No | Sun | Dinner | 2 | False |
1 | 10.34 | 1.66 | Male | No | Sun | Dinner | 3 | True |
2 | 21.01 | 3.50 | Male | No | Sun | Dinner | 3 | True |
3 | 23.68 | 3.31 | Male | No | Sun | Dinner | 2 | False |
4 | 24.59 | 3.61 | Female | No | Sun | Dinner | 4 | False |
If we have one binary variable, we can also do logistic regression.
sns.lmplot(x='age', y='survived', data=titanic, logistic=True, y_jitter=0.03)
<seaborn.axisgrid.FacetGrid at 0x7f9b3e76db10>
sns.lmplot(x='age', y='alone', data=titanic, logistic=True, y_jitter=0.03, hue='sex')
<seaborn.axisgrid.FacetGrid at 0x7f9aca9ca050>
We can also use markers to make plot more clear
sns.lmplot(x='age', y='survived', data=titanic, hue='sex', logistic=True,
y_jitter=0.03, markers=['s','D'])
<seaborn.axisgrid.FacetGrid at 0x7f9acaa124d0>
Multiple Plots¶
Like we did before, we can make multiple plots by providing another variable to col
and row
.
sns.lmplot(data=tips, x='total_bill', y='tip', hue='smoker', col='time')
<seaborn.axisgrid.FacetGrid at 0x7f9acaa00850>
sns.lmplot(data=tips, x='total_bill', y='tip', hue='smoker', col='time', row='sex')
<seaborn.axisgrid.FacetGrid at 0x7f9ade5e18d0>
One last thing about regression plot in Seaborn: We can also use jointplot()
and pairplot()
where we use reg
in parameter kind
.
sns.jointplot(data=tips, x='total_bill', y='tip', kind='reg')
<seaborn.axisgrid.JointGrid at 0x7f9b3e84cd90>
sns.pairplot(tips, x_vars=['total_bill', 'size'], y_vars=['tip'],
hue='time', kind='reg', height=5)
<seaborn.axisgrid.PairGrid at 0x7f9acad1c0d0>
That's it for regression plots!!
5. Multiplots¶
Multiple plots functions are used to visualize multiple features on multiple axes.
- Facet Grid
- PairGrid
- Pair Plot
We have already plotted multiple features in previous sections but this section is going to be these grid plots.
FacetGrid()¶
FacetGrid() is used to create multiple grid plots. It allows us to plot the variables on row and column axes, and we can also use hue
parameter to make the visual more clear based off a given feature.
What's interesting about FacetGrid is that you can choose whether you plot on row axe, column axe or both.
sns.FaceGrid() only create grids. In order to add visualizations, we will need to map it to a given plot type (scatter, histogram, bar....).
plot = sns.FacetGrid(tips, col='sex', hue='smoker')
plot.map(sns.scatterplot,'total_bill', 'tip')
plot.add_legend()
<seaborn.axisgrid.FacetGrid at 0x7f9acad22c90>
plot = sns.FacetGrid(tips, col='day', height=5, aspect=.5)
plot.map(sns.barplot, 'sex', 'total_bill');
/Users/jean/opt/miniconda3/envs/tensor/lib/python3.7/site-packages/seaborn/axisgrid.py:643: UserWarning: Using the barplot function without specifying `order` is likely to produce an incorrect plot. warnings.warn(warning)
PairGrid()¶
We can use pair grid to get the higher level overview of the dataset. It will plot the pairwise relationship in the dataset.
As you are going to see, it is much easier to use than FacetGrid.
plot = sns.PairGrid(tips)
plot.map(sns.scatterplot)
<seaborn.axisgrid.PairGrid at 0x7f9ade96b690>
It is also possible to be selective on the plot type you want at the diagonals.
plot = sns.PairGrid(tips, hue='sex')
plot.map_diag(sns.histplot)
plot.map_offdiag(sns.scatterplot)
plot.add_legend()
<string>:6: RuntimeWarning: Converting input from bool to <class 'numpy.uint8'> for compatibility. <string>:6: RuntimeWarning: Converting input from bool to <class 'numpy.uint8'> for compatibility. <string>:6: RuntimeWarning: Converting input from bool to <class 'numpy.uint8'> for compatibility.
<seaborn.axisgrid.PairGrid at 0x7f9ab85bcb90>
plot = sns.PairGrid(tips)
plot.map_diag(sns.histplot)
plot.map_upper(sns.scatterplot)
plot.map_lower(sns.kdeplot)
plot.add_legend()
<string>:6: RuntimeWarning: Converting input from bool to <class 'numpy.uint8'> for compatibility. <string>:6: RuntimeWarning: Converting input from bool to <class 'numpy.uint8'> for compatibility.
<seaborn.axisgrid.PairGrid at 0x7f9ab86b9390>
If you want, you have the option to select the features you are interested in instead of letting PairGrid plots everything.
plot = sns.PairGrid(tips, vars=['total_bill', 'tip'], hue='day')
plot.map(sns.scatterplot)
plot.add_legend()
<seaborn.axisgrid.PairGrid at 0x7f9adfd0f450>
Pairplot()¶
Pairplot is a simple, flexible and quick way to visualize the entire dataset. This can allows you to quickly understand the relationships between different features.
sns.pairplot(tips)
<string>:6: RuntimeWarning: Converting input from bool to <class 'numpy.uint8'> for compatibility. <string>:6: RuntimeWarning: Converting input from bool to <class 'numpy.uint8'> for compatibility.
<seaborn.axisgrid.PairGrid at 0x7f9b3fa80c50>
sns.pairplot(tips, hue='sex', height=5);
6. Matrix Plots: Heat Maps and Cluster Maps¶
In data analysis, sometime it is handy to visualize the data as color encoded matrices and can be used to find the clusters within the data.
Heat Maps¶
heatmap() will color the matrix.
Let's first see it for a numpy array and we will apply it to a real world dataset.
sns.set_theme()
data = np.random.randn(10,15)
sns.heatmap(data)
<AxesSubplot:>
titanic.head()
survived | pclass | sex | age | sibsp | parch | fare | embarked | class | who | adult_male | deck | embark_town | alive | alone | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | 3 | male | 22.0 | 1 | 0 | 7.2500 | S | Third | man | True | NaN | Southampton | no | False |
1 | 1 | 1 | female | 38.0 | 1 | 0 | 71.2833 | C | First | woman | False | C | Cherbourg | yes | False |
2 | 1 | 3 | female | 26.0 | 0 | 0 | 7.9250 | S | Third | woman | False | NaN | Southampton | yes | True |
3 | 1 | 1 | female | 35.0 | 1 | 0 | 53.1000 | S | First | woman | False | C | Southampton | yes | False |
4 | 0 | 3 | male | 35.0 | 0 | 0 | 8.0500 | S | Third | man | True | NaN | Southampton | no | True |
# Finding the correlation of features in titanic dataset
correlation = titanic.corr()
correlation
survived | pclass | age | sibsp | parch | fare | adult_male | alone | |
---|---|---|---|---|---|---|---|---|
survived | 1.000000 | -0.338481 | -0.077221 | -0.035322 | 0.081629 | 0.257307 | -0.557080 | -0.203367 |
pclass | -0.338481 | 1.000000 | -0.369226 | 0.083081 | 0.018443 | -0.549500 | 0.094035 | 0.135207 |
age | -0.077221 | -0.369226 | 1.000000 | -0.308247 | -0.189119 | 0.096067 | 0.280328 | 0.198270 |
sibsp | -0.035322 | 0.083081 | -0.308247 | 1.000000 | 0.414838 | 0.159651 | -0.253586 | -0.584471 |
parch | 0.081629 | 0.018443 | -0.189119 | 0.414838 | 1.000000 | 0.216225 | -0.349943 | -0.583398 |
fare | 0.257307 | -0.549500 | 0.096067 | 0.159651 | 0.216225 | 1.000000 | -0.182024 | -0.271832 |
adult_male | -0.557080 | 0.094035 | 0.280328 | -0.253586 | -0.349943 | -0.182024 | 1.000000 | 0.404744 |
alone | -0.203367 | 0.135207 | 0.198270 | -0.584471 | -0.583398 | -0.271832 | 0.404744 | 1.000000 |
sns.heatmap(correlation)
<AxesSubplot:>
If you want to add the values in addition to color, you can set the parameter annot
to True
.
sns.heatmap(correlation, annot=True)
<AxesSubplot:>
sns.clustermap(correlation)
<seaborn.matrix.ClusterGrid at 0x7f9ab9aaee50>
So far, you have saw how flexible Seaborn is in visualizing data with different kinds of plots.
By default, the Seaborn plots are clear and good looking. But there are times you are going to need more attractive visualizations and that is what is coming up.
7. Styles, Themes and Colors¶
Seaborn allows to customize the visualizations depending on our needs. We may want to have control on the plot styles, colors. Let's see how that works.
Styles and Themes¶
There are five styles in Seaborn: darkgrid(default), whitegrid, dark, white, and ticks.
sns.set_style('whitegrid')
sns.catplot(data=tips, x='day', y='total_bill', kind='bar')
<seaborn.axisgrid.FacetGrid at 0x7f9aa873af90>
sns.set_style('dark')
sns.catplot(data=tips, x='day', y='total_bill', kind='box')
<seaborn.axisgrid.FacetGrid at 0x7f9aa8b62a10>
sns.set_style('white')
sns.catplot(data=tips, x='day', y='total_bill', kind='boxen')
<seaborn.axisgrid.FacetGrid at 0x7f9aa8c5bcd0>
sns.set_style('ticks')
sns.catplot(data=tips, x='day', y='total_bill', kind='box')
<seaborn.axisgrid.FacetGrid at 0x7f9aa8c5b650>
Removing the Axes Spines¶
We can also use despine()
to remove the top and right axes spines.
sns.catplot(data=tips, x='day', y='total_bill', kind='box')
sns.despine()
We can also move the spines away from the data by setting the offset distance or points that spines should move away from the axes.
sns.catplot(data=tips, x='day', y='total_bill', kind='violin')
sns.despine(offset=10, trim=True)
Size and Aspect¶
We can use Matplotlib figsize function to change the size of Seaborn plots.
plt.figure(figsize=(10,8))
sns.histplot(data=tips, x='total_bill', hue='sex')
<AxesSubplot:xlabel='total_bill', ylabel='Count'>
You can also set the size of the plot by using parameters size and aspect
but they do not work in all plots.
Scaling plot elements with the context¶
Context is used to control the scale of the elements of the plot. This can be really helpful depending on where you want to use the visualizations.
We use sns.set_context()
to achieve that. There are four contexts: paper, notebook(default), talk, and poster
.
But we also have to reset the style first.
sns.set_theme()
sns.set_context('paper')
sns.kdeplot(data=tips, x='tip', hue='day')
<AxesSubplot:xlabel='tip', ylabel='Density'>
sns.set_context('talk')
sns.kdeplot(data=tips, x='tip', hue='day')
<AxesSubplot:xlabel='tip', ylabel='Density'>
sns.set_context('poster')
sns.kdeplot(data=tips, x='tip', hue='day')
<AxesSubplot:xlabel='tip', ylabel='Density'>
sns.set_context('notebook')
sns.kdeplot(data=tips, x='tip', hue='day')
<AxesSubplot:xlabel='tip', ylabel='Density'>
Colors¶
Seaborn allows us to choose color that we think can make the visuals attractive.
With whole range of color palettes, there are so many options to choose from.
We can either use sns.color_palette()
before each plot, or set palette
inside the plot definition.
sns.set_palette('rocket')
sns.kdeplot(data=tips, x='tip', hue='day', multiple='stack')
<AxesSubplot:xlabel='tip', ylabel='Density'>
sns.set_palette('viridis')
sns.kdeplot(data=tips, x='tip', hue='day', multiple='stack')
<AxesSubplot:xlabel='tip', ylabel='Density'>
sns.kdeplot(data=tips, x='tip', hue='day', multiple='stack', palette='icefire')
<AxesSubplot:xlabel='tip', ylabel='Density'>
As you can see, it makes the plots more appearing. There so many palettes and here is the list. You can play with them to see what color matches your purpose.
To see the available color palettes, check out the documentation or this cheat sheet.
# To see what the palette looks like
sns.color_palette('tab10')
sns.color_palette('dark')
Lastly, you can use sns.set_theme()
to directly set the style, palette, and context.
sns.set_theme(style='white', context='talk', palette='viridis')
sns.kdeplot(data=tips, x='tip', hue='day', multiple='stack')
<AxesSubplot:xlabel='tip', ylabel='Density'>
This is the end of the lab!!¶