diff --git a/.gitignore b/.gitignore index 898feb8f4f..6890911c14 100644 --- a/.gitignore +++ b/.gitignore @@ -39,3 +39,4 @@ develop-eggs/ dist/ downloads/ .pytest_cache/ +compile_commands.json diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 5c1410e3e8..0b381258a4 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -126,12 +126,9 @@ def forward( # Check scale if scale is None and fp8_meta is None: - scale = 1 + scale = torch.full([1], 1, dtype=torch.float32, device=device) if scale is not None: - if isinstance(scale, torch.Tensor): - scale = scale.to(device=device, dtype=torch.float32) - else: - scale = torch.full([1], scale, dtype=torch.float32, device=device) + scale = scale.to(device=device, dtype=torch.float32) # Check scale-inverse if scale_inv is None: @@ -335,6 +332,18 @@ class Float8Tensor(QuantizedTensor): """ + _data: torch.Tensor + _fp8_attrs: Dict[str, Any] + _fp8_meta: Optional[Dict[str, Any]] + _fp8_meta_forward: bool + _fp8_meta_index: Optional[int] + _fp8_dtype: TE_DType + _scale_inv: torch.Tensor + + # FP8 transpose cache + _transpose: Optional[torch.Tensor] + _transpose_invalid: bool + def __new__( cls, *, @@ -371,13 +380,12 @@ def __new__( requires_grad=requires_grad, device=data.device, ) - self._data: torch.Tensor = data + self._data = data # Initialize dict of class attributes # Note: We store FP8 attributes in a dictionary so we can # share them between tensors with the same data, e.g. detached # tensors. - self._fp8_attrs: dict if fp8_attrs is None: self._fp8_attrs = {} else: @@ -390,16 +398,16 @@ def __new__( "To initialize Float8Tensor with FP8 meta tensors, " "the FP8 meta tensor index must also be provided" ) - self._fp8_meta: Optional[Dict[str, Any]] = fp8_meta - self._fp8_meta_forward: bool = fp8_meta_forward - self._fp8_meta_index: Optional[int] = fp8_meta_index + self._fp8_meta = fp8_meta + self._fp8_meta_forward = fp8_meta_forward + self._fp8_meta_index = fp8_meta_index # FP8 dtype assert fp8_dtype in ( TE_DType.kFloat8E4M3, TE_DType.kFloat8E5M2, ), f"Unsupported fp8_dtype {fp8_dtype}." - self._fp8_dtype: TE_DType = fp8_dtype + self._fp8_dtype = fp8_dtype # FP8 scale-inverse if fp8_scale_inv is None and self._fp8_meta is not None: @@ -412,13 +420,6 @@ def __new__( raise ValueError( "Attempted to initialize Float8Tensor without specifying scale-inverse" ) - if not isinstance(fp8_scale_inv, torch.Tensor): - fp8_scale_inv = torch.full( - [1], - fp8_scale_inv, - dtype=torch.float32, - device=self._data.device, - ) if fp8_scale_inv.numel() != 1: raise ValueError( "Attempted to initialize Float8Tensor with invalid scale-inverse tensor" @@ -433,11 +434,11 @@ def __new__( device=self._data.device, dtype=torch.float32, ) - self._scale_inv: Optional[torch.Tensor] = fp8_scale_inv + self._scale_inv = fp8_scale_inv # FP8 transpose cache - self._transpose: Optional[Float8Tensor] = data_transpose - self._transpose_invalid: bool = self._transpose is None + self._transpose = data_transpose + self._transpose_invalid = self._transpose is None return self @@ -477,7 +478,7 @@ def __repr__(self): ")" ) - def dequantize(self, dtype: Optional[torch.dtype] = None) -> torch.Tensor: + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: # Convert PyTorch dtype to TE dtype if dtype is None: @@ -603,11 +604,8 @@ def quantize_( # Make sure FP8 scaling factors are in expected format if scale is not None: - if isinstance(scale, torch.Tensor): - if not devices_match(scale.device, dst.device) or scale.dtype != torch.float32: - scale = scale.to(device=dst.device, dtype=torch.float32) - else: - scale = torch.full([1], scale, dtype=torch.float32, device=dst.device) + if not devices_match(scale.device, dst.device) or scale.dtype != torch.float32: + scale = scale.to(device=dst.device, dtype=torch.float32) if amax is not None: while amax.dim() < 2: amax = amax.unsqueeze(0) @@ -781,23 +779,21 @@ def transpose_2d( fill_cache = False # Need to compute transpose if cache is invalid - need_compute = force_compute - if self._transpose is None: - need_compute = True - elif self._transpose_invalid: - need_compute = True - - # Need to apply transpose kernel if noop flag is applied - if noop_flag is not None: - need_compute = True + need_compute = ( + force_compute + or (self._transpose is None) + or self._transpose_invalid + or (noop_flag is not None) + ) # Return cached transpose if possible if not need_compute: + assert self._transpose is not None return self._transpose # Allocate output if needed data = self._data.contiguous().reshape(-1, self.size(-1)) - out = self._transpose + out: Optional[torch.Tensor] = self._transpose if out is None: out = torch.empty( (data.size(1), data.size(0)),