Extensions

class blocks.extensions.CallbackName[source]

Bases: str

A name of a TrainingExtension callback.

Raises:
  • TypeError on comparison with a string which is not a name of
  • TrainingExtension callback.
class blocks.extensions.CompositeExtension(sub_extensions, run_before_children=True, **kwargs)[source]

Bases: blocks.extensions.SimpleExtension

An extension that manages several other extensions.

Parameters:
  • sub_extensions (iterable) – An iterable collection of sub-extensions to manage.
  • run_before_children (bool, optional) – Whether the container extension’s own logic should be dispatched before that of the sub-extensions. If False, the containing extension is dispatched last. Defaults to True.

Notes

The main use case for this class is bundling together groups of extensions that are most commonly used in tandem, configured so as to interact with one another. Encapsulating this pattern in a single extension reduces boilerplate.

Sub-extensions are dispatched in the order specified in sub_extensions, on whatever triggers they are individually configured to respect.

Sub-extensions may be run on different triggers than the containing extension; the trigger keywords passed to the constructor for this class only affect the outer extension’s logic, and sub-extensions should be configured independently (possibly in a constructor for a subclass of CompositeExtension).

dispatch(callback_invoked, *from_main_loop)[source]

Check conditions and call the do() method.

Also adds additional arguments if specified for a condition.

Todo

Add a check for a situation when several conditions are met at the same time and do something.

do(which_callback, *args)[source]

Does the job of the training extension.

Parameters:
  • which_callback (str) – The name of the callback in the context of which do() is run.
  • *args (tuple) – The arguments from the main loop concatenated with additional arguments from user.

Notes

Subclasses must accept additional positional arguments in their call signature for this method, even if they are unused.

main_loop
class blocks.extensions.FinishAfter(**kwargs)[source]

Bases: blocks.extensions.SimpleExtension

Finishes the training process when triggered.

do(which_callback, *args)[source]

Does the job of the training extension.

Parameters:
  • which_callback (str) – The name of the callback in the context of which do() is run.
  • *args (tuple) – The arguments from the main loop concatenated with additional arguments from user.

Notes

Subclasses must accept additional positional arguments in their call signature for this method, even if they are unused.

class blocks.extensions.Predicate(condition, num)[source]

Bases: object

class blocks.extensions.Printing(**kwargs)[source]

Bases: blocks.extensions.SimpleExtension

Prints log messages to the screen.

do(which_callback, *args)[source]

Does the job of the training extension.

Parameters:
  • which_callback (str) – The name of the callback in the context of which do() is run.
  • *args (tuple) – The arguments from the main loop concatenated with additional arguments from user.

Notes

Subclasses must accept additional positional arguments in their call signature for this method, even if they are unused.

class blocks.extensions.ProgressBar(**kwargs)[source]

Bases: blocks.extensions.TrainingExtension

Display a progress bar during training.

This extension tries to infer the number of iterations per epoch by querying the num_batches, num_examples and batch_size attributes from the IterationScheme. When this information is not available it will display a simplified progress bar that does not include the estimated time until the end of this epoch.

Notes

This extension should be run before other extensions that print to the screen at the end or at the beginning of the epoch (e.g. the Printing extension). Placing ProgressBar before these extension will ensure you won’t get intermingled output on your terminal.

after_epoch()[source]

The callback invoked after an epoch is finished.

before_batch(batch)[source]

The callback invoked before a batch is processed.

Parameters:batch (object) – The data batch to be processed.
before_epoch()[source]

The callback invoked before starting an epoch.

create_bar()[source]

Create a new progress bar.

Calls self.get_iter_per_epoch(), selects an appropriate set of widgets and creates a ProgressBar.

get_iter_per_epoch()[source]

Try to infer the number of iterations per epoch.

class blocks.extensions.SimpleExtension(**kwargs)[source]

Bases: blocks.extensions.TrainingExtension

A base class for simple extensions.

All logic of simple extensions is concentrated in the method do(). This method is called when certain conditions are fulfilled. The user can manage the conditions by calling the add_condition method and by passing arguments to the constructor. In addition to specifying when do() is called, it is possible to specify additional arguments passed to do() under different conditions.

Parameters:
  • before_training (bool) – If True, do() is invoked before training.
  • before_first_epoch (bool) – If True, do() is invoked before the first epoch.
  • before_epoch (bool) – If True, do() is invoked before every epoch.
  • on_resumption (bool, optional) – If True, do() is invoked when training is resumed.
  • on_interrupt (bool, optional) – If True, do() is invoked when training is interrupted.
  • after_epoch (bool) – If True, do() is invoked after every epoch.
  • after_batch (bool) – If True, do() is invoked after every batch.
  • after_training (bool) – If True, do() is invoked after training.
  • after_n_epochs (int, optional) – If not None, do() is invoked when after_n_epochs epochs are done.
  • every_n_epochs (int, optional) – If not None, do() is invoked after every n-th epoch.
  • after_n_batches (int, optional) – If not None, do() is invoked when after_n_batches batches are processed.
  • every_n_batches (int, optional) – If not None, do() is invoked after every n-th batch.
BOOLEAN_TRIGGERS = frozenset(['before_batch', 'after_batch', 'after_training', 'before_epoch', 'before_training', 'on_error', 'before_first_epoch', 'after_epoch', 'on_interrupt', 'on_resumption'])
INTEGER_TRIGGERS = frozenset(['every_n_batches', 'after_n_epochs', 'every_n_epochs', 'after_n_batches'])
add_condition(callbacks_names, predicate=None, arguments=None)[source]

Adds a condition under which a do() is called.

Parameters:
  • callbacks_names (list of str) – The names of the callback in which the method.
  • predicate (function) – A predicate function the main loop’s log as the single parameter and returning True when the method should be called and False when should not. If None, an always True predicate is used.
  • arguments (iterable) – Additional arguments to be passed to do(). They will be concatenated with the ones passed from the main loop (e.g. the batch in case of after_epoch callback).
Returns:

Return type:

The extension object (allow chaining calls)

dispatch(callback_invoked, *from_main_loop)[source]

Check conditions and call the do() method.

Also adds additional arguments if specified for a condition.

Todo

Add a check for a situation when several conditions are met at the same time and do something.

do(which_callback, *args)[source]

Does the job of the training extension.

Parameters:
  • which_callback (str) – The name of the callback in the context of which do() is run.
  • *args (tuple) – The arguments from the main loop concatenated with additional arguments from user.

Notes

Subclasses must accept additional positional arguments in their call signature for this method, even if they are unused.

static parse_args(which_callback, args)[source]

Separates do() arguments coming from different sources.

When a do() method receives arguments from both the main loop (e.g. a batch) and the user, it often has to separate them. This method is the right tool to use.

Parameters:
  • which_callback (str) – The name of the callback.
  • args (iterable) – The arguments.
Returns:

  • from_main_loop (tuple)
  • from_user (tuple)

set_conditions(**kwargs)[source]

Set the conditions for which this extension should be run.

:param See the SimpleExtension docstring for a list of: :param possible parameters.:

class blocks.extensions.Timestamp(log_record='timestamp', separator=' ', **kwargs)[source]

Bases: blocks.extensions.SimpleExtension

Adds a human readable (ISO 8601) timestamp to the log.

Parameters:
  • log_record (str, optional) – The record name to use. Defaults to ‘timestamp’.
  • separator (str, optional) – Separator between the date and time. ISO 8601 specifies ‘T’. Here, we default to ‘ ‘ (blank space) for human readability.

Notes

By default, triggers after every epoch as well as before training starts, after training finishes, when an error occurs or when training is interrupted or resumed, as these are all generally useful circumstances for which to have a timestamp. These can be disabled by passing False as the appropriate keyword argument; see SimpleExtension.

DEFAULT_LOG_RECORD = 'timestamp'
do(*args)[source]

Does the job of the training extension.

Parameters:
  • which_callback (str) – The name of the callback in the context of which do() is run.
  • *args (tuple) – The arguments from the main loop concatenated with additional arguments from user.

Notes

Subclasses must accept additional positional arguments in their call signature for this method, even if they are unused.

get_timestamp()[source]
class blocks.extensions.Timing(prefix='', **kwargs)[source]

Bases: blocks.extensions.SimpleExtension

Add timing information to the log.

This adds data about the time spent in the algorithm’s process_batch() method as well as the time spent reading data per batch or epoch. It also reports the time spent initializing the algorithm.

Parameters:prefix (str) – Prefix to be added to the log record. Defaults to the empty string.

Notes

Add this extension before the Printing extension.

Created with callbacks like every_n_batches this extension averages the time.

This extension does not enable full profiling information. To see a full profile of the main loop at the end of training, use the profile configuration (e.g. by setting BLOCKS_PROFILE=true).

do(which_callback, *args)[source]

Does the job of the training extension.

Parameters:
  • which_callback (str) – The name of the callback in the context of which do() is run.
  • *args (tuple) – The arguments from the main loop concatenated with additional arguments from user.

Notes

Subclasses must accept additional positional arguments in their call signature for this method, even if they are unused.

class blocks.extensions.TrainingExtension(name=None)[source]

Bases: object

The base class for training extensions.

An extension is a set of callbacks sharing a joint context that are invoked at certain stages of the training procedure. These callbacks typically add a certain functionality to the training procedure, e.g. running validation on auxiliary datasets or early stopping.

Parameters:name (str, optional) – The name of the extension. The names are useful in order to distinguish between several extensions of the same type that belongs to the same main loop. By default the name is set to the name of the class.
main_loop

MainLoop – The main loop to which the extension belongs.

name

str – The name of the extension.

after_batch(batch)[source]

The callback invoked after a batch is processed.

Parameters:batch (object) – The data batch just processed.
after_epoch()[source]

The callback invoked after an epoch is finished.

after_training()[source]

The callback invoked after training is finished.

before_batch(batch)[source]

The callback invoked before a batch is processed.

Parameters:batch (object) – The data batch to be processed.
before_epoch()[source]

The callback invoked before starting an epoch.

before_training()[source]

The callback invoked before training is started.

dispatch(callback_name, *args)[source]

Runs callback with the given name.

The reason for having this method is to allow the descendants of the TrainingExtension to intercept callback invocations and do something with them, e.g. block when certain condition does not hold. The default implementation simply invokes the callback by its name.

main_loop
on_error(exception)[source]

The callback invoked when an error occurs.

Parameters:exception (object) – Exception occurred during the main loop run.
on_interrupt()[source]

The callback invoked when training is interrupted.

on_resumption()[source]

The callback invoked after training is resumed.

blocks.extensions.always_true(log)[source]
blocks.extensions.callback(func)[source]
blocks.extensions.has_done_epochs(log)[source]

Monitoring extensions

class blocks.extensions.monitoring.DataStreamMonitoring(variables, data_stream, updates=None, **kwargs)[source]

Bases: blocks.extensions.SimpleExtension, blocks.extensions.monitoring.MonitoringExtension

Monitors Theano variables and monitored-quantities on a data stream.

By default monitoring is done before the first and after every epoch.

Parameters:
  • variables (list of TensorVariable and) – MonitoredQuantity The variables to monitor. The variable names are used as record names in the logs.
  • updates (list of tuples or OrderedDict or None) – TensorSharedVariable updates to be performed during evaluation. This parameter is only for Theano variables. Be careful not to update any model parameters as this is not intended to alter your model in any meaningful way. A typical use case of this option arises when the theano function used for evaluation contains a call to scan() which might have returned shared variable updates.
  • data_stream (instance of DataStream) – The data stream to monitor on. A data epoch is requested each time monitoring is done.
do(callback_name, *args)[source]

Write the values of monitored variables to the log.

class blocks.extensions.monitoring.MonitoringExtension(prefix=None, suffix=None, **kwargs)[source]

Bases: blocks.extensions.TrainingExtension

A mixin with logic shared by monitoring extensions.

Parameters:
  • prefix (str, optional) – The prefix for the log records done by the extension. It is prepended to the variable names with an underscore as a separator. If not given, no prefix is added to the names of the observed variables.
  • suffix (str, optional) – The suffix for the log records done by the extension. It is appended to the end of variable names with an underscore as a separator. If not given, no suffix is added the names of the observed variables.
SEPARATOR = '_'
add_records(log, record_tuples)[source]

Helper function to add monitoring records to the log.

record_name(variable)[source]

The record name for a variable.

class blocks.extensions.monitoring.TrainingDataMonitoring(variables, **kwargs)[source]

Bases: blocks.extensions.SimpleExtension, blocks.extensions.monitoring.MonitoringExtension

Monitors values of Theano variables on training batches.

Use this extension to monitor a quantity on every training batch cheaply. It integrates with the training algorithm in order to avoid recomputing same things several times. For instance, if you are training a network and you want to log the norm of the gradient on every batch, the backpropagation will only be done once. By controlling the frequency with which the do() method is called, you can aggregate the monitored variables, e.g. only log the gradient norm average over an epoch.

Parameters:variables (list of TensorVariable or) – MonitoredQuantity The variables or non-Theano quantities to monitor. The variable names are used as record names in the logs.

Notes

All the monitored variables are evaluated _before_ the parameter update.

Requires the training algorithm to be an instance of UpdatesAlgorithm.

do(callback_name, *args)[source]

Initializes the buffer or commits the values to the log.

What this method does depends on from what callback it is called and with which arguments. When called within before_training, it initializes the aggregation buffer and instructs the training algorithm what additional computations should be carried at each step by adding corresponding updates to it. In most_other cases it writes aggregated values of the monitored variables to the log. An exception is when an argument just_aggregate is given: in this cases it updates the values of monitored non-Theano quantities, but does not write anything to the log.

blocks.extensions.monitoring.take_last(variable)

Training

class blocks.extensions.training.SharedVariableModifier(parameter, function, num_args=None, **kwargs)[source]

Bases: blocks.extensions.SimpleExtension

Adjusts shared variable parameter using some function.

Applies a function to compute the new value of a shared parameter each iteration.

This class can be used to adapt over the training process parameters like learning rate, momentum, etc.

Parameters:
  • parameter (TensorSharedVariable) – Shared variable to be adjusted
  • function (callable) –

    A function which outputs a numeric value to which the given shared variable will be set and may take one or two arguments.

    In the first case, function that takes the total number of iterations done (int) as an input.

    In the second case, it is a function which takes number of iterations done (int) and old value of the shared variable (with the same dtype as parameter).

  • num_args (int, optional) – The number of arguments to pass to the function. If unspecified, it will be inferred. This is useful if you are using function-like objects for which the arity of the function cannot be inferred.

Notes

This class includes a method function that calls the function passed in the constructor and a num_args property which computes the number of arguments to use by inspecting the function object. Subclasses may override a method called function and/or the num_args property and instead pass None to the superclass constructor. This can be used to bypass certain serialization issues on Legacy Python regarding the unpicklability of instance method objects.

do(which_callback, *args)[source]

Does the job of the training extension.

Parameters:
  • which_callback (str) – The name of the callback in the context of which do() is run.
  • *args (tuple) – The arguments from the main loop concatenated with additional arguments from user.

Notes

Subclasses must accept additional positional arguments in their call signature for this method, even if they are unused.

function(*args)[source]
num_args
class blocks.extensions.training.TrackTheBest(record_name, notification_name=None, choose_best=<built-in function min>, **kwargs)[source]

Bases: blocks.extensions.SimpleExtension

Check if a log quantity has the minimum/maximum value so far.

Parameters:
  • record_name (str) – The name of the record to track.
  • notification_name (str, optional) – The name for the record to be made in the log when the current value of the tracked quantity is the best so far. It not given, ‘record_name’ plus “best_so_far” suffix is used.
  • choose_best (callable, optional) – A function that takes the current value and the best so far and return the best of two. By default min(), which corresponds to tracking the minimum value.
best_name

str – The name of the status record to keep the best value so far.

notification_name

str – The name of the record written to the log when the current value of the tracked quantity is the best so far.

Notes

In the likely case that you are relying on another extension to add the tracked quantity to the log, make sure to place this extension after the extension that writes the quantity to the log in the extensions argument to blocks.main_loop.MainLoop.

do(which_callback, *args)[source]

Does the job of the training extension.

Parameters:
  • which_callback (str) – The name of the callback in the context of which do() is run.
  • *args (tuple) – The arguments from the main loop concatenated with additional arguments from user.

Notes

Subclasses must accept additional positional arguments in their call signature for this method, even if they are unused.

Serialization

class blocks.extensions.saveload.Checkpoint(path, parameters=None, save_separately=None, save_main_loop=True, use_cpickle=False, **kwargs)[source]

Bases: blocks.extensions.SimpleExtension

Saves a pickled version of the main loop to the disk.

The pickled main loop can be later reloaded and training can be resumed.

Makes a SAVED_TO record in the log with the serialization destination in the case of success and None in the case of failure. The value of the record is a tuple of paths to which saving was done (there can be more than one if the user added a condition with an argument, see do() docs).

Parameters:
  • path (str) – The destination path for pickling.
  • parameters (list, optional) – The parameters to save separately. If None, the parameters from the model (main_loop.model.parameters) are saved.
  • save_separately (list of str, optional) – The list of the main loop’s attributes to be saved (copied) in a separate file in the tar archive. It may be used for example to save the log separetely. The name of the attribute will be used as name in the tar file.
  • save_main_loop (bool) – Choose whether to save the main loop or not. This can be useful for example if you are only interested in saving the parameters, but not the whole main loop. Defaults to True.
  • use_cpickle (bool) – See documentation of dump().

Notes

Using pickling for saving the whole main loop object comes with certain limitations:

  • Theano computation graphs build in the GPU-mode (theano.config.device == “gpu”) can not be used in the usual mode (and vice-versa). Therefore using this extension binds you to using only one kind of device.
do(callback_name, *args)[source]

Pickle the main loop object to the disk.

If *args contain an argument from user, it is treated as saving path to be used instead of the one given at the construction stage.

class blocks.extensions.saveload.Load(path, load_iteration_state=False, load_log=False, **kwargs)[source]

Bases: blocks.extensions.SimpleExtension

Loads a saved checkpoint into the main loop.

Makes a LOADED_FROM record in the log with the dump path.

Parameters:
  • path (str) – The path to the folder with dump.
  • load_iteration_state (bool) – If True, load the iteration state. This can be useful when your model has very long epochs, and you want to resume when you were in the middle of one. Defaults to False.
  • load_log (bool) – If True, load the old log and continue logging from there. Convenient because you end up with a single log of the entire training history. Defaults to False.

Notes

Requires the model to be created entirely using bricks, with a unique path/name for each brick, so that the parameters can be matched to their values.

In order to load the iteration state and the log, the saved model needs to be unpickled. Note that resuming training this way is still not entirely seamless because e.g. extensions will not be reloaded.

do(*args, **kwargs)[source]

Does the job of the training extension.

Parameters:
  • which_callback (str) – The name of the callback in the context of which do() is run.
  • *args (tuple) – The arguments from the main loop concatenated with additional arguments from user.

Notes

Subclasses must accept additional positional arguments in their call signature for this method, even if they are unused.

load_to(main_loop)[source]