Jerky For Dogs Solution¶
Here we build the plot, then make it better and better in subsequent "drafts".
import matplotlib.pyplot as plt
# Instantiate a figure with 4 Axes (2 rows, 2 columns)
# Use tuple unpacking to get each Axes object
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, figsize=(8,4))
fig.suptitle('Sales by date')
ax1.plot(dates[flavors == 'traditional'], sales[flavors == 'traditional'])
ax1.set_title('traditional')
ax2.plot(dates[flavors == 'salsa'], sales[flavors == 'salsa'])
ax2.set_title('salsa')
ax3.plot(dates[flavors == 'sweet orange'], sales[flavors == 'sweet orange'])
ax3.set_title('sweet orange')
ax4.plot(dates[flavors == 'smokey'], sales[flavors == 'smokey'])
ax4.set_title('smokey')
-
Instantiate a figure with a 2x2 grid of subplots. (See pyplot.subplots.)
import matplotlib.pyplot as plt fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, figsize=(8,4))
plt.subplots(2, 2, figsize=(8,4))
returns two things:- a figure object
- a 2x2 numpy array of Axes objects.
We "unpack" all five of these objects using
fig
and a tuple of tuples ofax
variables.Subplot arrangement
The axes in this example are arranged in a grid like this:
ax1 ax2 ax3 ax4
-
Add a title to the entire figure.
fig.suptitle('Sales by date')
-
Plot the data on each Axes.
ax1.plot(dates[flavors == 'traditional'], sales[flavors == 'traditional']) # (1)! ax1.set_title('traditional') # (2)!
- Subset
dates
andsales
byflavors == 'traditional'
, then make a simple line plot from this data usingAxes.plot()
. - Give the Axes a title with
Axes.set_title()
.
Repeat this for each (flavor, Axes) pair.
- Subset
Issue
The y axes for traditional jerky ranges from 0 to 150, but the y axis for salsa jerky ranges from 0 to 1500! It's
unfair to compare these flavors until they're on a consistent y scale. (The x axes suffers same issue.) We'll fix this
in the next draft.
import matplotlib.pyplot as plt
# Instantiate a figure with 4 Axes (2 rows, 2 columns)
# Use tuple unpacking to get each Axes object
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(8,4), sharex=True, sharey=True)
fig.suptitle('Sales by date')
ax1.plot(dates[flavors == 'traditional'], sales[flavors == 'traditional'])
ax1.set_title('traditional')
ax2.plot(dates[flavors == 'salsa'], sales[flavors == 'salsa'])
ax2.set_title('salsa')
ax3.plot(dates[flavors == 'sweet orange'], sales[flavors == 'sweet orange'])
ax3.set_title('sweet orange')
ax4.plot(dates[flavors == 'smokey'], sales[flavors == 'smokey'])
ax4.set_title('smokey')
- We add
sharex=True, sharey=True
parameters toplt.subplots()
. This tells Matplotlib to use the same x scale and y scale for every Axes object.
Issue
The layout's a bit messy and the x tick labels are unreadable. We'll fix this in the next draft.
import matplotlib.pyplot as plt
# Instantiate a figure with 4 Axes (2 rows, 2 columns)
# Use tuple unpacking to get each Axes object
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(8,4), sharex=True, sharey=True, layout='tight')
fig.suptitle('Sales by date')
ax1.plot(dates[flavors == 'traditional'], sales[flavors == 'traditional'])
ax1.set_title('traditional')
ax2.plot(dates[flavors == 'salsa'], sales[flavors == 'salsa'])
ax2.set_title('salsa')
ax3.plot(dates[flavors == 'sweet orange'], sales[flavors == 'sweet orange'])
ax3.set_title('sweet orange')
ax4.plot(dates[flavors == 'smokey'], sales[flavors == 'smokey'])
ax4.set_title('smokey')
# Adjust the x label formats
fig.autofmt_xdate()
- We add
layout='tight'
toplt.subplots()
to reduce the whitespace around the grid of Axes. - We add
fig.autofmt_xdate()
to automatically rotate the x tick labels so they're not overlapping.
Issue
There's still a lot of manual code repetition. We should follow the
Don't Repeat Yourself (DRY) principle of coding.
import matplotlib.pyplot as plt
# Instantiate a figure with 4 Axes (2 rows, 2 columns)
fig, axs = plt.subplots(2, 2, figsize=(8,4), sharex=True, sharey=True, layout='tight')
fig.suptitle('Sales by date')
# Loop over Axes, flavors
unique_flavors = np.unique(flavors)
for flav, ax in zip(unique_flavors, axs.ravel()):
ax.plot(dates[flavors == flav], sales[flavors == flav])
ax.set_title(flav)
# Adjust the x label formats
fig.autofmt_xdate()
-
Instead of unpacking the Axes objects into four variables (
ax1
,ax2
,ax3
,ax4
), we store all four Axes in a single variable (axs
). Keep in mind,axs
is a NumPy ndarray.type(axs) # <class 'numpy.ndarray'> axs # array([[<AxesSubplot:>, <AxesSubplot:>], # [<AxesSubplot:>, <AxesSubplot:>]], dtype=object)
-
Next we get the unique flavors from the
flavors
array.unique_flavors = np.unique(flavors) print(unique_flavors) # ['salsa' 'smokey' 'sweet orange' 'traditional']
-
Then we iterate through each (flavor, Axes) pair, drawing the line plot for that flavor and setting the Axes title.
for flav, ax in zip(unique_flavors, axs.ravel()): # (1)! ax.plot(dates[flavors == flav], sales[flavors == flav]) ax.set_title(flav)
-
axs.ravel()
flattens the(2,2)
array into a(4,)
array.print(axs) # [[<AxesSubplot:> <AxesSubplot:>] # [<AxesSubplot:> <AxesSubplot:>]] print(axs.ravel()) # [<AxesSubplot:> <AxesSubplot:> <AxesSubplot:> <AxesSubplot:>]
We could also use
axs.flatten()
, butflatten()
returns a copy of the array whereasravel()
returns a view of the array.
-