我想做什么
由于fairseq学习的模型需要很长时间才能加载,因此创建一个行为与fairseq-interactive相同的类。
fairseq的版本是0.9.0。
码
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 | from collections import namedtuple import torch from fairseq import checkpoint_utils, options, tasks, utils from fairseq.data import encoders Batch = namedtuple('Batch', 'ids src_tokens src_lengths') Translation = namedtuple('Translation', 'src_str hypos pos_scores alignments') def make_batches(lines, args, task, max_positions, encode_fn): tokens = [ task.source_dictionary.encode_line( encode_fn(src_str), add_if_not_exist=False ).long() for src_str in lines ] lengths = torch.LongTensor([t.numel() for t in tokens]) itr = task.get_batch_iterator( dataset=task.build_dataset_for_inference(tokens, lengths), max_tokens=args.max_tokens, max_sentences=args.max_sentences, max_positions=max_positions, ).next_epoch_itr(shuffle=False) for batch in itr: yield Batch( ids=batch['id'], src_tokens=batch['net_input']['src_tokens'], src_lengths=batch['net_input']['src_lengths'], ) class Generator(): def __init__(self, data_path, checkpoint_path="checkpoint_best.pt"): self.parser = options.get_generation_parser(interactive=True) self.parser.set_defaults(path=checkpoint_path, remove_bpe=None, dataset_impl="lazy", num_wokers=5 ) self.args = options.parse_args_and_arch(self.parser, input_args=[data_path] ) utils.import_user_module(self.args) if self.args.buffer_size < 1: self.args.buffer_size = 1 if self.args.max_tokens is None and self.args.max_sentences is None: self.args.max_sentences = 1 assert not self.args.sampling or self.args.nbest == self.args.beam, \ '--sampling requires --nbest to be equal to --beam' assert not self.args.max_sentences or self.args.max_sentences <= self.args.buffer_size, \ '--max-sentences/--batch-size cannot be larger than --buffer-size' self.use_cuda = torch.cuda.is_available() and not self.args.cpu self.task = tasks.setup_task(self.args) self.models, self._model_args = checkpoint_utils.load_model_ensemble( self.args.path.split(':'), arg_overrides=eval(self.args.model_overrides), task=self.task, ) self.src_dict = self.task.source_dictionary self.tgt_dict = self.task.target_dictionary for model in self.models: model.make_generation_fast_( beamable_mm_beam_size=None if self.args.no_beamable_mm else self.args.beam, need_attn=self.args.print_alignment, ) if self.args.fp16: model.half() if self.use_cuda: model.cuda() self.generator = self.task.build_generator(self.models, self.args) if self.args.remove_bpe == 'gpt2': from fairseq.gpt2_bpe.gpt2_encoding import get_encoder self.decoder = get_encoder( 'fairseq/gpt2_bpe/encoder.json', 'fairseq/gpt2_bpe/vocab.bpe', ) self.encode_fn = lambda x: ' '.join(map(str, self.decoder.encode(x))) else: self.decoder = None self.encode_fn = lambda x: x self.align_dict = utils.load_align_dict(self.args.replace_unk) self.max_positions = utils.resolve_max_positions( self.task.max_positions(), *[model.max_positions() for model in self.models] ) def generate(self, string): start_id = 0 inputs = [string] results = [] for batch in make_batches(inputs, self.args, self.task, self.max_positions, self.encode_fn): src_tokens = batch.src_tokens src_lengths = batch.src_lengths if self.use_cuda: src_tokens = src_tokens.cuda() src_lengths = src_lengths.cuda() sample = { 'net_input': { 'src_tokens': src_tokens, 'src_lengths': src_lengths, }, } translations = self.task.inference_step(self.generator, self.models, sample) for i, (id, hypos) in enumerate(zip(batch.ids.tolist(), translations)): src_tokens_i = utils.strip_pad(src_tokens[i], self.tgt_dict.pad()) results.append((start_id + id, src_tokens_i, hypos)) for id, src_tokens, hypos in sorted(results, key=lambda x: x[0]): if self.src_dict is not None: src_str = self.src_dict.string(src_tokens, self.args.remove_bpe) for hypo in hypos[:min(len(hypos), self.args.nbest)]: hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=hypo['tokens'].int().cpu(), src_str=src_str, alignment=hypo['alignment'].int().cpu() if hypo['alignment'] is not None else None, align_dict=self.align_dict, tgt_dict=self.tgt_dict, remove_bpe=self.args.remove_bpe, ) if self.decoder is not None: hypo_str = self.decoder.decode(map(int, hypo_str.strip().split())) return hypo_str |
如何使用
1 2 3 | gen = Generator("/path/to/data.src_trg", "/path/to/checkpoint_best.pt") gen.generate("分か ち 書き し た 文 章") > 生 成 され た 文 章 |
参考
不分割。不能与最新版本0.9.0一起使用。
https://github.com/sharad461/nepali-translator/blob/master/translator/app/modules/interactive.py
原始博览会eq-interactive
https://github.com/pytorch/fairseq/blob/master/fairseq_cli/interactive.py