twodlearn.bayesnet.losses module

class twodlearn.bayesnet.losses.KLDivergence(distribution_a, distribution_b, name=None)[source]

Bases: twodlearn.core.common.TdlModel

Get the KL-divergence KL(distribution_a || distribution_b). If there is no KL method registered specifically for type(distribution_a) and type(distribution_b), then the class hierarchies of these types are searched.

If one KL method is registered between any pairs of classes in these two parent hierarchies, it is used.

If more than one such registered method exists, the method whose registered classes have the shortest sum MRO paths to the input types is used.

If more than one such shortest path exists, the first method identified in the search is used (favoring a shorter MRO distance to type(distribution_a)).

distribution_a[source]
distribution_b[source]
value[source]
class twodlearn.bayesnet.losses.RegisterKL(dist_cls_a, dist_cls_b)[source]

Bases: object

Decorator to register a KL divergence implementation function.

Usage

@distributions.RegisterKL(distributions.Normal, distributions.Normal)
def _kl_normal_mvn(norm_a, norm_b):
    # Return KL(norm_a || norm_b)