I noticed when generating input_ids. Prompt does not include tokens such as [SEG]. Will this cause seg_toked_idx to be unable to identify the index of the [SEG] token?
code:
if isinstance(self.seg_token_idx, list):
seg_token_num = self.seg_token_num
seg_token_mask = torch.zeros_like(input_ids[:, 1:]).bool()
for seg_token_idx in self.seg_token_idx:
seg_token_mask = seg_token_mask | (input_ids[:, 1:] == seg_token_idx)
else:
seg_token_num = self.seg_token_num
seg_token_mask = input_ids[:, 1:] == self.seg_token_idx