wandb

Classes

WandbLogger

Log using Weights and Biases.

Weights and Biases Logger

class lightning.pytorch.loggers.wandb.WandbLogger(name=None, save_dir='.', version=None, offline=False, dir=None, id=None, anonymous=None, project=None, log_model=False, experiment=None, prefix='', checkpoint_name=None, add_file_policy='mutable', **kwargs)[source]

Bases: Logger

Log using Weights and Biases.

Installation and set-up

Install with pip:

pip install wandb

Create a WandbLogger instance:

from lightning.pytorch.loggers import WandbLogger

wandb_logger = WandbLogger(project="MNIST")

Pass the logger instance to the Trainer:

trainer = Trainer(logger=wandb_logger)

A new W&B run will be created when training starts if you have not created one manually before with wandb.init().

Log metrics

Log from LightningModule:

class LitModule(LightningModule):
    def training_step(self, batch, batch_idx):
        self.log("train/loss", loss)

Use directly wandb module:

wandb.log({"train/loss": loss})

Log hyper-parameters

Save LightningModule parameters:

class LitModule(LightningModule):
    def __init__(self, *args, **kwarg):
        self.save_hyperparameters()

Add other config parameters:

# add one parameter
wandb_logger.experiment.config["key"] = value

# add multiple parameters
wandb_logger.experiment.config.update({key1: val1, key2: val2})

# use directly wandb module
wandb.config["key"] = value
wandb.config.update()

Log gradients, parameters and model topology

Call the watch method for automatically tracking gradients:

# log gradients and model topology
wandb_logger.watch(model)

# log gradients, parameter histogram and model topology
wandb_logger.watch(model, log="all")

# change log frequency of gradients and parameters (100 steps by default)
wandb_logger.watch(model, log_freq=500)

# do not log graph (in case of errors)
wandb_logger.watch(model, log_graph=False)

The watch method adds hooks to the model which can be removed at the end of training:

wandb_logger.experiment.unwatch(model)

Log model checkpoints

Log model checkpoints at the end of training:

wandb_logger = WandbLogger(log_model=True)

Log model checkpoints as they get created during training:

wandb_logger = WandbLogger(log_model="all")

Custom checkpointing can be set up through ModelCheckpoint:

# log model only if `val_accuracy` increases
wandb_logger = WandbLogger(log_model="all")
checkpoint_callback = ModelCheckpoint(monitor="val_accuracy", mode="max")
trainer = Trainer(logger=wandb_logger, callbacks=[checkpoint_callback])

latest and best aliases are automatically set to easily retrieve a model checkpoint:

# reference can be retrieved in artifacts panel
# "VERSION" can be a version (ex: "v2") or an alias ("latest or "best")
checkpoint_reference = "USER/PROJECT/MODEL-RUN_ID:VERSION"

# download checkpoint locally (if not already cached)
run = wandb.init(project="MNIST")
artifact = run.use_artifact(checkpoint_reference, type="model")
artifact_dir = artifact.download()

# load checkpoint
model = LitModule.load_from_checkpoint(Path(artifact_dir) / "model.ckpt")

Log media

Log text with:

# using columns and data
columns = ["input", "label", "prediction"]
data = [["cheese", "english", "english"], ["fromage", "french", "spanish"]]
wandb_logger.log_text(key="samples", columns=columns, data=data)

# using a pandas DataFrame
wandb_logger.log_text(key="samples", dataframe=my_dataframe)

Log images with:

# using tensors, numpy arrays or PIL images
wandb_logger.log_image(key="samples", images=[img1, img2])

# adding captions
wandb_logger.log_image(key="samples", images=[img1, img2], caption=["tree", "person"])

# using file path
wandb_logger.log_image(key="samples", images=["img_1.jpg", "img_2.jpg"])

More arguments can be passed for logging segmentation masks and bounding boxes. Refer to Image Overlays documentation.

columns = ["caption", "image", "sound"]
data = [["cheese", wandb.Image(img_1), wandb.Audio(snd_1)], ["wine", wandb.Image(img_2), wandb.Audio(snd_2)]]
wandb_logger.log_table(key="samples", columns=columns, data=data)

Downloading and Using Artifacts

To download an artifact without starting a run, call the download_artifact function on the class:

from lightning.pytorch.loggers import WandbLogger

artifact_dir = WandbLogger.download_artifact(artifact="path/to/artifact")

To download an artifact and link it to an ongoing run call the download_artifact function on the logger instance:

class MyModule(LightningModule):
    def any_lightning_module_function_or_hook(self):
        self.logger.download_artifact(artifact="path/to/artifact")

To link an artifact from a previous run you can use use_artifact function:

from lightning.pytorch.loggers import WandbLogger

wandb_logger = WandbLogger(project="my_project", name="my_run")
wandb_logger.use_artifact(artifact="path/to/artifact")

See also

Parameters:
  • name (Optional[str]) – Display name for the run.

  • save_dir (Union[str, Path]) – Path where data is saved.

  • version (Optional[str]) – Sets the version, mainly used to resume a previous run.

  • offline (bool) – Run offline (data can be streamed later to wandb servers).

  • dir (Union[str, Path, None]) – Same as save_dir.

  • id (Optional[str]) – Same as version.

  • anonymous (Optional[bool]) – Enables or explicitly disables anonymous logging.

  • project (Optional[str]) – The name of the project to which this run will belong. If not set, the environment variable WANDB_PROJECT will be used as a fallback. If both are not set, it defaults to 'lightning_logs'.

  • log_model (Union[Literal['all'], bool]) –

    Log checkpoints created by ModelCheckpoint as W&B artifacts. latest and best aliases are automatically set.

    • if log_model == 'all', checkpoints are logged during training.

    • if log_model == True, checkpoints are logged at the end of training, except when save_top_k == -1 which also logs every checkpoint during training.

    • if log_model == False (default), no checkpoint is logged.

  • prefix (str) – A string to put at the beginning of metric keys.

  • experiment (Union[Run, RunDisabled, None]) – WandB experiment object. Automatically set when creating a run.

  • checkpoint_name (Optional[str]) – Name of the model checkpoint artifact being logged.

  • add_file_policy (Literal['mutable', 'immutable']) – If “mutable”, copies file to tempdirectory before upload.

  • **kwargs (Any) – Arguments passed to wandb.init() like entity, group, tags, etc.

Raises:
  • ModuleNotFoundError – If required WandB package is not installed on the device.

  • MisconfigurationException – If both log_model and offline is set to True.

after_save_checkpoint(checkpoint_callback)[source]

Called after model checkpoint callback saves a new checkpoint.

Parameters:

checkpoint_callback (ModelCheckpoint) – the model checkpoint callback instance

Return type:

None

static download_artifact(artifact, save_dir=None, artifact_type=None, use_artifact=True)[source]

Downloads an artifact from the wandb server.

Parameters:
  • artifact (str) – The path of the artifact to download.

  • save_dir (Union[str, Path, None]) – The directory to save the artifact to.

  • artifact_type (Optional[str]) – The type of artifact to download.

  • use_artifact (Optional[bool]) – Whether to add an edge between the artifact graph.

Return type:

str

Returns:

The path to the downloaded artifact.

finalize(status)[source]

Do any processing that is necessary to finalize an experiment.

Parameters:

status (str) – Status that the experiment finished with (e.g. success, failed, aborted)

Return type:

None

log_audio(key, audios, step=None, **kwargs)[source]

Log audios (numpy arrays, or file paths).

Parameters:
  • key (str) – The key to be used for logging the audio files

  • audios (list[Any]) – The list of audio file paths, or numpy arrays to be logged

  • step (Optional[int]) – The step number to be used for logging the audio files

  • **kwargs (Any) – Optional kwargs are lists passed to each Wandb.Audio instance (ex: caption, sample_rate).

Return type:

None

Optional kwargs are lists passed to each audio (ex: caption, sample_rate).

log_hyperparams(params)[source]

Record hyperparameters.

Parameters:
  • params (Union[dict[str, Any], Namespace]) – Namespace or Dict containing the hyperparameters

  • args – Optional positional arguments, depends on the specific logger being used

  • kwargs – Optional keyword arguments, depends on the specific logger being used

Return type:

None

log_image(key, images, step=None, **kwargs)[source]

Log images (tensors, numpy arrays, PIL Images or file paths).

Optional kwargs are lists passed to each image (ex: caption, masks, boxes).

Return type:

None

log_metrics(metrics, step=None)[source]

Records metrics. This method logs metrics as soon as it received them.

Parameters:
  • metrics (Mapping[str, float]) – Dictionary with metric names as keys and measured quantities as values

  • step (Optional[int]) – Step number at which the metrics should be recorded

Return type:

None

log_table(key, columns=None, data=None, dataframe=None, step=None)[source]

Log a Table containing any object type (text, image, audio, video, molecule, html, etc).

Can be defined either with columns and data or with dataframe.

Return type:

None

log_text(key, columns=None, data=None, dataframe=None, step=None)[source]

Log text as a Table.

Can be defined either with columns and data or with dataframe.

Return type:

None

log_video(key, videos, step=None, **kwargs)[source]

Log videos (numpy arrays, or file paths).

Parameters:
  • key (str) – The key to be used for logging the video files

  • videos (list[Any]) – The list of video file paths, or numpy arrays to be logged

  • step (Optional[int]) – The step number to be used for logging the video files

  • **kwargs (Any) – Optional kwargs are lists passed to each Wandb.Video instance (ex: caption, fps, format).

Return type:

None

Optional kwargs are lists passed to each video (ex: caption, fps, format).

use_artifact(artifact, artifact_type=None)[source]

Logs to the wandb dashboard that the mentioned artifact is used by the run.

Parameters:
  • artifact (str) – The path of the artifact.

  • artifact_type (Optional[str]) – The type of artifact being used.

Return type:

Artifact

Returns:

wandb Artifact object for the artifact.

property experiment: Union[wandb.wandb_run.Run, wandb.sdk.lib.RunDisabled]

Actual wandb object. To use wandb features in your LightningModule do the following.

Example:

.. code-block:: python

self.logger.experiment.some_wandb_function()

property name: Optional[str]

The project name of this experiment.

Returns:

The name of the project the current experiment belongs to. This name is not the same as wandb.Run’s name. To access wandb’s internal experiment name, use logger.experiment.name instead.

property save_dir: Optional[str]

Gets the save directory.

Returns:

The path to the save directory.

property version: Optional[str]

Gets the id of the experiment.

Returns:

The id of the experiment if the experiment exists else the id given to the constructor.