From b4d7ddd91e5338c4fdac116ec9e7443e56a54776 Mon Sep 17 00:00:00 2001 From: PSImera Date: Sat, 2 May 2026 21:49:07 +0400 Subject: [PATCH 1/2] Replace deprecated torch._utils._accumulate with itertools.accumulate --- trainer/dataset.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/trainer/dataset.py b/trainer/dataset.py index a6485eea3a9..98de77162ed 100644 --- a/trainer/dataset.py +++ b/trainer/dataset.py @@ -10,9 +10,10 @@ from PIL import Image import numpy as np from torch.utils.data import Dataset, ConcatDataset, Subset -from torch._utils import _accumulate +from itertools import accumulate as _accumulate import torchvision.transforms as transforms + def contrast_grey(img): high = np.percentile(img, 90) low = np.percentile(img, 10) @@ -98,12 +99,12 @@ def get_batch(self): for i, data_loader_iter in enumerate(self.dataloader_iter_list): try: - image, text = data_loader_iter.next() + image, text = next(data_loader_iter) balanced_batch_images.append(image) balanced_batch_texts += text except StopIteration: self.dataloader_iter_list[i] = iter(self.data_loader_list[i]) - image, text = self.dataloader_iter_list[i].next() + image, text = next(self.dataloader_iter_list[i]) balanced_batch_images.append(image) balanced_batch_texts += text except ValueError: From 59ee52fe42971c8a20d10f3159efd6b101fd2071 Mon Sep 17 00:00:00 2001 From: PSImera Date: Sat, 2 May 2026 21:52:03 +0400 Subject: [PATCH 2/2] Fix device mismatch in CTC loss and consolidate trainer into notebook Explicitly move tensors to device before CTCLoss call in validation(), and inline train.py and test.py into trainer.ipynb. --- trainer/test.py | 112 --------- trainer/train.py | 282 --------------------- trainer/trainer.ipynb | 568 ++++++++++++++++++++++++++++++++++++++---- 3 files changed, 522 insertions(+), 440 deletions(-) delete mode 100644 trainer/test.py delete mode 100644 trainer/train.py diff --git a/trainer/test.py b/trainer/test.py deleted file mode 100644 index 48eaa76c750..00000000000 --- a/trainer/test.py +++ /dev/null @@ -1,112 +0,0 @@ -import os -import time -import string -import argparse - -import torch -import torch.backends.cudnn as cudnn -import torch.utils.data -import torch.nn.functional as F -import numpy as np -from nltk.metrics.distance import edit_distance - -from utils import CTCLabelConverter, AttnLabelConverter, Averager -from dataset import hierarchical_dataset, AlignCollate -from model import Model - -def validation(model, criterion, evaluation_loader, converter, opt, device): - """ validation or evaluation """ - n_correct = 0 - norm_ED = 0 - length_of_data = 0 - infer_time = 0 - valid_loss_avg = Averager() - - for i, (image_tensors, labels) in enumerate(evaluation_loader): - batch_size = image_tensors.size(0) - length_of_data = length_of_data + batch_size - image = image_tensors.to(device) - # For max length prediction - length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(device) - text_for_pred = torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device) - - text_for_loss, length_for_loss = converter.encode(labels, batch_max_length=opt.batch_max_length) - - start_time = time.time() - if 'CTC' in opt.Prediction: - preds = model(image, text_for_pred) - forward_time = time.time() - start_time - - # Calculate evaluation loss for CTC decoder. - preds_size = torch.IntTensor([preds.size(1)] * batch_size) - # permute 'preds' to use CTCloss format - cost = criterion(preds.log_softmax(2).permute(1, 0, 2), text_for_loss, preds_size, length_for_loss) - - if opt.decode == 'greedy': - # Select max probabilty (greedy decoding) then decode index to character - _, preds_index = preds.max(2) - preds_index = preds_index.view(-1) - preds_str = converter.decode_greedy(preds_index.data, preds_size.data) - elif opt.decode == 'beamsearch': - preds_str = converter.decode_beamsearch(preds, beamWidth=2) - - else: - preds = model(image, text_for_pred, is_train=False) - forward_time = time.time() - start_time - - preds = preds[:, :text_for_loss.shape[1] - 1, :] - target = text_for_loss[:, 1:] # without [GO] Symbol - cost = criterion(preds.contiguous().view(-1, preds.shape[-1]), target.contiguous().view(-1)) - - # select max probabilty (greedy decoding) then decode index to character - _, preds_index = preds.max(2) - preds_str = converter.decode(preds_index, length_for_pred) - labels = converter.decode(text_for_loss[:, 1:], length_for_loss) - - infer_time += forward_time - valid_loss_avg.add(cost) - - # calculate accuracy & confidence score - preds_prob = F.softmax(preds, dim=2) - preds_max_prob, _ = preds_prob.max(dim=2) - confidence_score_list = [] - - for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob): - if 'Attn' in opt.Prediction: - gt = gt[:gt.find('[s]')] - pred_EOS = pred.find('[s]') - pred = pred[:pred_EOS] # prune after "end of sentence" token ([s]) - pred_max_prob = pred_max_prob[:pred_EOS] - - if pred == gt: - n_correct += 1 - - ''' - (old version) ICDAR2017 DOST Normalized Edit Distance https://rrc.cvc.uab.es/?ch=7&com=tasks - "For each word we calculate the normalized edit distance to the length of the ground truth transcription." - if len(gt) == 0: - norm_ED += 1 - else: - norm_ED += edit_distance(pred, gt) / len(gt) - ''' - - # ICDAR2019 Normalized Edit Distance - if len(gt) == 0 or len(pred) ==0: - norm_ED += 0 - elif len(gt) > len(pred): - norm_ED += 1 - edit_distance(pred, gt) / len(gt) - else: - norm_ED += 1 - edit_distance(pred, gt) / len(pred) - - # calculate confidence score (= multiply of pred_max_prob) - try: - confidence_score = pred_max_prob.cumprod(dim=0)[-1] - except: - confidence_score = 0 # for empty pred case, when prune after "end of sentence" token ([s]) - confidence_score_list.append(confidence_score) - # print(pred, gt, pred==gt, confidence_score) - - accuracy = n_correct / float(length_of_data) * 100 - norm_ED = norm_ED / float(length_of_data) # ICDAR2019 Normalized Edit Distance - - return valid_loss_avg.val(), accuracy, norm_ED, preds_str, confidence_score_list, labels, infer_time, length_of_data diff --git a/trainer/train.py b/trainer/train.py deleted file mode 100644 index e0066f3d078..00000000000 --- a/trainer/train.py +++ /dev/null @@ -1,282 +0,0 @@ -import os -import sys -import time -import random -import torch -import torch.backends.cudnn as cudnn -import torch.nn as nn -import torch.nn.init as init -import torch.optim as optim -import torch.utils.data -from torch.cuda.amp import autocast, GradScaler -import numpy as np - -from utils import CTCLabelConverter, AttnLabelConverter, Averager -from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset -from model import Model -from test import validation -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - -def count_parameters(model): - print("Modules, Parameters") - total_params = 0 - for name, parameter in model.named_parameters(): - if not parameter.requires_grad: continue - param = parameter.numel() - #table.add_row([name, param]) - total_params+=param - print(name, param) - print(f"Total Trainable Params: {total_params}") - return total_params - -def train(opt, show_number = 2, amp=False): - """ dataset preparation """ - if not opt.data_filtering_off: - print('Filtering the images containing characters which are not in opt.character') - print('Filtering the images whose label is longer than opt.batch_max_length') - - opt.select_data = opt.select_data.split('-') - opt.batch_ratio = opt.batch_ratio.split('-') - train_dataset = Batch_Balanced_Dataset(opt) - - log = open(f'./saved_models/{opt.experiment_name}/log_dataset.txt', 'a', encoding="utf8") - AlignCollate_valid = AlignCollate(imgH=opt.imgH, imgW=opt.imgW, keep_ratio_with_pad=opt.PAD, contrast_adjust=opt.contrast_adjust) - valid_dataset, valid_dataset_log = hierarchical_dataset(root=opt.valid_data, opt=opt) - valid_loader = torch.utils.data.DataLoader( - valid_dataset, batch_size=min(32, opt.batch_size), - shuffle=True, # 'True' to check training progress with validation function. - num_workers=int(opt.workers), prefetch_factor=512, - collate_fn=AlignCollate_valid, pin_memory=True) - log.write(valid_dataset_log) - print('-' * 80) - log.write('-' * 80 + '\n') - log.close() - - """ model configuration """ - if 'CTC' in opt.Prediction: - converter = CTCLabelConverter(opt.character) - else: - converter = AttnLabelConverter(opt.character) - opt.num_class = len(converter.character) - - if opt.rgb: - opt.input_channel = 3 - model = Model(opt) - print('model input parameters', opt.imgH, opt.imgW, opt.num_fiducial, opt.input_channel, opt.output_channel, - opt.hidden_size, opt.num_class, opt.batch_max_length, opt.Transformation, opt.FeatureExtraction, - opt.SequenceModeling, opt.Prediction) - - if opt.saved_model != '': - pretrained_dict = torch.load(opt.saved_model) - if opt.new_prediction: - model.Prediction = nn.Linear(model.SequenceModeling_output, len(pretrained_dict['module.Prediction.weight'])) - - model = torch.nn.DataParallel(model).to(device) - print(f'loading pretrained model from {opt.saved_model}') - if opt.FT: - model.load_state_dict(pretrained_dict, strict=False) - else: - model.load_state_dict(pretrained_dict) - if opt.new_prediction: - model.module.Prediction = nn.Linear(model.module.SequenceModeling_output, opt.num_class) - for name, param in model.module.Prediction.named_parameters(): - if 'bias' in name: - init.constant_(param, 0.0) - elif 'weight' in name: - init.kaiming_normal_(param) - model = model.to(device) - else: - # weight initialization - for name, param in model.named_parameters(): - if 'localization_fc2' in name: - print(f'Skip {name} as it is already initialized') - continue - try: - if 'bias' in name: - init.constant_(param, 0.0) - elif 'weight' in name: - init.kaiming_normal_(param) - except Exception as e: # for batchnorm. - if 'weight' in name: - param.data.fill_(1) - continue - model = torch.nn.DataParallel(model).to(device) - - model.train() - print("Model:") - print(model) - count_parameters(model) - - """ setup loss """ - if 'CTC' in opt.Prediction: - criterion = torch.nn.CTCLoss(zero_infinity=True).to(device) - else: - criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(device) # ignore [GO] token = ignore index 0 - # loss averager - loss_avg = Averager() - - # freeze some layers - try: - if opt.freeze_FeatureFxtraction: - for param in model.module.FeatureExtraction.parameters(): - param.requires_grad = False - if opt.freeze_SequenceModeling: - for param in model.module.SequenceModeling.parameters(): - param.requires_grad = False - except: - pass - - # filter that only require gradient decent - filtered_parameters = [] - params_num = [] - for p in filter(lambda p: p.requires_grad, model.parameters()): - filtered_parameters.append(p) - params_num.append(np.prod(p.size())) - print('Trainable params num : ', sum(params_num)) - # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())] - - # setup optimizer - if opt.optim=='adam': - #optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999)) - optimizer = optim.Adam(filtered_parameters) - else: - optimizer = optim.Adadelta(filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps) - print("Optimizer:") - print(optimizer) - - """ final options """ - # print(opt) - with open(f'./saved_models/{opt.experiment_name}/opt.txt', 'a', encoding="utf8") as opt_file: - opt_log = '------------ Options -------------\n' - args = vars(opt) - for k, v in args.items(): - opt_log += f'{str(k)}: {str(v)}\n' - opt_log += '---------------------------------------\n' - print(opt_log) - opt_file.write(opt_log) - - """ start training """ - start_iter = 0 - if opt.saved_model != '': - try: - start_iter = int(opt.saved_model.split('_')[-1].split('.')[0]) - print(f'continue to train, start_iter: {start_iter}') - except: - pass - - start_time = time.time() - best_accuracy = -1 - best_norm_ED = -1 - i = start_iter - - scaler = GradScaler() - t1= time.time() - - while(True): - # train part - optimizer.zero_grad(set_to_none=True) - - if amp: - with autocast(): - image_tensors, labels = train_dataset.get_batch() - image = image_tensors.to(device) - text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) - batch_size = image.size(0) - - if 'CTC' in opt.Prediction: - preds = model(image, text).log_softmax(2) - preds_size = torch.IntTensor([preds.size(1)] * batch_size) - preds = preds.permute(1, 0, 2) - torch.backends.cudnn.enabled = False - cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device)) - torch.backends.cudnn.enabled = True - else: - preds = model(image, text[:, :-1]) # align with Attention.forward - target = text[:, 1:] # without [GO] Symbol - cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) - scaler.scale(cost).backward() - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) - scaler.step(optimizer) - scaler.update() - else: - image_tensors, labels = train_dataset.get_batch() - image = image_tensors.to(device) - text, length = converter.encode(labels, batch_max_length=opt.batch_max_length) - batch_size = image.size(0) - if 'CTC' in opt.Prediction: - preds = model(image, text).log_softmax(2) - preds_size = torch.IntTensor([preds.size(1)] * batch_size) - preds = preds.permute(1, 0, 2) - torch.backends.cudnn.enabled = False - cost = criterion(preds, text.to(device), preds_size.to(device), length.to(device)) - torch.backends.cudnn.enabled = True - else: - preds = model(image, text[:, :-1]) # align with Attention.forward - target = text[:, 1:] # without [GO] Symbol - cost = criterion(preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)) - cost.backward() - torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip) - optimizer.step() - loss_avg.add(cost) - - # validation part - if (i % opt.valInterval == 0) and (i!=0): - print('training time: ', time.time()-t1) - t1=time.time() - elapsed_time = time.time() - start_time - # for log - with open(f'./saved_models/{opt.experiment_name}/log_train.txt', 'a', encoding="utf8") as log: - model.eval() - with torch.no_grad(): - valid_loss, current_accuracy, current_norm_ED, preds, confidence_score, labels,\ - infer_time, length_of_data = validation(model, criterion, valid_loader, converter, opt, device) - model.train() - - # training loss and validation loss - loss_log = f'[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}' - loss_avg.reset() - - current_model_log = f'{"Current_accuracy":17s}: {current_accuracy:0.3f}, {"Current_norm_ED":17s}: {current_norm_ED:0.4f}' - - # keep best accuracy model (on valid dataset) - if current_accuracy > best_accuracy: - best_accuracy = current_accuracy - torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_accuracy.pth') - if current_norm_ED > best_norm_ED: - best_norm_ED = current_norm_ED - torch.save(model.state_dict(), f'./saved_models/{opt.experiment_name}/best_norm_ED.pth') - best_model_log = f'{"Best_accuracy":17s}: {best_accuracy:0.3f}, {"Best_norm_ED":17s}: {best_norm_ED:0.4f}' - - loss_model_log = f'{loss_log}\n{current_model_log}\n{best_model_log}' - print(loss_model_log) - log.write(loss_model_log + '\n') - - # show some predicted results - dashed_line = '-' * 80 - head = f'{"Ground Truth":25s} | {"Prediction":25s} | Confidence Score & T/F' - predicted_result_log = f'{dashed_line}\n{head}\n{dashed_line}\n' - - #show_number = min(show_number, len(labels)) - - start = random.randint(0,len(labels) - show_number ) - for gt, pred, confidence in zip(labels[start:start+show_number], preds[start:start+show_number], confidence_score[start:start+show_number]): - if 'Attn' in opt.Prediction: - gt = gt[:gt.find('[s]')] - pred = pred[:pred.find('[s]')] - - predicted_result_log += f'{gt:25s} | {pred:25s} | {confidence:0.4f}\t{str(pred == gt)}\n' - predicted_result_log += f'{dashed_line}' - print(predicted_result_log) - log.write(predicted_result_log + '\n') - print('validation time: ', time.time()-t1) - t1=time.time() - # save model per 1e+4 iter. - if (i + 1) % 1e+4 == 0: - torch.save( - model.state_dict(), f'./saved_models/{opt.experiment_name}/iter_{i+1}.pth') - - if i == opt.num_iter: - print('end the training') - sys.exit() - i += 1 diff --git a/trainer/trainer.ipynb b/trainer/trainer.ipynb index 712bf833412..33a16a71f1d 100644 --- a/trainer/trainer.ipynb +++ b/trainer/trainer.ipynb @@ -3,68 +3,544 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "end_time": "2021-07-23T04:19:23.488642Z", - "start_time": "2021-07-23T04:19:21.854534Z" - } - }, + "metadata": {}, "outputs": [], "source": [ "import os\n", - "import torch.backends.cudnn as cudnn\n", + "import sys\n", + "import time\n", "import yaml\n", - "from train import train\n", - "from utils import AttrDict\n", - "import pandas as pd" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "ExecuteTime": { - "end_time": "2021-07-23T04:19:23.885144Z", - "start_time": "2021-07-23T04:19:23.880564Z" - }, - "code_folding": [] - }, - "outputs": [], - "source": [ + "import random\n", + "import string\n", + "import argparse\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "import torch\n", + "import torch.backends.cudnn as cudnn\n", + "import torch.nn as nn\n", + "import torch.nn.init as init\n", + "import torch.optim as optim\n", + "import torch.utils.data\n", + "import torch.nn.functional as F\n", + "from torch.cuda.amp import autocast, GradScaler\n", + "\n", + "from nltk.metrics.distance import edit_distance\n", + "from utils import CTCLabelConverter, AttnLabelConverter, Averager, AttrDict\n", + "from dataset import hierarchical_dataset, AlignCollate, Batch_Balanced_Dataset\n", + "from model import Model\n", + "\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "cudnn.benchmark = True\n", - "cudnn.deterministic = False" + "cudnn.deterministic = False\n", + "\n", + "if torch.cuda.is_available():\n", + " print(\"CUDA is availeble\")\n", + " print(\"devices =\", torch.cuda.device_count())\n", + " print(\"current device: \", torch.cuda.get_device_name(torch.cuda.current_device()))\n", + "else:\n", + " print(\"CUDA not availeble, recomended to fix it for faster train\")" ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "ExecuteTime": { - "end_time": "2021-07-23T04:19:24.119144Z", - "start_time": "2021-07-23T04:19:24.112032Z" - }, - "code_folding": [ - 0 - ] - }, + "metadata": {}, "outputs": [], "source": [ + "def validation(model, criterion, evaluation_loader, converter, opt, device):\n", + " \"\"\"validation or evaluation\"\"\"\n", + " n_correct = 0\n", + " norm_ED = 0\n", + " length_of_data = 0\n", + " infer_time = 0\n", + " valid_loss_avg = Averager()\n", + "\n", + " for i, (image_tensors, labels) in enumerate(evaluation_loader):\n", + " batch_size = image_tensors.size(0)\n", + " length_of_data = length_of_data + batch_size\n", + " image = image_tensors.to(device)\n", + " # For max length prediction\n", + " length_for_pred = torch.IntTensor([opt.batch_max_length] * batch_size).to(\n", + " device\n", + " )\n", + " text_for_pred = (\n", + " torch.LongTensor(batch_size, opt.batch_max_length + 1).fill_(0).to(device)\n", + " )\n", + "\n", + " text_for_loss, length_for_loss = converter.encode(\n", + " labels, batch_max_length=opt.batch_max_length\n", + " )\n", + "\n", + " start_time = time.time()\n", + " if \"CTC\" in opt.Prediction:\n", + " preds = model(image, text_for_pred)\n", + " forward_time = time.time() - start_time\n", + "\n", + " # Calculate evaluation loss for CTC decoder.\n", + " preds_size = torch.IntTensor([preds.size(1)] * batch_size)\n", + " # permute 'preds' to use CTCloss format\n", + "\n", + " # Перед вызовом функции criterion\n", + " preds = preds.to(device)\n", + " text_for_loss = text_for_loss.to(device)\n", + " preds_size = preds_size.to(device)\n", + " length_for_loss = length_for_loss.to(device)\n", + "\n", + " cost = criterion(\n", + " preds.log_softmax(2).permute(1, 0, 2),\n", + " text_for_loss,\n", + " preds_size,\n", + " length_for_loss,\n", + " )\n", + "\n", + " if opt.decode == \"greedy\":\n", + " # Select max probabilty (greedy decoding) then decode index to character\n", + " _, preds_index = preds.max(2)\n", + " preds_index = preds_index.view(-1)\n", + " preds_str = converter.decode_greedy(preds_index.data, preds_size.data)\n", + " elif opt.decode == \"beamsearch\":\n", + " preds_str = converter.decode_beamsearch(preds, beamWidth=2)\n", + "\n", + " else:\n", + " preds = model(image, text_for_pred, is_train=False)\n", + " forward_time = time.time() - start_time\n", + "\n", + " preds = preds[:, : text_for_loss.shape[1] - 1, :]\n", + " target = text_for_loss[:, 1:] # without [GO] Symbol\n", + " cost = criterion(\n", + " preds.contiguous().view(-1, preds.shape[-1]),\n", + " target.contiguous().view(-1),\n", + " )\n", + "\n", + " # select max probabilty (greedy decoding) then decode index to character\n", + " _, preds_index = preds.max(2)\n", + " preds_str = converter.decode(preds_index, length_for_pred)\n", + " labels = converter.decode(text_for_loss[:, 1:], length_for_loss)\n", + "\n", + " infer_time += forward_time\n", + " valid_loss_avg.add(cost)\n", + "\n", + " # calculate accuracy & confidence score\n", + " preds_prob = F.softmax(preds, dim=2)\n", + " preds_max_prob, _ = preds_prob.max(dim=2)\n", + " confidence_score_list = []\n", + "\n", + " for gt, pred, pred_max_prob in zip(labels, preds_str, preds_max_prob):\n", + " if \"Attn\" in opt.Prediction:\n", + " gt = gt[: gt.find(\"[s]\")]\n", + " pred_EOS = pred.find(\"[s]\")\n", + " pred = pred[:pred_EOS] # prune after \"end of sentence\" token ([s])\n", + " pred_max_prob = pred_max_prob[:pred_EOS]\n", + "\n", + " if pred == gt:\n", + " n_correct += 1\n", + "\n", + " \"\"\"\n", + " (old version) ICDAR2017 DOST Normalized Edit Distance https://rrc.cvc.uab.es/?ch=7&com=tasks\n", + " \"For each word we calculate the normalized edit distance to the length of the ground truth transcription.\" \n", + " if len(gt) == 0:\n", + " norm_ED += 1\n", + " else:\n", + " norm_ED += edit_distance(pred, gt) / len(gt)\n", + " \"\"\"\n", + "\n", + " # ICDAR2019 Normalized Edit Distance\n", + " if len(gt) == 0 or len(pred) == 0:\n", + " norm_ED += 0\n", + " elif len(gt) > len(pred):\n", + " norm_ED += 1 - edit_distance(pred, gt) / len(gt)\n", + " else:\n", + " norm_ED += 1 - edit_distance(pred, gt) / len(pred)\n", + "\n", + " # calculate confidence score (= multiply of pred_max_prob)\n", + " try:\n", + " confidence_score = pred_max_prob.cumprod(dim=0)[-1]\n", + " except:\n", + " confidence_score = 0 # for empty pred case, when prune after \"end of sentence\" token ([s])\n", + " confidence_score_list.append(confidence_score)\n", + " # print(pred, gt, pred==gt, confidence_score)\n", + "\n", + " accuracy = n_correct / float(length_of_data) * 100\n", + " norm_ED = norm_ED / float(length_of_data) # ICDAR2019 Normalized Edit Distance\n", + "\n", + " return (\n", + " valid_loss_avg.val(),\n", + " accuracy,\n", + " norm_ED,\n", + " preds_str,\n", + " confidence_score_list,\n", + " labels,\n", + " infer_time,\n", + " length_of_data,\n", + " )\n", + "\n", + "\n", + "def count_parameters(model):\n", + " print(\"Modules, Parameters\")\n", + " total_params = 0\n", + " for name, parameter in model.named_parameters():\n", + " if not parameter.requires_grad:\n", + " continue\n", + " param = parameter.numel()\n", + " # table.add_row([name, param])\n", + " total_params += param\n", + " print(name, param)\n", + " print(f\"Total Trainable Params: {total_params}\")\n", + " return total_params\n", + "\n", + "\n", + "def train(opt, show_number=2, amp=False):\n", + " \"\"\"dataset preparation\"\"\"\n", + " if not opt.data_filtering_off:\n", + " print(\n", + " \"Filtering the images containing characters which are not in opt.character\"\n", + " )\n", + " print(\"Filtering the images whose label is longer than opt.batch_max_length\")\n", + "\n", + " opt.select_data = opt.select_data.split(\"-\")\n", + " opt.batch_ratio = opt.batch_ratio.split(\"-\")\n", + " train_dataset = Batch_Balanced_Dataset(opt)\n", + "\n", + " log = open(\n", + " f\"./saved_models/{opt.experiment_name}/log_dataset.txt\", \"a\", encoding=\"utf8\"\n", + " )\n", + " AlignCollate_valid = AlignCollate(\n", + " imgH=opt.imgH,\n", + " imgW=opt.imgW,\n", + " keep_ratio_with_pad=opt.PAD,\n", + " contrast_adjust=opt.contrast_adjust,\n", + " )\n", + " valid_dataset, valid_dataset_log = hierarchical_dataset(\n", + " root=opt.valid_data, opt=opt\n", + " )\n", + " valid_loader = torch.utils.data.DataLoader(\n", + " valid_dataset,\n", + " batch_size=min(32, opt.batch_size),\n", + " shuffle=True, # 'True' to check training progress with validation function.\n", + " num_workers=int(opt.workers),\n", + " prefetch_factor=512,\n", + " collate_fn=AlignCollate_valid,\n", + " pin_memory=True,\n", + " )\n", + " log.write(valid_dataset_log)\n", + " print(\"-\" * 80)\n", + " log.write(\"-\" * 80 + \"\\n\")\n", + " log.close()\n", + "\n", + " \"\"\" model configuration \"\"\"\n", + " if \"CTC\" in opt.Prediction:\n", + " converter = CTCLabelConverter(opt.character)\n", + " else:\n", + " converter = AttnLabelConverter(opt.character)\n", + " opt.num_class = len(converter.character)\n", + "\n", + " if opt.rgb:\n", + " opt.input_channel = 3\n", + " model = Model(opt)\n", + " print(\n", + " \"model input parameters\",\n", + " opt.imgH,\n", + " opt.imgW,\n", + " opt.num_fiducial,\n", + " opt.input_channel,\n", + " opt.output_channel,\n", + " opt.hidden_size,\n", + " opt.num_class,\n", + " opt.batch_max_length,\n", + " opt.Transformation,\n", + " opt.FeatureExtraction,\n", + " opt.SequenceModeling,\n", + " opt.Prediction,\n", + " )\n", + "\n", + " if opt.saved_model != \"\":\n", + " pretrained_dict = torch.load(opt.saved_model)\n", + " if opt.new_prediction:\n", + " model.Prediction = nn.Linear(\n", + " model.SequenceModeling_output,\n", + " len(pretrained_dict[\"module.Prediction.weight\"]),\n", + " )\n", + "\n", + " model = torch.nn.DataParallel(model).to(device)\n", + " print(f\"loading pretrained model from {opt.saved_model}\")\n", + " if opt.FT:\n", + " model.load_state_dict(pretrained_dict, strict=False)\n", + " else:\n", + " model.load_state_dict(pretrained_dict)\n", + " if opt.new_prediction:\n", + " model.module.Prediction = nn.Linear(\n", + " model.module.SequenceModeling_output, opt.num_class\n", + " )\n", + " for name, param in model.module.Prediction.named_parameters():\n", + " if \"bias\" in name:\n", + " init.constant_(param, 0.0)\n", + " elif \"weight\" in name:\n", + " init.kaiming_normal_(param)\n", + " model = model.to(device)\n", + " else:\n", + " # weight initialization\n", + " for name, param in model.named_parameters():\n", + " if \"localization_fc2\" in name:\n", + " print(f\"Skip {name} as it is already initialized\")\n", + " continue\n", + " try:\n", + " if \"bias\" in name:\n", + " init.constant_(param, 0.0)\n", + " elif \"weight\" in name:\n", + " init.kaiming_normal_(param)\n", + " except Exception as e: # for batchnorm.\n", + " if \"weight\" in name:\n", + " param.data.fill_(1)\n", + " continue\n", + " model = torch.nn.DataParallel(model).to(device)\n", + "\n", + " model.train()\n", + " print(\"Model:\")\n", + " print(model)\n", + " count_parameters(model)\n", + "\n", + " \"\"\" setup loss \"\"\"\n", + " if \"CTC\" in opt.Prediction:\n", + " criterion = torch.nn.CTCLoss(zero_infinity=True).to(device)\n", + " else:\n", + " criterion = torch.nn.CrossEntropyLoss(ignore_index=0).to(\n", + " device\n", + " ) # ignore [GO] token = ignore index 0\n", + " # loss averager\n", + " loss_avg = Averager()\n", + "\n", + " # freeze some layers\n", + " try:\n", + " if opt.freeze_FeatureFxtraction:\n", + " for param in model.module.FeatureExtraction.parameters():\n", + " param.requires_grad = False\n", + " if opt.freeze_SequenceModeling:\n", + " for param in model.module.SequenceModeling.parameters():\n", + " param.requires_grad = False\n", + " except:\n", + " pass\n", + "\n", + " # filter that only require gradient decent\n", + " filtered_parameters = []\n", + " params_num = []\n", + " for p in filter(lambda p: p.requires_grad, model.parameters()):\n", + " filtered_parameters.append(p)\n", + " params_num.append(np.prod(p.size()))\n", + " print(\"Trainable params num : \", sum(params_num))\n", + " # [print(name, p.numel()) for name, p in filter(lambda p: p[1].requires_grad, model.named_parameters())]\n", + "\n", + " # setup optimizer\n", + " if opt.optim == \"adam\":\n", + " # optimizer = optim.Adam(filtered_parameters, lr=opt.lr, betas=(opt.beta1, 0.999))\n", + " optimizer = optim.Adam(filtered_parameters)\n", + " else:\n", + " optimizer = optim.Adadelta(\n", + " filtered_parameters, lr=opt.lr, rho=opt.rho, eps=opt.eps\n", + " )\n", + " print(\"Optimizer:\")\n", + " print(optimizer)\n", + "\n", + " \"\"\" final options \"\"\"\n", + " # print(opt)\n", + " with open(\n", + " f\"./saved_models/{opt.experiment_name}/opt.txt\", \"a\", encoding=\"utf8\"\n", + " ) as opt_file:\n", + " opt_log = \"------------ Options -------------\\n\"\n", + " args = vars(opt)\n", + " for k, v in args.items():\n", + " opt_log += f\"{str(k)}: {str(v)}\\n\"\n", + " opt_log += \"---------------------------------------\\n\"\n", + " print(opt_log)\n", + " opt_file.write(opt_log)\n", + "\n", + " \"\"\" start training \"\"\"\n", + " start_iter = 0\n", + " if opt.saved_model != \"\":\n", + " try:\n", + " start_iter = int(opt.saved_model.split(\"_\")[-1].split(\".\")[0])\n", + " print(f\"continue to train, start_iter: {start_iter}\")\n", + " except:\n", + " pass\n", + "\n", + " start_time = time.time()\n", + " best_accuracy = -1\n", + " best_norm_ED = -1\n", + " i = start_iter\n", + "\n", + " scaler = GradScaler()\n", + " t1 = time.time()\n", + "\n", + " while True:\n", + " # train part\n", + " optimizer.zero_grad(set_to_none=True)\n", + "\n", + " if amp:\n", + " with autocast():\n", + " image_tensors, labels = train_dataset.get_batch()\n", + " image = image_tensors.to(device)\n", + " text, length = converter.encode(\n", + " labels, batch_max_length=opt.batch_max_length\n", + " )\n", + " batch_size = image.size(0)\n", + "\n", + " if \"CTC\" in opt.Prediction:\n", + " preds = model(image, text).log_softmax(2)\n", + " preds_size = torch.IntTensor([preds.size(1)] * batch_size)\n", + " preds = preds.permute(1, 0, 2)\n", + " torch.backends.cudnn.enabled = False\n", + " cost = criterion(\n", + " preds, text.to(device), preds_size.to(device), length.to(device)\n", + " )\n", + " torch.backends.cudnn.enabled = True\n", + " else:\n", + " preds = model(image, text[:, :-1]) # align with Attention.forward\n", + " target = text[:, 1:] # without [GO] Symbol\n", + " cost = criterion(\n", + " preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)\n", + " )\n", + " scaler.scale(cost).backward()\n", + " scaler.unscale_(optimizer)\n", + " torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)\n", + " scaler.step(optimizer)\n", + " scaler.update()\n", + " else:\n", + " image_tensors, labels = train_dataset.get_batch()\n", + " image = image_tensors.to(device)\n", + " text, length = converter.encode(\n", + " labels, batch_max_length=opt.batch_max_length\n", + " )\n", + " batch_size = image.size(0)\n", + " if \"CTC\" in opt.Prediction:\n", + " preds = model(image, text).log_softmax(2)\n", + " preds_size = torch.IntTensor([preds.size(1)] * batch_size)\n", + " preds = preds.permute(1, 0, 2)\n", + " torch.backends.cudnn.enabled = False\n", + " cost = criterion(\n", + " preds, text.to(device), preds_size.to(device), length.to(device)\n", + " )\n", + " torch.backends.cudnn.enabled = True\n", + " else:\n", + " preds = model(image, text[:, :-1]) # align with Attention.forward\n", + " target = text[:, 1:] # without [GO] Symbol\n", + " cost = criterion(\n", + " preds.view(-1, preds.shape[-1]), target.contiguous().view(-1)\n", + " )\n", + " cost.backward()\n", + " torch.nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)\n", + " optimizer.step()\n", + " loss_avg.add(cost)\n", + "\n", + " # validation part\n", + " if (i % opt.valInterval == 0) and (i != 0):\n", + " print(\"training time: \", time.time() - t1)\n", + " t1 = time.time()\n", + " elapsed_time = time.time() - start_time\n", + " # for log\n", + " with open(\n", + " f\"./saved_models/{opt.experiment_name}/log_train.txt\",\n", + " \"a\",\n", + " encoding=\"utf8\",\n", + " ) as log:\n", + " model.eval()\n", + " with torch.no_grad():\n", + " (\n", + " valid_loss,\n", + " current_accuracy,\n", + " current_norm_ED,\n", + " preds,\n", + " confidence_score,\n", + " labels,\n", + " infer_time,\n", + " length_of_data,\n", + " ) = validation(\n", + " model, criterion, valid_loader, converter, opt, device\n", + " )\n", + " model.train()\n", + "\n", + " # training loss and validation loss\n", + " loss_log = f\"[{i}/{opt.num_iter}] Train loss: {loss_avg.val():0.5f}, Valid loss: {valid_loss:0.5f}, Elapsed_time: {elapsed_time:0.5f}\"\n", + " loss_avg.reset()\n", + "\n", + " current_model_log = f'{\"Current_accuracy\":17s}: {current_accuracy:0.3f}, {\"Current_norm_ED\":17s}: {current_norm_ED:0.4f}'\n", + "\n", + " # keep best accuracy model (on valid dataset)\n", + " if current_accuracy > best_accuracy:\n", + " best_accuracy = current_accuracy\n", + " torch.save(\n", + " model.state_dict(),\n", + " f\"./saved_models/{opt.experiment_name}/best_accuracy.pth\",\n", + " )\n", + " if current_norm_ED > best_norm_ED:\n", + " best_norm_ED = current_norm_ED\n", + " torch.save(\n", + " model.state_dict(),\n", + " f\"./saved_models/{opt.experiment_name}/best_norm_ED.pth\",\n", + " )\n", + " best_model_log = f'{\"Best_accuracy\":17s}: {best_accuracy:0.3f}, {\"Best_norm_ED\":17s}: {best_norm_ED:0.4f}'\n", + "\n", + " loss_model_log = f\"{loss_log}\\n{current_model_log}\\n{best_model_log}\"\n", + " print(loss_model_log)\n", + " log.write(loss_model_log + \"\\n\")\n", + "\n", + " # show some predicted results\n", + " dashed_line = \"-\" * 80\n", + " head = f'{\"Ground Truth\":25s} | {\"Prediction\":25s} | Confidence Score & T/F'\n", + " predicted_result_log = f\"{dashed_line}\\n{head}\\n{dashed_line}\\n\"\n", + "\n", + " # show_number = min(show_number, len(labels))\n", + "\n", + " start = random.randint(0, len(labels) - show_number)\n", + " for gt, pred, confidence in zip(\n", + " labels[start : start + show_number],\n", + " preds[start : start + show_number],\n", + " confidence_score[start : start + show_number],\n", + " ):\n", + " if \"Attn\" in opt.Prediction:\n", + " gt = gt[: gt.find(\"[s]\")]\n", + " pred = pred[: pred.find(\"[s]\")]\n", + "\n", + " predicted_result_log += f\"{gt:25s} | {pred:25s} | {confidence:0.4f}\\t{str(pred == gt)}\\n\"\n", + " predicted_result_log += f\"{dashed_line}\"\n", + " print(predicted_result_log)\n", + " log.write(predicted_result_log + \"\\n\")\n", + " print(\"validation time: \", time.time() - t1)\n", + " t1 = time.time()\n", + " # save model per 1e+4 iter.\n", + " if (i + 1) % 1e4 == 0:\n", + " torch.save(\n", + " model.state_dict(),\n", + " f\"./saved_models/{opt.experiment_name}/iter_{i+1}.pth\",\n", + " )\n", + "\n", + " if i == opt.num_iter:\n", + " print(\"end the training\")\n", + " sys.exit()\n", + " i += 1\n", + "\n", + "\n", "def get_config(file_path):\n", - " with open(file_path, 'r', encoding=\"utf8\") as stream:\n", + " with open(file_path, \"r\", encoding=\"utf8\") as stream:\n", " opt = yaml.safe_load(stream)\n", " opt = AttrDict(opt)\n", - " if opt.lang_char == 'None':\n", - " characters = ''\n", - " for data in opt['select_data'].split('-'):\n", - " csv_path = os.path.join(opt['train_data'], data, 'labels.csv')\n", - " df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python', usecols=['filename', 'words'], keep_default_na=False)\n", - " all_char = ''.join(df['words'])\n", - " characters += ''.join(set(all_char))\n", + " if opt.lang_char == \"None\":\n", + " characters = \"\"\n", + " for data in opt[\"select_data\"].split(\"-\"):\n", + " csv_path = os.path.join(opt[\"train_data\"], data, \"labels.csv\")\n", + " df = pd.read_csv(\n", + " csv_path,\n", + " sep=\"^([^,]+),\",\n", + " engine=\"python\",\n", + " usecols=[\"filename\", \"words\"],\n", + " keep_default_na=False,\n", + " )\n", + " all_char = \"\".join(df[\"words\"])\n", + " characters += \"\".join(set(all_char))\n", " characters = sorted(set(characters))\n", - " opt.character= ''.join(characters)\n", + " opt.character = \"\".join(characters)\n", " else:\n", " opt.character = opt.number + opt.symbol + opt.lang_char\n", - " os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)\n", + " os.makedirs(f\"./saved_models/{opt.experiment_name}\", exist_ok=True)\n", " return opt" ] }, @@ -79,7 +555,7 @@ }, "outputs": [], "source": [ - "opt = get_config(\"config_files/en_filtered_config.yaml\")\n", + "opt = get_config(\"config_files/custom_data_train.yaml\")\n", "train(opt, amp=False)" ] }, @@ -107,7 +583,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.11" + "version": "3.9.0" } }, "nbformat": 4,