Skip to content

Jerky For Dogs Solution


Here we build the plot, then make it better and better in subsequent "drafts".

import matplotlib.pyplot as plt

# Instantiate a figure with 4 Axes (2 rows, 2 columns)
# Use tuple unpacking to get each Axes object
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, figsize=(8,4))
fig.suptitle('Sales by date')

ax1.plot(dates[flavors == 'traditional'], sales[flavors == 'traditional'])
ax1.set_title('traditional')

ax2.plot(dates[flavors == 'salsa'], sales[flavors == 'salsa'])
ax2.set_title('salsa')

ax3.plot(dates[flavors == 'sweet orange'], sales[flavors == 'sweet orange'])
ax3.set_title('sweet orange')

ax4.plot(dates[flavors == 'smokey'], sales[flavors == 'smokey'])
ax4.set_title('smokey')

  1. Instantiate a figure with a 2x2 grid of subplots. (See pyplot.subplots.)

    import matplotlib.pyplot as plt
    
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(nrows=2, ncols=2, figsize=(8,4))
    

    plt.subplots(2, 2, figsize=(8,4)) returns two things:

    • a figure object
    • a 2x2 numpy array of Axes objects.

    We "unpack" all five of these objects using fig and a tuple of tuples of ax variables.

    Subplot arrangement

    The axes in this example are arranged in a grid like this:

    ax1 ax2
    ax3 ax4
    
  2. Add a title to the entire figure.

    fig.suptitle('Sales by date')
    
  3. Plot the data on each Axes.

    ax1.plot(dates[flavors == 'traditional'], sales[flavors == 'traditional']) # (1)!
    ax1.set_title('traditional') # (2)!
    
    1. Subset dates and sales by flavors == 'traditional', then make a simple line plot from this data using Axes.plot().
    2. Give the Axes a title with Axes.set_title().

    Repeat this for each (flavor, Axes) pair.

Issue
The y axes for traditional jerky ranges from 0 to 150, but the y axis for salsa jerky ranges from 0 to 1500! It's unfair to compare these flavors until they're on a consistent y scale. (The x axes suffers same issue.) We'll fix this in the next draft.

import matplotlib.pyplot as plt

# Instantiate a figure with 4 Axes (2 rows, 2 columns)
# Use tuple unpacking to get each Axes object
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(8,4), sharex=True, sharey=True)
fig.suptitle('Sales by date')

ax1.plot(dates[flavors == 'traditional'], sales[flavors == 'traditional'])
ax1.set_title('traditional')

ax2.plot(dates[flavors == 'salsa'], sales[flavors == 'salsa'])
ax2.set_title('salsa')

ax3.plot(dates[flavors == 'sweet orange'], sales[flavors == 'sweet orange'])
ax3.set_title('sweet orange')

ax4.plot(dates[flavors == 'smokey'], sales[flavors == 'smokey'])
ax4.set_title('smokey')

  1. We add sharex=True, sharey=True parameters to plt.subplots(). This tells Matplotlib to use the same x scale and y scale for every Axes object.

Issue
The layout's a bit messy and the x tick labels are unreadable. We'll fix this in the next draft.

import matplotlib.pyplot as plt

# Instantiate a figure with 4 Axes (2 rows, 2 columns)
# Use tuple unpacking to get each Axes object
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(8,4), sharex=True, sharey=True, layout='tight')
fig.suptitle('Sales by date')

ax1.plot(dates[flavors == 'traditional'], sales[flavors == 'traditional'])
ax1.set_title('traditional')

ax2.plot(dates[flavors == 'salsa'], sales[flavors == 'salsa'])
ax2.set_title('salsa')

ax3.plot(dates[flavors == 'sweet orange'], sales[flavors == 'sweet orange'])
ax3.set_title('sweet orange')

ax4.plot(dates[flavors == 'smokey'], sales[flavors == 'smokey'])
ax4.set_title('smokey')

# Adjust the x label formats
fig.autofmt_xdate()

  1. We add layout='tight' to plt.subplots() to reduce the whitespace around the grid of Axes.
  2. We add fig.autofmt_xdate() to automatically rotate the x tick labels so they're not overlapping.

Issue
There's still a lot of manual code repetition. We should follow the Don't Repeat Yourself (DRY) principle of coding.

import matplotlib.pyplot as plt

# Instantiate a figure with 4 Axes (2 rows, 2 columns)
fig, axs = plt.subplots(2, 2, figsize=(8,4), sharex=True, sharey=True, layout='tight')
fig.suptitle('Sales by date')

# Loop over Axes, flavors
unique_flavors = np.unique(flavors)
for flav, ax in zip(unique_flavors, axs.ravel()):
    ax.plot(dates[flavors == flav], sales[flavors == flav])
    ax.set_title(flav)

# Adjust the x label formats
fig.autofmt_xdate()

  1. Instead of unpacking the Axes objects into four variables (ax1, ax2, ax3, ax4), we store all four Axes in a single variable (axs). Keep in mind, axs is a NumPy ndarray.

    type(axs)
    # <class 'numpy.ndarray'>
    
    axs
    # array([[<AxesSubplot:>, <AxesSubplot:>],
    #        [<AxesSubplot:>, <AxesSubplot:>]], dtype=object)
    
  2. Next we get the unique flavors from the flavors array.

    unique_flavors = np.unique(flavors)
    
    print(unique_flavors)
    # ['salsa' 'smokey' 'sweet orange' 'traditional']
    
  3. Then we iterate through each (flavor, Axes) pair, drawing the line plot for that flavor and setting the Axes title.

    for flav, ax in zip(unique_flavors, axs.ravel()): # (1)!
        ax.plot(dates[flavors == flav], sales[flavors == flav])
        ax.set_title(flav)
    
    1. axs.ravel() flattens the (2,2) array into a (4,) array.

      print(axs)
      # [[<AxesSubplot:> <AxesSubplot:>]
      #  [<AxesSubplot:> <AxesSubplot:>]]
      
      print(axs.ravel())
      # [<AxesSubplot:> <AxesSubplot:> <AxesSubplot:> <AxesSubplot:>]
      

      We could also use axs.flatten(), but flatten() returns a copy of the array whereas ravel() returns a view of the array.


See the problem