WandbLogger¶
- class lightning.pytorch.loggers.WandbLogger(name=None, save_dir='.', version=None, offline=False, dir=None, id=None, anonymous=None, project=None, log_model=False, experiment=None, prefix='', checkpoint_name=None, add_file_policy='mutable', **kwargs)[source]¶
Bases:
Logger
Log using Weights and Biases.
Installation and set-up
Install with pip:
pip install wandb
Create a WandbLogger instance:
from lightning.pytorch.loggers import WandbLogger wandb_logger = WandbLogger(project="MNIST")
Pass the logger instance to the Trainer:
trainer = Trainer(logger=wandb_logger)
A new W&B run will be created when training starts if you have not created one manually before with wandb.init().
Log metrics
Log from
LightningModule
:class LitModule(LightningModule): def training_step(self, batch, batch_idx): self.log("train/loss", loss)
Use directly wandb module:
wandb.log({"train/loss": loss})
Log hyper-parameters
Save
LightningModule
parameters:class LitModule(LightningModule): def __init__(self, *args, **kwarg): self.save_hyperparameters()
Add other config parameters:
# add one parameter wandb_logger.experiment.config["key"] = value # add multiple parameters wandb_logger.experiment.config.update({key1: val1, key2: val2}) # use directly wandb module wandb.config["key"] = value wandb.config.update()
Log gradients, parameters and model topology
Call the watch method for automatically tracking gradients:
# log gradients and model topology wandb_logger.watch(model) # log gradients, parameter histogram and model topology wandb_logger.watch(model, log="all") # change log frequency of gradients and parameters (100 steps by default) wandb_logger.watch(model, log_freq=500) # do not log graph (in case of errors) wandb_logger.watch(model, log_graph=False)
The watch method adds hooks to the model which can be removed at the end of training:
wandb_logger.experiment.unwatch(model)
Log model checkpoints
Log model checkpoints at the end of training:
wandb_logger = WandbLogger(log_model=True)
Log model checkpoints as they get created during training:
wandb_logger = WandbLogger(log_model="all")
Custom checkpointing can be set up through
ModelCheckpoint
:# log model only if `val_accuracy` increases wandb_logger = WandbLogger(log_model="all") checkpoint_callback = ModelCheckpoint(monitor="val_accuracy", mode="max") trainer = Trainer(logger=wandb_logger, callbacks=[checkpoint_callback])
latest and best aliases are automatically set to easily retrieve a model checkpoint:
# reference can be retrieved in artifacts panel # "VERSION" can be a version (ex: "v2") or an alias ("latest or "best") checkpoint_reference = "USER/PROJECT/MODEL-RUN_ID:VERSION" # download checkpoint locally (if not already cached) run = wandb.init(project="MNIST") artifact = run.use_artifact(checkpoint_reference, type="model") artifact_dir = artifact.download() # load checkpoint model = LitModule.load_from_checkpoint(Path(artifact_dir) / "model.ckpt")
Log media
Log text with:
# using columns and data columns = ["input", "label", "prediction"] data = [["cheese", "english", "english"], ["fromage", "french", "spanish"]] wandb_logger.log_text(key="samples", columns=columns, data=data) # using a pandas DataFrame wandb_logger.log_text(key="samples", dataframe=my_dataframe)
Log images with:
# using tensors, numpy arrays or PIL images wandb_logger.log_image(key="samples", images=[img1, img2]) # adding captions wandb_logger.log_image(key="samples", images=[img1, img2], caption=["tree", "person"]) # using file path wandb_logger.log_image(key="samples", images=["img_1.jpg", "img_2.jpg"])
More arguments can be passed for logging segmentation masks and bounding boxes. Refer to Image Overlays documentation.
columns = ["caption", "image", "sound"] data = [["cheese", wandb.Image(img_1), wandb.Audio(snd_1)], ["wine", wandb.Image(img_2), wandb.Audio(snd_2)]] wandb_logger.log_table(key="samples", columns=columns, data=data)
Downloading and Using Artifacts
To download an artifact without starting a run, call the
download_artifact
function on the class:from lightning.pytorch.loggers import WandbLogger artifact_dir = WandbLogger.download_artifact(artifact="path/to/artifact")
To download an artifact and link it to an ongoing run call the
download_artifact
function on the logger instance:class MyModule(LightningModule): def any_lightning_module_function_or_hook(self): self.logger.download_artifact(artifact="path/to/artifact")
To link an artifact from a previous run you can use
use_artifact
function:from lightning.pytorch.loggers import WandbLogger wandb_logger = WandbLogger(project="my_project", name="my_run") wandb_logger.use_artifact(artifact="path/to/artifact")
See also
Demo in Google Colab with hyperparameter search and model logging
- Parameters:
version¶ (
Optional
[str
]) – Sets the version, mainly used to resume a previous run.offline¶ (
bool
) – Run offline (data can be streamed later to wandb servers).anonymous¶ (
Optional
[bool
]) – Enables or explicitly disables anonymous logging.project¶ (
Optional
[str
]) – The name of the project to which this run will belong. If not set, the environment variable WANDB_PROJECT will be used as a fallback. If both are not set, it defaults to'lightning_logs'
.log_model¶ (
Union
[Literal
['all'
],bool
]) –Log checkpoints created by
ModelCheckpoint
as W&B artifacts. latest and best aliases are automatically set.if
log_model == 'all'
, checkpoints are logged during training.if
log_model == True
, checkpoints are logged at the end of training, except whensave_top_k
== -1
which also logs every checkpoint during training.if
log_model == False
(default), no checkpoint is logged.
prefix¶ (
str
) – A string to put at the beginning of metric keys.experiment¶ (
Union
[Run
,RunDisabled
,None
]) – WandB experiment object. Automatically set when creating a run.checkpoint_name¶ (
Optional
[str
]) – Name of the model checkpoint artifact being logged.add_file_policy¶ (
Literal
['mutable'
,'immutable'
]) – If “mutable”, copies file to tempdirectory before upload.**kwargs¶ (
Any
) – Arguments passed towandb.init()
like entity, group, tags, etc.
- Raises:
ModuleNotFoundError – If required WandB package is not installed on the device.
MisconfigurationException – If both
log_model
andoffline
is set toTrue
.
- after_save_checkpoint(checkpoint_callback)[source]¶
Called after model checkpoint callback saves a new checkpoint.
- Parameters:
checkpoint_callback¶ (
ModelCheckpoint
) – the model checkpoint callback instance- Return type:
- static download_artifact(artifact, save_dir=None, artifact_type=None, use_artifact=True)[source]¶
Downloads an artifact from the wandb server.
- Parameters:
- Return type:
- Returns:
The path to the downloaded artifact.
- log_audio(key, audios, step=None, **kwargs)[source]¶
Log audios (numpy arrays, or file paths).
- Parameters:
- Return type:
Optional kwargs are lists passed to each audio (ex: caption, sample_rate).
- log_image(key, images, step=None, **kwargs)[source]¶
Log images (tensors, numpy arrays, PIL Images or file paths).
Optional kwargs are lists passed to each image (ex: caption, masks, boxes).
- Return type:
- log_metrics(metrics, step=None)[source]¶
Records metrics. This method logs metrics as soon as it received them.
- log_table(key, columns=None, data=None, dataframe=None, step=None)[source]¶
Log a Table containing any object type (text, image, audio, video, molecule, html, etc).
Can be defined either with columns and data or with dataframe.
- Return type:
- log_text(key, columns=None, data=None, dataframe=None, step=None)[source]¶
Log text as a Table.
Can be defined either with columns and data or with dataframe.
- Return type:
- log_video(key, videos, step=None, **kwargs)[source]¶
Log videos (numpy arrays, or file paths).
- Parameters:
- Return type:
Optional kwargs are lists passed to each video (ex: caption, fps, format).
- use_artifact(artifact, artifact_type=None)[source]¶
Logs to the wandb dashboard that the mentioned artifact is used by the run.
- property experiment: Union[wandb.wandb_run.Run, wandb.sdk.lib.RunDisabled]¶
Actual wandb object. To use wandb features in your
LightningModule
do the following.Example:
.. code-block:: python
self.logger.experiment.some_wandb_function()