forked from zlckanata/DeepGlobe-Road-Extraction-Challenge
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_ddp.py
More file actions
25 lines (20 loc) · 922 Bytes
/
Copy pathtrain_ddp.py
File metadata and controls
25 lines (20 loc) · 922 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
import torch
from torch.utils.data import DataLoader, DistributedSampler
import os
from dataloader.fusion_dataset import TLCGISDataset, multi_loader
def get_dataloader(imagelist_path, batchsize, DATA_ROOT = 'data/TLCGIS/', data_loader=multi_loader):
with open(imagelist_path) as file:
imagelist = file.readlines()
imagelist = list(map(lambda x: x[:-1], imagelist))
dataset = TLCGISDataset(imagelist, DATA_ROOT, loader=data_loader)
sampler = DistributedSampler(dataset, num_replicas=2,
rank=dist.get_rank(), shuffle=True)
dataloader = DataLoader(dataset,
batch_size=batchsize,
num_workers=16,
sampler=sampler,
shuffle=False)
def main():
pass