Skip to content

Koala Speeds Solution


import matplotlib.pyplot as plt

# Instantiate a figure with a single Axes
fig, ax = plt.subplots(figsize=(8,6), layout='tight')

# Make a scatter plot
ax.scatter(x=weights, y=speeds, s=ages**2, marker='x')

# Plot regression line
xlims = np.array(ax.get_xlim())
ax.plot(xlims, m * xlims + b, c='r')

# Set title & labels
ax.set_title('Speed vs weight of adult koalas')
ax.set_xlabel('weight')
ax.set_ylabel('speed')

# Insert equation in the upper right corner
ax.text(
    x=0.98, 
    y=0.98, 
    s=f'speed = {m}*weight + {b}', 
    ha='right', 
    va='top', 
    transform=ax.transAxes
)

Explanation

  1. Make a figure with a single Axes.

    import matplotlib.pyplot as plt
    
    fig, ax = plt.subplots(figsize=(8,6), layout='tight')
    
    • plt.subplots() makes a figure with one Axes. (See matplotlib.pyplot.subplots.)
    • figsize=(8,6) sets the figure dimensions as 8 in. by 6 in. (See the figsize parameter of matplotlib.figure.)
    • layout='tight' "adjusts the subplot parameters so that decorations like tick labels, axis labels and titles have enough space". (See the layout parameter of matplotlib.figure.)
  2. Make a scatter plot.

    ax.scatter(
        x=weights,   # set x as the weights
        y=speeds,    # set y as the speeds
        s=ages**2,   # set ages**2 as the marker sizes
        marker='x'   # set the marker style as "x"
    )
    
  3. Plot the regression line

    xlims = np.array(ax.get_xlim())
    ax.plot(xlims, m * xlims + b, c='r')
    

    The strategy here is to plot a line between two points: (x1, y1) to (x2, y2). We want the x values to span the x limits of ax, which we can fetch with

    xlims = np.array(ax.get_xlim())
    print(xlims)  # [ 9.55113043 20.10128969]
    

    Then we calculate the corresponding y values.

    print(m * xlims + b)
    # [6.76733044 5.18480655]
    

    Passing these into ax.plot() plots a line between them, by default.

    ax.plot(xlims, m * xlims + b, c='r') # (1)!
    
    1. c='r' makes the line color red.
  4. Set title & labels.

    ax.set_title('Speed vs weight of adult koalas')
    ax.set_xlabel('weight')
    ax.set_ylabel('speed')
    
    5. Insert the equation in the upper right corner.

    ax.text(
            x=0.98,                        # (1)!
            y=0.98,                        # (2)!
            s=f'speed = {m}*weight + {b}', # (3)!
            ha='right',                    # (4)!
            va='top',                      # (5)!
            transform=ax.transAxes         # (6)!
        )
    
    1. Anchor the text at x position 0.98 (near the right edge)
    2. Anchor the text at y position 0.98 (near the top edge)
    3. Use an f-string to generate the text representing the equation "speed = m*weight + b"
    4. horizontally align the text to the right (so the far right of the text lands at position x = 0.98)
    5. vertically align the text to the top (so the top of the text lands at position y = 0.98)
    6. Apply a transformation that lets us input values in the range [0, 1].

    Here we use ax.text() to add text onto the Axes. Our trick to placing the text in the top right is to use transform=ax.transAxes which allows us to specify x and y in the range [0, 1] where 0 is the left/bottom edge and 1 is the right/top edge.