WandB & PyTorch Lightning Implementation Workflow

Workflow Manual Watch The idea is to call the logger’s watch() method yourself. Remove log_graph and its related arguments (log, log_freq) from the logger initialization. After you create your Trainer instance, manually call wandb_logger.watch(). It’s important to do this after the Trainer is created, because the trainer’s initialization is what actually calls wandb.init() and starts the run. Here is the modified code for the workaround: 1# wandb Guidance 2 3import wandb 4 5import torch 6from torch.utils.data import DataLoader, TensorDataset 7 8import pytorch_lightning as pl 9from pytorch_lightning.loggers import WandbLogger 10 11# --- Model and Data Setup (same as before) --- 12class LitModel(pl.LightningModule): 13 def __init__(self): 14 super().__init__() 15 self.layer_1 = torch.nn.Linear(32, 64) 16 self.layer_2 = torch.nn.Linear(64, 10) 17 def forward(self, x): return self.layer_2(torch.relu(self.layer_1(x))) 18 def training_step(self, batch, batch_idx): 19 x, y = batch 20 loss = torch.nn.functional.cross_entropy(self(x), y) 21 self.log("train_loss", loss) 22 return loss 23 def configure_optimizers(self): return torch.optim.Adam(self.parameters()) 24 25X = torch.randn(100, 32) 26y = torch.randint(0, 10, (100,)) 27train_loader = DataLoader(TensorDataset(X, y), batch_size=8) 28# --- End of Setup --- 29 30 31# 1. Initialize the logger WITHOUT log_graph 32wandb_logger = WandbLogger( 33 project="my-awesome-project", 34 log_model=True # log_model is fine, it's a separate feature 35) 36 37model = LitModel() 38 39# 2. Initialize the Trainer 40trainer = pl.Trainer( 41 max_epochs=5, 42 logger=wandb_logger, 43 log_every_n_steps=1 44) 45 46# 3. THE WORKAROUND: Manually call watch() on the logger instance 47# This must be done AFTER the Trainer is initialized. 48wandb_logger.watch( 49 model=model, 50 log="all", 51 log_freq=8, # default: 100 52 log_graph=True, 53) 54 55# 4. Run training 56trainer.fit(model, train_loader) 57 58wandb.finish() Summary Method Pros Cons When to Use Upgrade Libraries - Simpler, cleaner code (log_graph=True).- Aligned with modern documentation.- Access to all new features. Requires updating your environment. This is the recommended approach for almost all cases. Manual watch() Works on old library versions without an upgrade. - More verbose code.- You need to remember the correct place to call watch(). Only if you are strictly forbidden from upgrading your Python packages.

June 10, 2025 · 2 min · 329 words · xxraincandyxx