for t in chain(self.module.parameters(), self.module.buffers()): if t.device != self.src_device_obj: raise RuntimeError("module must have its parameters and buffers " "on device {} (device_ids[0]) but found one of " "them on device: {}".format(self.src_device_obj, t.device))
defscatter(inputs, target_gpus, dim=0): r""" Slices tensors into approximately equal chunks and distributes them across given GPUs. Duplicates references to objects that are not tensors. """ defscatter_map(obj): ifisinstance(obj, torch.Tensor): return Scatter.apply(target_gpus, None, dim, obj) if is_namedtuple(obj): return [type(obj)(*args) for args inzip(*map(scatter_map, obj))] ifisinstance(obj, tuple) andlen(obj) > 0: returnlist(zip(*map(scatter_map, obj))) ifisinstance(obj, list) andlen(obj) > 0: return [list(i) for i inzip(*map(scatter_map, obj))] ifisinstance(obj, dict) andlen(obj) > 0: return [type(obj)(i) for i inzip(*map(scatter_map, obj.items()))] return [obj for targets in target_gpus]
# After scatter_map is called, a scatter_map cell will exist. This cell # has a reference to the actual function scatter_map, which has references # to a closure that has a reference to the scatter_map cell (because the # fn is recursive). To avoid this reference cycle, we set the function to # None, clearing the cell try: res = scatter_map(inputs) finally: scatter_map = None return res
@staticmethod defforward(ctx, target_gpus, chunk_sizes, dim, input): target_gpus = [_get_device_index(x, True) for x in target_gpus] ctx.dim = dim ctx.input_device = input.get_device() ifinput.device.type != "cpu"else -1 streams = None if torch.cuda.is_available() and ctx.input_device == -1: # Perform CPU to GPU copies in a background stream
# 新建 cuda stream streams = [_get_stream(device) for device in target_gpus]
# Synchronize with the copy stream if streams isnotNone: for i, output inenumerate(outputs): with torch.cuda.device(target_gpus[i]): main_stream = torch.cuda.current_stream() main_stream.wait_stream(streams[i]) output.record_stream(main_stream) return outputs
# 现在开始拷贝网络 # 准备过程:将 network.modules() 变成list # 然后再为之后复制的模型准备好空的 list 和 indices
modules = list(network.modules()) module_copies = [[] for device in devices] module_indices = {} scriptmodule_skip_attr = {"_parameters", "_buffers", "_modules", "forward", "_c"}
for i, module inenumerate(modules): module_indices[module] = i for j inrange(num_replicas): replica = module._replicate_for_data_parallel() # This is a temporary fix for DDP. DDP needs to access the # replicated model parameters. It used to do so through # `mode.parameters()`. The fix added in #33907 for DP stops the # `parameters()` API from exposing the replicated parameters. # Hence, we add a `_former_parameters` dict here to support DDP. replica._former_parameters = OrderedDict()
module_copies[j].append(replica)
# 接下来分别复制 module,param,buffer for i, module inenumerate(modules): for key, child in module._modules.items(): if child isNone: for j inrange(num_replicas): replica = module_copies[j][i] replica._modules[key] = None else: module_idx = module_indices[child] for j inrange(num_replicas): replica = module_copies[j][i] setattr(replica, key, module_copies[j][module_idx]) for key, param in module._parameters.items(): if param isNone: for j inrange(num_replicas): replica = module_copies[j][i] replica._parameters[key] = None else: param_idx = param_indices[param] for j inrange(num_replicas): replica = module_copies[j][i] param = param_copies[j][param_idx] # parameters in replicas are no longer leaves, # so setattr them as non-parameter attributes setattr(replica, key, param) # expose the parameter for DDP replica._former_parameters[key] = param for key, buf in module._buffers.items(): if buf isNone: for j inrange(num_replicas): replica = module_copies[j][i] replica._buffers[key] = None else: if buf.requires_grad andnot detach: buffer_copies = buffer_copies_rg buffer_idx = buffer_indices_rg[buf] else: buffer_copies = buffer_copies_not_rg buffer_idx = buffer_indices_not_rg[buf] for j inrange(num_replicas): replica = module_copies[j][i] setattr(replica, key, buffer_copies[j][buffer_idx])
return [module_copies[j][0] for j inrange(num_replicas)]
# 先看 else 的 comment,因为不 detach 也会用到同样的函数 if detach: return comm.broadcast_coalesced(tensors, devices) else: # Use the autograd function to broadcast if not detach iflen(tensors) > 0:
return [tensor_copies[i:i + len(tensors)] for i inrange(0, len(tensor_copies), len(tensors))] else: return []
# Broadcast.apply classBroadcast(Function):
@staticmethod defforward(ctx, target_gpus, *inputs): assertall(i.device.type != 'cpu'for i in inputs), ( 'Broadcast function not implemented for CPU tensors' ) target_gpus = [_get_device_index(x, True) for x in target_gpus] ctx.target_gpus = target_gpus iflen(inputs) == 0: returntuple() ctx.num_inputs = len(inputs) # input 放在 device[0] ctx.input_device = inputs[0].get_device()
non_differentiables = [] for idx, input_requires_grad inenumerate(ctx.needs_input_grad[1:]): ifnot input_requires_grad: for output in outputs: non_differentiables.append(output[idx]) ctx.mark_non_differentiable(*non_differentiables) returntuple([t for tensors in outputs for t in tensors])
# 源码 defgather(outputs, target_device, dim=0): r""" Gathers tensors from different GPUs on a specified device (-1 means the CPU). """ defgather_map(outputs): out = outputs[0] ifisinstance(out, torch.Tensor): return Gather.apply(target_device, dim, *outputs) if out isNone: returnNone ifisinstance(out, dict): ifnotall((len(out) == len(d) for d in outputs)): raise ValueError('All dicts must have the same number of keys') returntype(out)(((k, gather_map([d[k] for d in outputs])) for k in out)) returntype(out)(map(gather_map, zip(*outputs)))
# Recursive function calls like this create reference cycles. # Setting the function to None clears the refcycle. try: res = gather_map(outputs) finally: gather_map = None return res
# Gather 源码
classGather(Function):
@staticmethod defforward(ctx, target_device, dim, *inputs): assertall(i.device.type != 'cpu'for i in inputs), ( 'Gather function not implemented for CPU tensors' )
ctx.dim = dim ctx.input_gpus = tuple(i.get_device() for i in inputs)
ifall(t.dim() == 0for t in inputs) and dim == 0: inputs = tuple(t.view(1) for t in inputs) warnings.warn('Was asked to gather along dimension 0, but all ' 'input tensors were scalars; will instead unsqueeze ' 'and return a vector.') ctx.unsqueezed_scalar = True else: ctx.unsqueezed_scalar = False ctx.input_sizes = tuple(i.size(ctx.dim) for i in inputs) return comm.gather(inputs, ctx.dim, ctx.target_device)
@staticmethod defbackward(ctx, grad_output): scattered_grads = Scatter.apply(ctx.input_gpus, ctx.input_sizes, ctx.dim, grad_output) if ctx.unsqueezed_scalar: scattered_grads = tuple(g[0] for g in scattered_grads) return (None, None) + scattered_grads
# comm.gather 涉及到 C++,具体实现咱也不讲了 ;) # Gathers tensors from multiple GPU devices. defgather(tensors, dim=0, destination=None, *, out=None): tensors = [_handle_complex(t) for t in tensors] if out isNone: if destination == -1: warnings.warn( 'Using -1 to represent CPU tensor is deprecated. Please use a ' 'device object or string instead, e.g., "cpu".') destination = _get_device_index(destination, allow_cpu=True, optional=True) return torch._C._gather(tensors, dim, destination) else: if destination isnotNone: raise RuntimeError( "'destination' must not be specified when 'out' is specified, but " "got destination={}".format(destination)) return torch._C._gather_out(tensors, out, dim)
The difference between DistributedDataParallel and DataParallelis:DistributedDataParallel uses multiprocessing where a process is created for each GPU, while DataParallel uses multithreading. By using multiprocessing, each GPU has its dedicated process, this avoids the performance overhead caused by GIL of Python interpreter.