Skip to content

Commit

Permalink
Fix issue #368
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelbenayoun committed Dec 14, 2023
1 parent 88e589f commit c2d554d
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion optimum/neuron/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Defines Trainer subclasses to perform training on AWS Neuron instances."""

import contextlib
import copy
import glob
import os
import random
Expand Down Expand Up @@ -395,7 +396,12 @@ def _save_xla(self, output_dir: Optional[str] = None):
if self.accelerator.distributed_type is NeuronDistributedType.TENSOR_PARALLELISM:
logger.info("Model parallelism is enabled, only saving the model sharded state dict.")
if isinstance(self.model, PreTrainedModel):
self.model.config.save_pretrained(output_dir)
from neuronx_distributed.parallel_layers.parallel_state import get_tensor_model_parallel_size

config = copy.deepcopy(self.model.config)
if self.args.tp_plugin.parallelize_embeddings:
config.vocab_size = config.vocab_size * get_tensor_model_parallel_size()
config.save_pretrained(output_dir)

parallelizer = ParallelizersManager.parallelizer_for_model(self.model)
# This mark_step is needed to avoid hang issues.
Expand Down

0 comments on commit c2d554d

Please sign in to comment.