Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions timm/layers/fast_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ def fast_group_norm(
# normally native AMP casts GN inputs to float32
# here we use the low precision autocast dtype
dt = get_autocast_dtype(x.device.type)
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None
x, weight, bias = (
x.to(dt),
weight.to(dt) if weight is not None else None,
bias.to(dt) if bias is not None else None,
)

with torch.amp.autocast(device_type=x.device.type, enabled=False):
return F.group_norm(x, num_groups, weight, bias, eps)
Expand All @@ -102,7 +106,11 @@ def fast_layer_norm(
# normally native AMP casts LN inputs to float32
# apex LN does not, this is behaving like Apex
dt = get_autocast_dtype(x.device.type)
x, weight, bias = x.to(dt), weight.to(dt), bias.to(dt) if bias is not None else None
x, weight, bias = (
x.to(dt),
weight.to(dt) if weight is not None else None,
bias.to(dt) if bias is not None else None,
)

with torch.amp.autocast(device_type=x.device.type, enabled=False):
return F.layer_norm(x, normalized_shape, weight, bias, eps)
Expand Down Expand Up @@ -151,7 +159,7 @@ def fast_rms_norm(
# normally native AMP casts LN inputs to float32 and leaves the output as float32
# apex LN does not, this is behaving like Apex
dt = get_autocast_dtype(x.device.type)
x, weight = x.to(dt), weight.to(dt)
x, weight = x.to(dt), weight.to(dt) if weight is not None else None

with torch.amp.autocast(device_type=x.device.type, enabled=False):
if has_torch_rms_norm:
Expand Down Expand Up @@ -199,7 +207,7 @@ def fast_rms_norm2d(
# normally native AMP casts norm inputs to float32 and leaves the output as float32
# apex does not, this is behaving like Apex
dt = get_autocast_dtype(x.device.type)
x, weight = x.to(dt), weight.to(dt)
x, weight = x.to(dt), weight.to(dt) if weight is not None else None

with torch.amp.autocast(device_type=x.device.type, enabled=False):
x = rms_norm2d(x, normalized_shape, weight, eps)
Expand Down Expand Up @@ -243,7 +251,7 @@ def fast_simple_norm(
# normally native AMP casts LN inputs to float32
# apex LN does not, this is behaving like Apex
dt = get_autocast_dtype(x.device.type)
x, weight = x.to(dt), weight.to(dt)
x, weight = x.to(dt), weight.to(dt) if weight is not None else None

with torch.amp.autocast(device_type=x.device.type, enabled=False):
x = simple_norm(x, normalized_shape, weight, eps)
Expand Down