@@ -1524,3 +1524,112 @@ def test_causal_lm_training_multi_gpu(self):
15241524
15251525 # assert loss is not None
15261526 assert trainer .state .log_history [- 1 ]["train_loss" ] is not None
1527+
1528+
1529+ PRECISIONS = [(torch .float32 ), (torch .float16 ), (torch .bfloat16 )]
1530+
1531+ LORA_PARAMS = {
1532+ "r" : 8 ,
1533+ "lora_alpha" : 16 ,
1534+ "lora_dropout" : 0.05 ,
1535+ }
1536+
1537+
1538+ class SimpleModel (torch .nn .Module ):
1539+ def __init__ (self ):
1540+ super ().__init__ ()
1541+
1542+ self .embedding_layer = torch .nn .Embedding (1000 , 768 )
1543+ self .layer_norm = torch .nn .LayerNorm (768 )
1544+ self .linear_transform = torch .nn .Linear (768 , 256 )
1545+
1546+ def forward (self , input_ids ):
1547+ embedded_output = self .embedding_layer (input_ids )
1548+ norm_output = self .layer_norm (embedded_output )
1549+ linear_output = self .linear_transform (norm_output )
1550+
1551+ return linear_output
1552+
1553+
1554+ class SimpleConv2DModel (torch .nn .Module ):
1555+ def __init__ (self ):
1556+ super ().__init__ ()
1557+
1558+ self .embedding_layer = torch .nn .Embedding (1000 , 768 )
1559+ self .layer_norm = torch .nn .LayerNorm (768 )
1560+ self .conv2d_transform = torch .nn .Conv2d (1 , 256 , kernel_size = (3 , 3 ), stride = (1 , 1 ), padding = (1 , 1 ))
1561+
1562+ def forward (self , input_ids ):
1563+ # Additional layers for your custom model
1564+ embedded_output = self .embedding_layer (input_ids )
1565+ norm_output = self .layer_norm (embedded_output )
1566+
1567+ # Reshape for Conv2d input (add batch size dimension)
1568+ norm_output = norm_output .unsqueeze (1 )
1569+ conv_output = self .conv2d_transform (norm_output )
1570+
1571+ # Remove batch size dimension
1572+ conv_output = conv_output .squeeze (1 )
1573+
1574+ return conv_output
1575+
1576+
1577+ @require_torch_gpu
1578+ class TestAutoCast (unittest .TestCase ):
1579+ # This test makes sure, that Lora dtypes are consistent with the types
1580+ # infered by torch.autocast under tested PRECISIONS
1581+ @parameterized .expand (PRECISIONS )
1582+ def test_simple_model (self , * args , ** kwargs ):
1583+ self ._test_model (SimpleModel (), * args , ** kwargs )
1584+
1585+ @parameterized .expand (PRECISIONS )
1586+ def test_simple_lora_linear_model (self , * args , ** kwargs ):
1587+ simple_model = SimpleModel ()
1588+ config = LoraConfig (
1589+ ** LORA_PARAMS ,
1590+ target_modules = ["linear_transform" ],
1591+ )
1592+
1593+ lora_model = get_peft_model (simple_model , config )
1594+
1595+ self ._test_model (lora_model , * args , ** kwargs )
1596+
1597+ @parameterized .expand (PRECISIONS )
1598+ def test_simple_lora_embedding_model (self , * args , ** kwargs ):
1599+ simple_model = SimpleModel ()
1600+ config = LoraConfig (
1601+ ** LORA_PARAMS ,
1602+ target_modules = ["embedding_layer" ],
1603+ )
1604+ lora_model = get_peft_model (simple_model , config )
1605+
1606+ self ._test_model (lora_model , * args , ** kwargs )
1607+
1608+ @parameterized .expand (PRECISIONS )
1609+ def test_simple_conv2d_model (self , * args , ** kwargs ):
1610+ self ._test_model (SimpleConv2DModel (), * args , ** kwargs )
1611+
1612+ @parameterized .expand (PRECISIONS )
1613+ def test_simple_lora_conv2d_model (self , * args , ** kwargs ):
1614+ simple_model = SimpleConv2DModel ()
1615+ config = LoraConfig (
1616+ ** LORA_PARAMS ,
1617+ target_modules = ["conv2d_transform" ],
1618+ )
1619+ lora_model = get_peft_model (simple_model , config )
1620+ self ._test_model (lora_model , * args , ** kwargs )
1621+
1622+ def _test_model (self , model , precision ):
1623+ # Move model to GPU
1624+ model = model .cuda ()
1625+
1626+ # Prepare dummy inputs
1627+ input_ids = torch .randint (0 , 1000 , (2 , 10 )).cuda ()
1628+ if precision == torch .bfloat16 :
1629+ if not torch .cuda .is_bf16_supported ():
1630+ self .skipTest ("Bfloat16 not supported on this device" )
1631+
1632+ # Forward pass with test precision
1633+ with torch .autocast (enabled = True , dtype = precision , device_type = "cuda" ):
1634+ outputs = model (input_ids )
1635+ assert outputs .dtype == precision
0 commit comments