Skip to content

Commit

Permalink
Add a callback system
Browse files Browse the repository at this point in the history
  • Loading branch information
ianhi committed Oct 27, 2021
1 parent 142b0c2 commit 9fa8daa
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 3 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pip install mpl-image-labeller
- Displays images with correct aspect ratio
- Easily configurable keymap
- Smart interactions with default Matplotlib keymap
- Callback System (see `examples/callbacks.py`)

![gif of usage for labelling images of cats and dogs](example.gif)

Expand Down
21 changes: 21 additions & 0 deletions examples/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import matplotlib.pyplot as plt
import numpy as np

from mpl_image_labeller import image_labeller

images = np.random.randn(5, 10, 10)
labeller = image_labeller(images, classes=["good", "bad", "blarg"])


def image_changed_callback(index, image):
print(index)
print(image.sum())


def label_assigned(index, label):
print(f"label {label} assigned to image {index}")


labeller.on_image_changed(image_changed_callback)
labeller.on_label_assigned(image_changed_callback)
plt.show()
47 changes: 44 additions & 3 deletions mpl_image_labeller/_labeller.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
from matplotlib.backend_bases import key_press_handler
from matplotlib.cbook import CallbackRegistry
from matplotlib.figure import Figure


Expand Down Expand Up @@ -155,6 +156,7 @@ def _get_image(i):
)

self._fig.canvas.mpl_connect("key_press_event", self._key_press)
self._observers = CallbackRegistry()

@property
def ax(self):
Expand Down Expand Up @@ -205,6 +207,7 @@ def _update_displayed(self):
self._im.set_data(image)
self._im.set_extent((-0.5, image.shape[1] - 0.5, image.shape[0] - 0.5, -0.5))
self._update_title()
self._observers.process("image-changed", self._image_index, image)
self._fig.canvas.draw_idle()

def _key_press(self, event):
Expand All @@ -213,9 +216,9 @@ def _key_press(self, event):
elif event.key == "right":
self.image_index += 1
elif event.key in self._label_keymap:
self._labels[self._image_index] = self._classes[
self._label_keymap[event.key]
]
klass = self._classes[self._label_keymap[event.key]]
self._labels[self._image_index] = klass
self._observers.process("label-assigned", self._image_index, klass)
if self._label_advances:
if self.image_index == self._N_images - 1:
# make sure we update the title we are on the last image
Expand All @@ -228,3 +231,41 @@ def _key_press(self, event):
self._update_title()
# TODO: blit just the text here
self._fig.canvas.draw_idle()

def on_label_assigned(self, func):
"""
Connect *func* as a callback function for when a label is assigned
to an image. *func* will receive the index of the image and the
new class.
Parameters
----------
func : callable
Function to call when a point is added.
Returns
-------
int
Connection id (which can be used to disconnect *func*).
"""
return self._observers.connect("label-assigned", lambda *args: func(*args))

def on_image_changed(self, func):
"""
Connect *func* as a callback function for when the displayed image
is changed. *func* will receive the index of the new image and the
image. `fig.canvas.draw_idle` will be called after the callback is
executed so if you are modifying the figure then you do not need to
explicitly call *draw* yourself.
Parameters
----------
func : callable
Function to call when a point is added.
Returns
-------
int
Connection id (which can be used to disconnect *func*).
"""
return self._observers.connect("image-changed", lambda *args: func(*args))

0 comments on commit 9fa8daa

Please sign in to comment.