@@ -1514,195 +1514,3 @@ def get_simplified_key(
15141514 next_power_of_2 (page_size ),
15151515 next_power_of_2 (max_num_tokens ),
15161516 )
1517-
1518-
1519- TUNED_KV_PAGES_FOR_DECODE = {
1520- # key
1521- # - device_name
1522- # - q dtype
1523- # - kv dtype
1524- # - q head number
1525- # - kv head number
1526- # - head dim
1527- # - page_size
1528- # value:
1529- # - num_kv_pages_per_block
1530- "TPU v6e" : {
1531- ("bfloat16" , "bfloat16" , 1 , 1 , 128 , 128 ): 256 ,
1532- ("bfloat16" , "bfloat16" , 2 , 1 , 128 , 128 ): 256 ,
1533- ("bfloat16" , "bfloat16" , 4 , 1 , 128 , 128 ): 256 ,
1534- ("bfloat16" , "bfloat16" , 8 , 1 , 128 , 128 ): 256 ,
1535- ("bfloat16" , "bfloat16" , 16 , 1 , 128 , 128 ): 256 ,
1536- ("bfloat16" , "bfloat16" , 2 , 2 , 128 , 128 ): 128 ,
1537- ("bfloat16" , "bfloat16" , 4 , 2 , 128 , 128 ): 128 ,
1538- ("bfloat16" , "bfloat16" , 8 , 2 , 128 , 128 ): 128 ,
1539- ("bfloat16" , "bfloat16" , 16 , 2 , 128 , 128 ): 128 ,
1540- ("bfloat16" , "bfloat16" , 32 , 2 , 128 , 128 ): 128 ,
1541- ("bfloat16" , "bfloat16" , 4 , 4 , 128 , 128 ): 64 ,
1542- ("bfloat16" , "bfloat16" , 8 , 4 , 128 , 128 ): 64 ,
1543- ("bfloat16" , "bfloat16" , 16 , 4 , 128 , 128 ): 64 ,
1544- ("bfloat16" , "bfloat16" , 32 , 4 , 128 , 128 ): 64 ,
1545- ("bfloat16" , "bfloat16" , 8 , 8 , 128 , 128 ): 32 ,
1546- ("bfloat16" , "bfloat16" , 16 , 8 , 128 , 128 ): 32 ,
1547- ("bfloat16" , "bfloat16" , 32 , 8 , 128 , 128 ): 32 ,
1548- ("bfloat16" , "bfloat16" , 64 , 8 , 128 , 128 ): 32 ,
1549- ("bfloat16" , "bfloat16" , 16 , 16 , 128 , 128 ): 16 ,
1550- ("bfloat16" , "bfloat16" , 32 , 16 , 128 , 128 ): 16 ,
1551- ("bfloat16" , "bfloat16" , 64 , 16 , 128 , 128 ): 16 ,
1552- ("bfloat16" , "bfloat16" , 128 , 16 , 128 , 128 ): 16 ,
1553- ("bfloat16" , "bfloat16" , 1 , 1 , 128 , 256 ): 128 ,
1554- ("bfloat16" , "bfloat16" , 2 , 1 , 128 , 256 ): 128 ,
1555- ("bfloat16" , "bfloat16" , 4 , 1 , 128 , 256 ): 128 ,
1556- ("bfloat16" , "bfloat16" , 8 , 1 , 128 , 256 ): 128 ,
1557- ("bfloat16" , "bfloat16" , 16 , 1 , 128 , 256 ): 128 ,
1558- ("bfloat16" , "bfloat16" , 2 , 2 , 128 , 256 ): 64 ,
1559- ("bfloat16" , "bfloat16" , 4 , 2 , 128 , 256 ): 64 ,
1560- ("bfloat16" , "bfloat16" , 8 , 2 , 128 , 256 ): 64 ,
1561- ("bfloat16" , "bfloat16" , 16 , 2 , 128 , 256 ): 64 ,
1562- ("bfloat16" , "bfloat16" , 4 , 4 , 128 , 256 ): 32 ,
1563- ("bfloat16" , "bfloat16" , 8 , 4 , 128 , 256 ): 32 ,
1564- ("bfloat16" , "bfloat16" , 16 , 4 , 128 , 256 ): 32 ,
1565- ("bfloat16" , "bfloat16" , 32 , 4 , 128 , 256 ): 32 ,
1566- ("bfloat16" , "bfloat16" , 8 , 8 , 128 , 256 ): 16 ,
1567- ("bfloat16" , "bfloat16" , 16 , 8 , 128 , 256 ): 16 ,
1568- ("bfloat16" , "bfloat16" , 32 , 8 , 128 , 256 ): 16 ,
1569- ("bfloat16" , "bfloat16" , 64 , 8 , 128 , 256 ): 16 ,
1570- ("bfloat16" , "bfloat16" , 16 , 16 , 128 , 256 ): 8 ,
1571- ("bfloat16" , "bfloat16" , 32 , 16 , 128 , 256 ): 8 ,
1572- ("bfloat16" , "bfloat16" , 64 , 16 , 128 , 256 ): 8 ,
1573- ("bfloat16" , "bfloat16" , 128 , 16 , 128 , 256 ): 8 ,
1574- ("bfloat16" , "bfloat16" , 256 , 16 , 128 , 256 ): 8 ,
1575- ("bfloat16" , "bfloat16" , 512 , 16 , 128 , 256 ): 8 ,
1576- },
1577- "TPU v7" : {
1578- ("bfloat16" , "bfloat16" , 1 , 1 , 128 , 128 ): 256 ,
1579- ("bfloat16" , "bfloat16" , 2 , 1 , 128 , 128 ): 256 ,
1580- ("bfloat16" , "bfloat16" , 4 , 1 , 128 , 128 ): 256 ,
1581- ("bfloat16" , "bfloat16" , 8 , 1 , 128 , 128 ): 256 ,
1582- ("bfloat16" , "bfloat16" , 16 , 1 , 128 , 128 ): 256 ,
1583- ("bfloat16" , "bfloat16" , 2 , 2 , 128 , 128 ): 128 ,
1584- ("bfloat16" , "bfloat16" , 4 , 2 , 128 , 128 ): 128 ,
1585- ("bfloat16" , "bfloat16" , 8 , 2 , 128 , 128 ): 128 ,
1586- ("bfloat16" , "bfloat16" , 16 , 2 , 128 , 128 ): 128 ,
1587- ("bfloat16" , "bfloat16" , 32 , 2 , 128 , 128 ): 128 ,
1588- ("bfloat16" , "bfloat16" , 4 , 4 , 128 , 128 ): 64 ,
1589- ("bfloat16" , "bfloat16" , 8 , 4 , 128 , 128 ): 64 ,
1590- ("bfloat16" , "bfloat16" , 16 , 4 , 128 , 128 ): 64 ,
1591- ("bfloat16" , "bfloat16" , 32 , 4 , 128 , 128 ): 64 ,
1592- ("bfloat16" , "bfloat16" , 8 , 8 , 128 , 128 ): 32 ,
1593- ("bfloat16" , "bfloat16" , 16 , 8 , 128 , 128 ): 32 ,
1594- ("bfloat16" , "bfloat16" , 32 , 8 , 128 , 128 ): 32 ,
1595- ("bfloat16" , "bfloat16" , 64 , 8 , 128 , 128 ): 32 ,
1596- ("bfloat16" , "bfloat16" , 16 , 16 , 128 , 128 ): 16 ,
1597- ("bfloat16" , "bfloat16" , 32 , 16 , 128 , 128 ): 16 ,
1598- ("bfloat16" , "bfloat16" , 64 , 16 , 128 , 128 ): 16 ,
1599- ("bfloat16" , "bfloat16" , 128 , 16 , 128 , 128 ): 16 ,
1600- ("bfloat16" , "bfloat16" , 1 , 1 , 128 , 256 ): 128 ,
1601- ("bfloat16" , "bfloat16" , 2 , 1 , 128 , 256 ): 128 ,
1602- ("bfloat16" , "bfloat16" , 4 , 1 , 128 , 256 ): 128 ,
1603- ("bfloat16" , "bfloat16" , 8 , 1 , 128 , 256 ): 128 ,
1604- ("bfloat16" , "bfloat16" , 16 , 1 , 128 , 256 ): 128 ,
1605- ("bfloat16" , "bfloat16" , 2 , 2 , 128 , 256 ): 64 ,
1606- ("bfloat16" , "bfloat16" , 4 , 2 , 128 , 256 ): 64 ,
1607- ("bfloat16" , "bfloat16" , 8 , 2 , 128 , 256 ): 64 ,
1608- ("bfloat16" , "bfloat16" , 16 , 2 , 128 , 256 ): 64 ,
1609- ("bfloat16" , "bfloat16" , 4 , 4 , 128 , 256 ): 32 ,
1610- ("bfloat16" , "bfloat16" , 8 , 4 , 128 , 256 ): 32 ,
1611- ("bfloat16" , "bfloat16" , 16 , 4 , 128 , 256 ): 32 ,
1612- ("bfloat16" , "bfloat16" , 32 , 4 , 128 , 256 ): 32 ,
1613- ("bfloat16" , "bfloat16" , 8 , 8 , 128 , 256 ): 16 ,
1614- ("bfloat16" , "bfloat16" , 16 , 8 , 128 , 256 ): 16 ,
1615- ("bfloat16" , "bfloat16" , 32 , 8 , 128 , 256 ): 16 ,
1616- ("bfloat16" , "bfloat16" , 64 , 8 , 128 , 256 ): 16 ,
1617- ("bfloat16" , "bfloat16" , 16 , 16 , 128 , 256 ): 8 ,
1618- ("bfloat16" , "bfloat16" , 32 , 16 , 128 , 256 ): 8 ,
1619- ("bfloat16" , "bfloat16" , 64 , 16 , 128 , 256 ): 8 ,
1620- ("bfloat16" , "bfloat16" , 128 , 16 , 128 , 256 ): 8 ,
1621- ("bfloat16" , "bfloat16" , 256 , 16 , 128 , 256 ): 8 ,
1622- ("bfloat16" , "bfloat16" , 512 , 16 , 128 , 256 ): 8 ,
1623- },
1624- }
1625-
1626-
1627- def get_kv_pages_for_decode (
1628- q_dtype ,
1629- kv_dtype ,
1630- actual_num_q_heads ,
1631- actual_num_kv_heads ,
1632- head_dim ,
1633- page_size ,
1634- pages_per_seq ,
1635- causal = True ,
1636- ) -> int :
1637- if not causal :
1638- # FIXME(pc) hack this to avoid oom when precompile, currently, we still have no better choice for non-causal's mask
1639- # this should be optimied future
1640- return 4
1641- """Look up for the best num_kv_pages_per_blk from auto-tuned table."""
1642- tpu_version = get_tpu_version ()
1643-
1644- if tpu_version < 4 :
1645- raise NotImplementedError ("TPU version must be 4 or higher." )
1646- keys = get_simplified_key_for_decode (
1647- page_size ,
1648- q_dtype ,
1649- kv_dtype ,
1650- actual_num_q_heads ,
1651- actual_num_kv_heads ,
1652- head_dim ,
1653- )
1654-
1655- device_name = keys [0 ]
1656-
1657- # Default block sizes.
1658- bkv_p = 1024 // page_size
1659- if tpu_version == 4 :
1660- # TPUv4 has much smaller VMEM size so we pick fixed block sizes.
1661- bkv_p = 512 // page_size
1662- else :
1663- if (
1664- device_name in TUNED_KV_PAGES_FOR_DECODE
1665- and keys [1 :] in TUNED_KV_PAGES_FOR_DECODE [device_name ]
1666- ):
1667- bkv_p = TUNED_KV_PAGES_FOR_DECODE [device_name ][keys [1 :]]
1668- else :
1669- logger .info (
1670- "Tuned RPA kv page not found for %s: page_size=%s, actual_num_q_heads=%s, "
1671- "actual_num_kv_heads=%s, head_dim=%s, pages_per_seq=%s." ,
1672- device_name ,
1673- page_size ,
1674- actual_num_q_heads ,
1675- actual_num_kv_heads ,
1676- head_dim ,
1677- pages_per_seq ,
1678- )
1679- logger .info ("Using default block size: bkv_p=%s." , bkv_p )
1680-
1681- return min (pages_per_seq , bkv_p )
1682-
1683-
1684- def get_simplified_key_for_decode (
1685- page_size ,
1686- q_dtype ,
1687- kv_dtype ,
1688- num_q_heads ,
1689- num_kv_heads ,
1690- head_dim ,
1691- ):
1692- """Get the simplified key to reduce the number of combinations."""
1693- assert num_q_heads % num_kv_heads == 0
1694- device = get_device_name ()
1695- q_dtype_name = jnp .dtype (q_dtype ).name
1696- kv_dtype_name = jnp .dtype (kv_dtype ).name
1697- num_q_heads = next_power_of_2 (num_q_heads )
1698- num_kv_heads = next_power_of_2 (num_kv_heads )
1699-
1700- return (
1701- device ,
1702- q_dtype_name ,
1703- kv_dtype_name ,
1704- num_q_heads ,
1705- num_kv_heads ,
1706- (head_dim + 127 ) // 128 * 128 ,
1707- next_power_of_2 (page_size ),
1708- )
0 commit comments