Developer guide

In the following, we introduce best practices regarding the implementation workflow before going into detail about how to take out custom implementations.

Workflow

Before you push to the main branch, please test the code and the documentation locally.

Unit testing

Run tests locally with the unittest package.

python -m venv venv
venv/bin/pip install -e .[tests]
venv/bin/python -m unittest

As soon as you push to the main branch, GitHub Actions will take out these unit tests, too.

Documentation

After locally building the documentation, open docs/build/html/index.html in your browser.

venv/bin/pip install -e .[docs]
venv/bin/sphinx-build -M html docs/source docs/build

As soon as you push to the main branch, GitHub Actions will build the documentation, push it to the gh-pages branch, and publish the result on GitHub Pages: https://mirkobunse.github.io/qunfold

Custom implementations

Custom Losses and Data representations can be used in any instance of LinearMethod. Use the already existing implementations as examples.

Losses

The most convenient way of implementing a custom loss is to create a JAX-powered function (p, q, M, N) -> loss_value. From this function, you can create a FunctionLoss object to be used in any instance of LinearMethod.

class qunfold.methods.linear.losses.FunctionLoss(loss_function: Callable)

Create a loss object from a JAX function (p, q, M, N) -> loss_value.

Using this class is likely more convenient than subtyping AbstractLoss. In both cases, the loss_value has to be the result of a JAX expression. The JAX requirement ensures that the loss function can be auto-differentiated. Hence, no derivatives of the loss function have to be provided manually. JAX expressions are easy to implement. Just import the numpy wrapper

>>> import jax.numpy as jnp

and use jnp just as if you would use numpy.

Note

p is a vector of class-wise probabilities. This vector will already be the result of our soft-max trick, so that you don’t have to worry about constraints or latent parameters.

Parameters:

loss_function – A JAX function (p, q, M, N) -> loss_value.

Examples

The least squares loss, (q - M*p)’ * (q - M*p), is simply

>>> def least_squares(p, q, M, N):
>>>     jnp.dot(q - jnp.dot(M, p), q - jnp.dot(M, p))

and thereby ready to be used in a FunctionLoss object:

>>> least_squares_loss = FunctionLoss(least_squares)

If you require more freedom in implementing a custom loss, you can also create a sub-class of AbstractLoss.

class qunfold.methods.linear.losses.AbstractLoss

Abstract base class for loss functions and for regularization terms.

abstract instantiate(q, M, N)

This abstract method has to create a lambda expression p -> loss with JAX.

In particular, your implementation of this abstract method should return a lambda expression

>>> return lambda p: loss_value(q, M, p, N)

where loss_value has to return the result of a JAX expression. The JAX requirement ensures that the loss function can be auto-differentiated. Hence, no derivatives of the loss function have to be provided manually. JAX expressions are easy to implement. Just import the numpy wrapper

>>> import jax.numpy as jnp

and use jnp just as if you would use numpy.

Note

p is a vector of class-wise probabilities. This vector will already be the result of our soft-max trick, so that you don’t have to worry about constraints or latent parameters.

Parameters:
  • q – A numpy array.

  • M – A numpy matrix.

  • N – The number of data items that q represents.

Returns:

A lambda expression p -> loss, implemented in JAX.

Examples

The least squares loss, (q - M*p)’ * (q - M*p), is simply

>>> jnp.dot(q - jnp.dot(M, p), q - jnp.dot(M, p))

Data representations

To implement a custom data representation, you have to create a sub-class of AbstractRepresentation.

class qunfold.methods.linear.representations.AbstractRepresentation

Abstract base class for representations.

abstract fit_transform(X, y, average=True, n_classes=None)

This abstract method has to fit the representation and to return the transformed input data.

Note

Implementations of this abstract method should check the sanity of labels by calling check_y(y, n_classes) and they must set the property self.p_trn = class_prevalences(y, n_classes).

Parameters:
  • X – The feature matrix to which this representation will be fitted.

  • y – The labels to which this representation will be fitted.

  • average (optional) – Whether to return a transfer matrix M or a transformation (f(X), y). Defaults to True.

  • n_classes (optional) – The number of expected classes. Defaults to None.

Returns:

A transfer matrix M if average==True or a transformation (f(X), y) if average==False.

abstract transform(X, average=True)

This abstract method has to transform the data X.

Parameters:
  • X – The feature matrix that will be transformed.

  • average (optional) – Whether to return a vector q or a transformation f(X). Defaults to True.

Returns:

A vector q = f(X).mean(axis=0) if average==True or a transformation f(X) if average==False.

For those representations that use a kernel embedding, you can provide a KernelRepresentation with your kernel function.

class qunfold.KernelRepresentation

A general kernel-based data representation, as it is used in KMM. If you intend to use a Gaussian kernel or energy kernel, prefer their dedicated and more efficient implementations over this class.

Note

The methods of this representation do not support setting average=False.

Parameters:

kernel – A callable that will be used as the kernel. Must follow the signature (X[y==i], X[y==j]) -> scalar.