-
Notifications
You must be signed in to change notification settings - Fork 516
/
fsdp_wrapper.py
217 lines (193 loc) · 8.71 KB
/
fsdp_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
#
import argparse
import inspect
from typing import Any, Dict, Tuple
import torch
from torch.distributed.fsdp import (
BackwardPrefetch,
CPUOffload,
FullStateDictConfig,
FullyShardedDataParallel,
MixedPrecision,
ShardingStrategy,
StateDictType,
)
from torch.distributed.fsdp.api import FullOptimStateDictConfig
from corenet.modeling.models import BaseAnyNNModel
from corenet.optims.base_optim import BaseOptim
from corenet.utils import logger
from corenet.utils.ddp_utils import is_master
FSDP_SHARDING_STRATEGY_MAP = {
# In full shard, parameters, gradients, and optimizer states are sharded (aka ZERO 3)
"full_shard": ShardingStrategy.FULL_SHARD,
# hybrid_shard is the same as full shard, except sharding is done within a node.
# TODO: Revisit hybrid sharding in future because of the below issue.
# https://github.com/pytorch/pytorch/issues/102904
# "hybrid_shard": ShardingStrategy.HYBRID_SHARD,
# In no-shard, parameters, gradients, and optimizer states are not sharded
"no_shard": ShardingStrategy.NO_SHARD,
# In grad_op_shard, gradients and optimizer states are sharded (aka as Zero)
"grad_op_shard": ShardingStrategy.SHARD_GRAD_OP,
}
FSDP_DATATYPE_CONVERSION = {
"float16": torch.float16,
"float32": torch.float32,
"bfloat16": torch.bfloat16,
}
FSDP_BACKWARD_PREFETCH = {
# pre enables prefetching next set of parameters before computing gradients for current set of parameters.
"pre": BackwardPrefetch.BACKWARD_PRE,
# post enables prefetching next set of parameters after computing gradients for current set of parameters.
"post": BackwardPrefetch.BACKWARD_POST,
}
class FullyShardedDataParallelWrapper(FullyShardedDataParallel):
def __init__(
self,
opts: argparse.Namespace,
model: BaseAnyNNModel,
) -> None:
param_dtype = getattr(opts, "fsdp.parameter_datatype")
reduce_dtype = getattr(opts, "fsdp.gradient_reduction_datatype")
buffer_dtype = getattr(opts, "fsdp.buffer_datatype")
if param_dtype not in FSDP_DATATYPE_CONVERSION:
logger.error(
f"Supported data type for parameters in FSDP are {list(FSDP_DATATYPE_CONVERSION.keys())}. \
Got: {param_dtype}."
)
if reduce_dtype not in FSDP_DATATYPE_CONVERSION:
logger.error(
f"Supported data type for gradient reduction in FSDP are {list(FSDP_DATATYPE_CONVERSION.keys())}. \
Got: {reduce_dtype}."
)
if buffer_dtype not in FSDP_DATATYPE_CONVERSION:
logger.error(
f"Supported data type for buffer in FSDP are {list(FSDP_DATATYPE_CONVERSION.keys())}. \
Got: {buffer_dtype}."
)
prefetching_option = getattr(opts, "fsdp.backward_prefetching")
if prefetching_option not in FSDP_BACKWARD_PREFETCH:
logger.error(
f"Supported backward pre-fetching options are {list(FSDP_BACKWARD_PREFETCH.keys())}. \
Got: {prefetching_option}."
)
fsdp_precision_policy = MixedPrecision(
param_dtype=FSDP_DATATYPE_CONVERSION[param_dtype],
reduce_dtype=FSDP_DATATYPE_CONVERSION[reduce_dtype],
buffer_dtype=FSDP_DATATYPE_CONVERSION[buffer_dtype],
)
fsdp_parameters = inspect.signature(FullyShardedDataParallel).parameters
# Enabling `use_orig_params` tells FSDP not to flatten parameters, and enables us to specify different LR/weight decay values.
# `use_orig_params` feature is available in PyTorch versions > 2.0
extra_args_fsdp = (
dict(use_orig_params=True)
if "use_orig_params" in fsdp_parameters
else dict()
)
if "limit_all_gathers" in fsdp_parameters and getattr(
opts, "fsdp.limit_all_gathers"
):
extra_args_fsdp["limit_all_gathers"] = True
if "cpu_offload" in fsdp_parameters and getattr(opts, "fsdp.cpu_offload"):
extra_args_fsdp["cpu_offload"] = CPUOffload(offload_params=True)
sharding_strategy = getattr(opts, "fsdp.sharding_strategy")
if sharding_strategy not in FSDP_SHARDING_STRATEGY_MAP:
logger.error(
f"Supported sharding strategies for FSDP are: {list(FSDP_SHARDING_STRATEGY_MAP.keys())}. Got: {sharding_strategy}."
)
# get fsdp wrapping policy
fsdp_wrap_policy = model.get_fsdp_wrap_policy()
super().__init__(
model,
sharding_strategy=FSDP_SHARDING_STRATEGY_MAP[sharding_strategy],
auto_wrap_policy=fsdp_wrap_policy,
mixed_precision=fsdp_precision_policy,
backward_prefetch=FSDP_BACKWARD_PREFETCH[prefetching_option],
**extra_args_fsdp,
)
@classmethod
def add_arguments(cls, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
"""Add FSDP-specific arguments"""
if cls == FullyShardedDataParallelWrapper:
group = parser.add_argument_group(title=cls.__name__)
group.add_argument(
"--fsdp.sharding-strategy",
type=str,
default=None,
choices=list(FSDP_SHARDING_STRATEGY_MAP.keys()),
help="Sharding strategy for FSDP. Defaults to None.",
)
group.add_argument(
"--fsdp.backward-prefetching",
type=str,
default="pre",
choices=["pre", "post"],
help="Backward prefetching. Supported modes are `pre` and `post`. \
`pre` and `post` prefetches the next set of parameters before and after \
the current set of parameter's gradient computation respectively. \
Defaults to `pre`.",
)
group.add_argument(
"--fsdp.parameter-datatype",
type=str,
default="bfloat16",
choices=list(FSDP_DATATYPE_CONVERSION.keys()),
help="Specify the data type of model parameters. See FSDP documentation for details. \
Defaults to `bfloat16`.",
)
group.add_argument(
"--fsdp.gradient-reduction-datatype",
type=str,
default="bfloat16",
choices=list(FSDP_DATATYPE_CONVERSION.keys()),
help="Specify the data type for gradient reduction. See FSDP documentation for details. \
Defaults to `bfloat16`.",
)
group.add_argument(
"--fsdp.buffer-datatype",
type=str,
default="bfloat16",
choices=list(FSDP_DATATYPE_CONVERSION.keys()),
help="Specify the data type for buffers. See FSDP documentation for details. \
Defaults to `bfloat16`.",
)
group.add_argument(
"--fsdp.limit-all-gathers",
action="store_true",
help="Enabling this flag allows FSDP to explicitly synchronize the CPU threads and \
prevent too many in-flight all-gathers. Enabling this can \
help lower the number of CUDA malloc retries. Defaults to `False`. \
Note: In older PyTorch versions, this flag may not be available.",
)
group.add_argument(
"--fsdp.cpu-offload",
action="store_true",
help="Enable CPU offloading. Defaults to `False`. \
Note: In older PyTorch versions, this flag may not be available.",
)
return parser
def get_fsdp_model_optimizer_state_dict_on_rank0(
model: FullyShardedDataParallelWrapper, optimizer: BaseOptim
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
Aggregates the model and optimizer states from all shards on rank0 and return it.
Args:
model: Model (partially) sharded by FSDP.
optimizer: Optimizer.
"""
with FullyShardedDataParallelWrapper.state_dict_type(
model,
StateDictType.FULL_STATE_DICT,
# config for model state aggregation
FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
# config for optimizer state aggregation
FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
):
model_state = model.state_dict()
# Returns the state dict of optimzier for the ``model`` that is (partially) sharded by FSDP.
optim_state = FullyShardedDataParallel.optim_state_dict(
model=model, optim=optimizer
)
return model_state, optim_state