From c85bd6224c46f23eaed45af2371a36a05f187e6a Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Thu, 24 Oct 2024 00:45:32 +0000 Subject: [PATCH 1/8] feat: automatic plugin feature --- examples/dynamo/automatic_plugin/custom_op.py | 93 +++++++++++ .../conversion/plugin/plugin_generator.py | 151 ++++++++++++++++++ .../conversion/plugin_ops_converters.py | 47 ++++++ 3 files changed, 291 insertions(+) create mode 100644 examples/dynamo/automatic_plugin/custom_op.py create mode 100644 py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py create mode 100644 py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py diff --git a/examples/dynamo/automatic_plugin/custom_op.py b/examples/dynamo/automatic_plugin/custom_op.py new file mode 100644 index 0000000000..043c75d1e6 --- /dev/null +++ b/examples/dynamo/automatic_plugin/custom_op.py @@ -0,0 +1,93 @@ +import triton +import triton.language as tl + +@triton.jit +def elementwise_add_kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr): + # Program ID determines the block of data each thread will process + pid = tl.program_id(0) + # Compute the range of elements that this thread block will work on + block_start = pid * BLOCK_SIZE + # Range of indices this thread will handle + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Load elements from the X and Y tensors + x_vals = tl.load(X + offsets) + y_vals = tl.load(Y + offsets) + # Perform the element-wise multiplication + z_vals = x_vals + y_vals + # Store the result in Z + tl.store(Z + offsets, z_vals) + + +import torch +from torch.library import custom_op + + +@custom_op("torchtrt_ex::elementwise_add", mutates_args=()) # type: ignore[misc] +def elementwise_add(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor: + # Ensure the tensors are on the GPU + assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device." + assert X.shape == Y.shape, "Tensors must have the same shape." + + # Create output tensor + Z = torch.empty_like(X) + + # Define block size + BLOCK_SIZE = 1024 + + # Grid of programs + grid = lambda meta: (X.numel() // meta['BLOCK_SIZE'],) + + # Launch the kernel + elementwise_add_kernel[grid](X, Y, Z, BLOCK_SIZE=BLOCK_SIZE) + + return Z + + +# Using the module in PyTorch +# X = torch.randn(1024, device='cuda', requires_grad=True) +# Y = torch.randn(1024, device='cuda', requires_grad=True) +# X = torch.full((128, 128), 2, device='cuda',) +# Y = torch.full((128, 128), 2, device='cuda',) +# # elementwise_mul_op = ElementwiseMulModule() +# Z = torch.ops.torchtrt_ex.elementwise_add(X, Y) +# print(Z) +# print(X + Y) +# print(X) +# print(Y) +# print(Z) +# print(X+Y) +# Z.sum().backward() + + +from torch import nn + + +class MyModel(nn.Module): # type: ignore[misc] + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + z = torch.mul(x, y) + res = torch.ops.torchtrt_ex.elementwise_add(x, z) + + return res + + +my_model = MyModel().to("cuda") +m = torch.full((64, 64), 2, device='cuda',) +n = torch.full((64, 64), 3, device='cuda',) +# print(torch.ops.torchtrt_ex.elementwise_add(m, n)) +# print(my_model.forward(m, n)) + + +@torch.library.register_fake("torchtrt_ex::elementwise_add") +def _(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + +import torch_tensorrt as torchtrt + + +with torchtrt.logging.info(): + model_trt = torchtrt.compile(my_model, inputs=[m, n], debug=True, min_block_size=1) + res = model_trt(m, n) + print(res) \ No newline at end of file diff --git a/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py b/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py new file mode 100644 index 0000000000..9319b9d045 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py @@ -0,0 +1,151 @@ +import tensorrt as trt + + + +class CustomPlugin(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild, trt.IPluginV3OneRuntime): # type: ignore[misc] + def __init__( + self, plugin_name : str, fc = None, phase = None + ): + # TODO: needs an additional passed in arguments to specify the needs for each plugin + # such as the one here: https://github.com/NVIDIA/TensorRT/blob/40efe7e9f2492657bbc455c4e2876e2ec792b812/samples/python/python_plugin/circ_pad_plugin_multi_tactic.py#L83 + trt.IPluginV3.__init__(self) + trt.IPluginV3OneCore.__init__(self) + trt.IPluginV3OneBuild.__init__(self) + trt.IPluginV3OneRuntime.__init__(self) + + # + # setattr(, ) + # self.pads = [] + # self.X_shape: List[int] = [] + + self.num_outputs = 1 # Defined by schema + self.plugin_namespace = "" + self.plugin_name = plugin_name + self.plugin_version = "1" + + # + # ex. + # TODO: need to parse the field collection here + # if fc is not None: + # assert fc[0].name == "pads" + # self.pads = fc[0].data + + if phase is not None: + self.phase = phase + + def get_capability_interface(self, type): + return self + + def get_output_datatypes( + self, input_types: List[trt.DataType] + ) -> trt.DataType: + # WE CAN USE THE FAKE TENSOR IMPLEMENTATION TO FIGURE OUT THE EXPECTED OUTPUT DATA TYPE + # with torch.fake_tensor(): + # + # fake_outputs = torch.ops..(*fake_inputs) + + # return fake_outputs[index] + + # The example case here is simple for experiment + return [input_types[0]] + + def get_output_shapes( + self, + output_index: int, + inputs: List[trt.DimsExprs], + exprBuilder: trt.IExprBuilder, + ) -> trt.DimsExprs: + + + # WE NEED TO FIND A WAY TO GO FROM FAKE TENSOR IMPL TO CONSTRUCTING A DIMSEXPR + # THIS IS SOLVED IN SHAPE PROP IN PYTORCH WHERE SHAPE PROP CAN GIVE SYMINTS THAT ENCODE THE + # SHAPE MAP. + output_shape = trt.DimsExprs(inputs[0]) + + return [output_shape] + + def get_fields_to_serialize(self): + # should be passed in as another argument + return trt.PluginFieldCollection([ + trt.PluginField("pads", self.pads, trt.PluginFieldType.INT32) + ]) + + def configure_plugin(self, inp, out): + pass + + def on_shape_change(self, inp, out): + X_dims = inp[0].dims + self.X_shape = np.zeros((len(X_dims),)) + for i in range(len(X_dims)): + self.X_shape[i] = X_dims[i] + + def supports_format_combination(self, pos, in_out, num_inputs): + assert num_inputs == 1 + assert pos < len(in_out) + + desc = in_out[pos].desc + if desc.format != trt.TensorFormat.LINEAR: + return False + + # first input should be float16 or float32 + if pos == 0: + return desc.type == trt.DataType.FLOAT or desc.type == trt.DataType.HALF + + # output should have the same type as the input + if pos == 1: + return in_out[0].desc.type == desc.type + + assert False + + + def enqueue( + self, + input_desc: List[trt.PluginTensorDesc], + output_desc: List[trt.PluginTensorDesc], + inputs: List[int], + outputs: List[int], + workspace: int, + stream: int, + ) -> None: + ... + + def attach_to_context(self, context): + return self.clone() + + def get_valid_tactics(self): + return [int(Tactic.TORCH), int(Tactic.TRITON)] + + def set_tactic(self, tactic): + self.tactic = Tactic(tactic) + + if self.phase == trt.TensorRTPhase.RUNTIME: + logger.info(f"Best tactic chosen: {self.tactic}") + + def clone(self) -> Self: + # + + +class PluginCreator(trt.IPluginCreatorV3One): # type: ignore[misc] + def __init__(self, plugin_name : str, plugin_field_names : trt.PluginFieldCollection): + super().__init__() + + self.name = plugin_name + self.plugin_namespace = "" + self.plugin_version = "1" + self.field_names = plugin_field_names + + def create_plugin( + self, name: str, field_collection: trt.PluginFieldCollection_ + ) -> CustomPlugin: + return CustomPlugin(field_collection) + + +# Looks like deserilaize required? Not found in the example here: https://github.com/NVIDIA/TensorRT/blob/main/samples/python/python_plugin/circ_pad_plugin_multi_tactic.py + # def deserialize_plugin(self, name: str, data: bytes) -> CircularPaddingPlugin: + # dict = pkl.loads(data) + # deserialized = () + # deserialized.__dict__.update(dict) + # return deserialized + +TRT_PLUGIN_REGISTRY = trt.get_plugin_registry() +TRT_PLUGIN_REGISTRY.register_creator(CircularPaddingPluginCreator(), "") \ No newline at end of file diff --git a/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py new file mode 100644 index 0000000000..b9136d492d --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py @@ -0,0 +1,47 @@ +import logging +from typing import Dict, Sequence, Tuple, Union + +import torch +from torch.fx.node import Argument, Target +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo._SourceIR import SourceIR +from torch_tensorrt.dynamo.conversion import impl +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + dynamo_tensorrt_converter, +) +from torch_tensorrt.fx.types import TRTTensor + +logger = logging.getLogger(__name__) + +@dynamo_tensorrt_converter(torch.ops.torchtrt_ex.elementwise_add.default) +def torchtrt_ex_elementwise_add( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +): + logger.debug(f"plugin stuff here2") + return torch.add(args) + + # How to retrieve a plugin if it is defined elsewhere (e.g. linked library) + # plugin_registry = trt.get_plugin_registry() + # plugin_creator = plugin_registry.get_plugin_creator( + # type="", version="1", plugin_namespace="" + # ) + # assert plugin_creator, f"Unable to find creator" + + # # Pass configurations to the plugin implementation + # field_configs = + # plugin = plugin_creator.create_plugin(name=name, field_collection=field_configs) + # assert plugin, "Unable to create " + + # + # + # + + # return layer.get_output(0) + + +# 1. generate plugin for any pytorch op \ No newline at end of file From eae249954e4ab4af296550aa5e834c7361cbe373 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Sat, 2 Nov 2024 02:16:37 +0000 Subject: [PATCH 2/8] update --- .../dynamo/conversion/__init__.py | 2 +- .../conversion/plugin/plugin_generator.py | 131 +++++++++++++++--- 2 files changed, 116 insertions(+), 17 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/__init__.py b/py/torch_tensorrt/dynamo/conversion/__init__.py index 5351f02bb6..235d1456b0 100644 --- a/py/torch_tensorrt/dynamo/conversion/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/__init__.py @@ -1,4 +1,4 @@ -from . import aten_ops_converters, ops_evaluators, prims_ops_converters +from . import aten_ops_converters, ops_evaluators, prims_ops_converters, plugin_ops_converters from ._conversion import convert_module, interpret_module_to_result from ._ConversionContext import ConversionContext from ._ConverterRegistry import * # noqa: F403 diff --git a/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py b/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py index 9319b9d045..21468ca9be 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py +++ b/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py @@ -1,19 +1,55 @@ import tensorrt as trt +import cupy as cp +import torch +import numpy as np +import logging +from enum import IntEnum + +logger = logging.getLogger("CustomPlugin") + + + +_numpy_to_plugin_field_type = { + np.dtype('int32'): trt.PluginFieldType.INT32, + np.dtype('int16'): trt.PluginFieldType.INT16, + np.dtype('int8'): trt.PluginFieldType.INT8, + np.dtype('bool'): trt.PluginFieldType.INT8, + np.dtype('int64'): trt.PluginFieldType.INT64, + np.dtype('float32'): trt.PluginFieldType.FLOAT32, + np.dtype('float64'): trt.PluginFieldType.FLOAT64, + np.dtype('float16'): trt.PluginFieldType.FLOAT16 +} + + +_built_in_to_plugin_field_type = { + int: trt.PluginFieldType.INT64, + float: trt.PluginFieldType.FLOAT64, + bool: trt.PluginFieldType.INT8, + # str is handled separately, so not needed here +} + +class Tactic(IntEnum): + TORCH = 1 + TRITON = 2 + class CustomPlugin(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild, trt.IPluginV3OneRuntime): # type: ignore[misc] def __init__( - self, plugin_name : str, fc = None, phase = None + self, plugin_name : str, attrs, phase = None ): # TODO: needs an additional passed in arguments to specify the needs for each plugin # such as the one here: https://github.com/NVIDIA/TensorRT/blob/40efe7e9f2492657bbc455c4e2876e2ec792b812/samples/python/python_plugin/circ_pad_plugin_multi_tactic.py#L83 trt.IPluginV3.__init__(self) + # Core capability, plugin attributes and behaviors common to both the build and runtime phases of a plugin’s lifetime trt.IPluginV3OneCore.__init__(self) + # Build capability, plugin attributes and behaviors that the plugin must exhibit for the TensorRT builder. trt.IPluginV3OneBuild.__init__(self) + # Runtime capability, plugin attributes and behaviors that the plugin must exhibit for it to be executable trt.IPluginV3OneRuntime.__init__(self) - # + # # setattr(, ) # self.pads = [] # self.X_shape: List[int] = [] @@ -21,7 +57,14 @@ def __init__( self.num_outputs = 1 # Defined by schema self.plugin_namespace = "" self.plugin_name = plugin_name - self.plugin_version = "1" + self.plugin_version = "1" + + # Set the timing cache ID to prevent unnecessary timing of second plugin instance + self.timing_cache_id = "" + + self.attrs = attrs + + self.tactic = None # # ex. @@ -66,18 +109,44 @@ def get_output_shapes( def get_fields_to_serialize(self): # should be passed in as another argument - return trt.PluginFieldCollection([ - trt.PluginField("pads", self.pads, trt.PluginFieldType.INT32) - ]) + field_names = [] + + for key, value in self.attrs.items(): + if isinstance(value, np.ndarray): + field_names.append( + trt.PluginField( + key, + value, + _numpy_to_plugin_field_type[np.dtype(value.dtype)], + ) + ) + elif isinstance(value, str): + field_names.append( + trt.PluginField(key, value.encode(), trt.PluginFieldType.CHAR) + ) + elif isinstance(value, bytes): + field_names.append( + trt.PluginField(key, value, trt.PluginFieldType.UNKNOWN) + ) + else: + field_names.append( + trt.PluginField( + key, + np.array([value]), + _built_in_to_plugin_field_type[type(value)], + ) + ) + + return trt.PluginFieldCollection(field_names) def configure_plugin(self, inp, out): pass - def on_shape_change(self, inp, out): - X_dims = inp[0].dims - self.X_shape = np.zeros((len(X_dims),)) - for i in range(len(X_dims)): - self.X_shape[i] = X_dims[i] + # def on_shape_change(self, inp, out): + # X_dims = inp[0].dims + # self.X_shape = np.zeros((len(X_dims),)) + # for i in range(len(X_dims)): + # self.X_shape[i] = X_dims[i] def supports_format_combination(self, pos, in_out, num_inputs): assert num_inputs == 1 @@ -102,12 +171,40 @@ def enqueue( self, input_desc: List[trt.PluginTensorDesc], output_desc: List[trt.PluginTensorDesc], - inputs: List[int], - outputs: List[int], + inputs, + outputs, workspace: int, stream: int, ) -> None: - ... + # input and output memory handling + input_mems = [None] * (len(inputs)) + + for i in range(len(inputs)): + input_mems[i] = cp.cuda.UnownedMemory(inputs[i], np.prod(input_desc[i].dims) * cp.dtype(trt.nptype(input_desc[i].type)).itemsize, self) + + output_mems = [None] * (len(outputs)) + + for i in range(len(outputs)): + output_mems[i] = cp.cuda.UnownedMemory(outputs[i], np.prod(output_desc[i].dims) * cp.dtype(trt.nptype(output_desc[i].type)).itemsize, self) + + + input_data = [None] * ((len(inputs))) + for i in range(len(inputs)): + input_data[i] = cp.ndarray(tuple(input_desc[i].dims), dtype=input_desc[i].type, memptr = cp.cuda.MemoryPointer(input_mems[i], 0)) + + output_data = [None] * ((len(outputs))) + for i in range(len(outputs)): + output_data[i] = cp.ndarray((np.prod(output_desc[i].dims)), dtype = output_desc[i].type, memptr = cp.cuda.MemoryPointer(output_mems[i], 0)) + + #TODO: This is just for a simple case for elementwise operations + # using Torch implementation for now + input_torch_0 = torch.as_tensor(input_data[0], device='cuda') + input_torch_1 = torch.as_tensor(input_data[1], device='cuda') + + output = torch.add(input_torch_0, input_torch_1) + + cp.copyto(output_data, output) + def attach_to_context(self, context): return self.clone() @@ -121,8 +218,10 @@ def set_tactic(self, tactic): if self.phase == trt.TensorRTPhase.RUNTIME: logger.info(f"Best tactic chosen: {self.tactic}") - def clone(self) -> Self: - # + def clone(self): + cloned_plugin = CustomPlugin(self.plugin_name, self.attrs) + cloned_plugin.__dict__.update(self.__dict__) + return cloned_plugin class PluginCreator(trt.IPluginCreatorV3One): # type: ignore[misc] From 702c149cf048c4014455e5e84a355efd7a3ceccf Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Wed, 6 Nov 2024 01:12:57 +0000 Subject: [PATCH 3/8] update --- .../dynamo/conversion/plugin/__init__.py | 1 + .../conversion/plugin/plugin_generator.py | 56 ++++++++++++++++--- .../conversion/plugin_ops_converters.py | 44 ++++++++++++--- 3 files changed, 85 insertions(+), 16 deletions(-) create mode 100644 py/torch_tensorrt/dynamo/conversion/plugin/__init__.py diff --git a/py/torch_tensorrt/dynamo/conversion/plugin/__init__.py b/py/torch_tensorrt/dynamo/conversion/plugin/__init__.py new file mode 100644 index 0000000000..016c091425 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/plugin/__init__.py @@ -0,0 +1 @@ +from .plugin_generator import PluginCreator \ No newline at end of file diff --git a/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py b/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py index 21468ca9be..4cd4c6f0c5 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py +++ b/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py @@ -10,8 +10,6 @@ logger = logging.getLogger("CustomPlugin") - - _numpy_to_plugin_field_type = { np.dtype('int32'): trt.PluginFieldType.INT32, np.dtype('int16'): trt.PluginFieldType.INT16, @@ -23,7 +21,6 @@ np.dtype('float16'): trt.PluginFieldType.FLOAT16 } - _built_in_to_plugin_field_type = { int: trt.PluginFieldType.INT64, float: trt.PluginFieldType.FLOAT64, @@ -225,18 +222,59 @@ def clone(self): class PluginCreator(trt.IPluginCreatorV3One): # type: ignore[misc] - def __init__(self, plugin_name : str, plugin_field_names : trt.PluginFieldCollection): - super().__init__() + def __init__(self, plugin_name : str, plugin_namespace : str, attrs): + trt.IPluginCreatorV3One.__init__(self) self.name = plugin_name - self.plugin_namespace = "" + self.plugin_namespace = plugin_namespace self.plugin_version = "1" - self.field_names = plugin_field_names + + field_names = [] + for name, (builtin, type_) in attrs.items(): + if builtin: + if type_ is str: + field_names.append( + trt.PluginField(name, b"", trt.PluginFieldType.CHAR) + ) + elif type_ is bytes: + field_names.append( + trt.PluginField(name, b"", trt.PluginFieldType.UNKNOWN) + ) + else: + field_names.append( + trt.PluginField( + name, np.array([]), _built_in_to_plugin_field_type[type_] + ) + ) + else: + field_names.append( + trt.PluginField( + name, np.array([]), _numpy_to_plugin_field_type[np.dtype(type_)] + ) + ) + + self.field_names = trt.PluginFieldCollection(field_names) def create_plugin( - self, name: str, field_collection: trt.PluginFieldCollection_ + self, name: str, fc, phase ) -> CustomPlugin: - return CustomPlugin(field_collection) + + + attrs = {} + # for f in fc: + # if f.name not in desc.input_attrs: + # raise AssertionError( + # f"Unexpected attribute {f.name} provided to create_plugin. Expected one of {desc.input_attrs.keys()}." + # ) + + # if _is_numpy_array(desc.input_attrs[f.name]): + # attrs[f.name] = f.data.astype(_infer_numpy_type(desc.input_attrs[f.name])) + # else: + # attrs[f.name] = desc.input_attrs[f.name](f.data) + + custom_plugin = CustomPlugin(name, attrs, fc) + + return custom_plugin # Looks like deserilaize required? Not found in the example here: https://github.com/NVIDIA/TensorRT/blob/main/samples/python/python_plugin/circ_pad_plugin_multi_tactic.py diff --git a/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py index b9136d492d..7aa8b4b5d1 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py @@ -11,9 +11,14 @@ dynamo_tensorrt_converter, ) from torch_tensorrt.fx.types import TRTTensor +from plugin import PluginCreator +import tensorrt as trt +from converter_utils import get_trt_tensor logger = logging.getLogger(__name__) +TRT_PLUGIN_REGISTRY = trt.get_plugin_registry() + @dynamo_tensorrt_converter(torch.ops.torchtrt_ex.elementwise_add.default) def torchtrt_ex_elementwise_add( ctx: ConversionContext, @@ -22,15 +27,17 @@ def torchtrt_ex_elementwise_add( kwargs: Dict[str, Argument], name: str, ): - logger.debug(f"plugin stuff here2") - return torch.add(args) + # logger.debug(f"plugin stuff here2") + # return torch.add(args) # How to retrieve a plugin if it is defined elsewhere (e.g. linked library) - # plugin_registry = trt.get_plugin_registry() - # plugin_creator = plugin_registry.get_plugin_creator( - # type="", version="1", plugin_namespace="" - # ) - # assert plugin_creator, f"Unable to find creator" + plugin_creator = PluginCreator("elementwise_add_plugin") + TRT_PLUGIN_REGISTRY.register_creator(plugin_creator, "") + + plugin_creator = TRT_PLUGIN_REGISTRY.get_plugin_creator( + type=plugin_creator, version="1", plugin_namespace="" + ) + assert plugin_creator, f"Unable to find elementwise_add_plugin creator" # # Pass configurations to the plugin implementation # field_configs = @@ -42,6 +49,29 @@ def torchtrt_ex_elementwise_add( # # return layer.get_output(0) + field_configs = trt.PluginFieldCollection([]) + + plugin = plugin_creator.create_plugin(name=name, field_collection=field_configs) + assert plugin, "Unable to create CircularPaddingPlugin" + + # input_tensor = args[ + # 0 + # ] # Arg 0 `torch.ops.torchtrt_ex.triton_circular_pad` is the input tensor + # if not isinstance(input_tensor, trt.ITensor): + # # Freeze input tensor if not TensorRT Tensor already + # input_tensor = get_trt_tensor(ctx, input_tensor, f"{name}_input") + + lhs_dtype = None + rhs_dtype = None + + lhs_val = get_trt_tensor(ctx, lhs_val, f"{name}_lhs", lhs_dtype) + rhs_val = get_trt_tensor(ctx, rhs_val, f"{name}_rhs", rhs_dtype) + + layer = ctx.net.add_plugin_v2( + [lhs_val, rhs_val], plugin + ) # Add the plugin to the network being constructed + layer.name = f"automatic-{name}" + return layer.get_output(0) # 1. generate plugin for any pytorch op \ No newline at end of file From ec1d50396003f9d5ef11d9e662f30f04d83286ad Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Sat, 9 Nov 2024 01:55:48 +0000 Subject: [PATCH 4/8] update --- .../conversion/plugin/plugin_generator.py | 38 +++++++++++-------- .../conversion/plugin_ops_converters.py | 16 ++++---- 2 files changed, 31 insertions(+), 23 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py b/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py index 4cd4c6f0c5..9cfbd6b1b9 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py +++ b/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py @@ -7,6 +7,8 @@ from enum import IntEnum +from typing import List + logger = logging.getLogger("CustomPlugin") @@ -62,6 +64,7 @@ def __init__( self.attrs = attrs self.tactic = None + # # ex. @@ -76,7 +79,7 @@ def __init__( def get_capability_interface(self, type): return self - def get_output_datatypes( + def get_output_data_types( self, input_types: List[trt.DataType] ) -> trt.DataType: # WE CAN USE THE FAKE TENSOR IMPLEMENTATION TO FIGURE OUT THE EXPECTED OUTPUT DATA TYPE @@ -91,18 +94,19 @@ def get_output_datatypes( def get_output_shapes( self, - output_index: int, inputs: List[trt.DimsExprs], + shape_inputs, exprBuilder: trt.IExprBuilder, ) -> trt.DimsExprs: + print(inputs) # WE NEED TO FIND A WAY TO GO FROM FAKE TENSOR IMPL TO CONSTRUCTING A DIMSEXPR # THIS IS SOLVED IN SHAPE PROP IN PYTORCH WHERE SHAPE PROP CAN GIVE SYMINTS THAT ENCODE THE # SHAPE MAP. - output_shape = trt.DimsExprs(inputs[0]) + output_dims = trt.DimsExprs(inputs[0]) - return [output_shape] + return [output_dims] def get_fields_to_serialize(self): # should be passed in as another argument @@ -139,13 +143,15 @@ def get_fields_to_serialize(self): def configure_plugin(self, inp, out): pass - # def on_shape_change(self, inp, out): - # X_dims = inp[0].dims - # self.X_shape = np.zeros((len(X_dims),)) - # for i in range(len(X_dims)): - # self.X_shape[i] = X_dims[i] + def on_shape_change(self, inp, out): + return + X_dims = inp[0].dims + self.X_shape = np.zeros((len(X_dims),)) + for i in range(len(X_dims)): + self.X_shape[i] = X_dims[i] def supports_format_combination(self, pos, in_out, num_inputs): + return assert num_inputs == 1 assert pos < len(in_out) @@ -198,7 +204,7 @@ def enqueue( input_torch_0 = torch.as_tensor(input_data[0], device='cuda') input_torch_1 = torch.as_tensor(input_data[1], device='cuda') - output = torch.add(input_torch_0, input_torch_1) + output = torch.ops.torchtrt_ex.elementwise_add(input_torch_0, input_torch_1) cp.copyto(output_data, output) @@ -212,8 +218,8 @@ def get_valid_tactics(self): def set_tactic(self, tactic): self.tactic = Tactic(tactic) - if self.phase == trt.TensorRTPhase.RUNTIME: - logger.info(f"Best tactic chosen: {self.tactic}") + # if self.phase == trt.TensorRTPhase.RUNTIME: + # logger.info(f"Best tactic chosen: {self.tactic}") def clone(self): cloned_plugin = CustomPlugin(self.plugin_name, self.attrs) @@ -256,7 +262,7 @@ def __init__(self, plugin_name : str, plugin_namespace : str, attrs): self.field_names = trt.PluginFieldCollection(field_names) def create_plugin( - self, name: str, fc, phase + self, name: str, field_collection, phase=None ) -> CustomPlugin: @@ -272,7 +278,7 @@ def create_plugin( # else: # attrs[f.name] = desc.input_attrs[f.name](f.data) - custom_plugin = CustomPlugin(name, attrs, fc) + custom_plugin = CustomPlugin(name, attrs) return custom_plugin @@ -284,5 +290,5 @@ def create_plugin( # deserialized.__dict__.update(dict) # return deserialized -TRT_PLUGIN_REGISTRY = trt.get_plugin_registry() -TRT_PLUGIN_REGISTRY.register_creator(CircularPaddingPluginCreator(), "") \ No newline at end of file +# TRT_PLUGIN_REGISTRY = trt.get_plugin_registry() +# TRT_PLUGIN_REGISTRY.register_creator(CircularPaddingPluginCreator(), "") \ No newline at end of file diff --git a/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py index 7aa8b4b5d1..f87c129265 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py @@ -11,9 +11,9 @@ dynamo_tensorrt_converter, ) from torch_tensorrt.fx.types import TRTTensor -from plugin import PluginCreator +from torch_tensorrt.dynamo.conversion.plugin import PluginCreator import tensorrt as trt -from converter_utils import get_trt_tensor +from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor logger = logging.getLogger(__name__) @@ -31,11 +31,11 @@ def torchtrt_ex_elementwise_add( # return torch.add(args) # How to retrieve a plugin if it is defined elsewhere (e.g. linked library) - plugin_creator = PluginCreator("elementwise_add_plugin") + plugin_creator = PluginCreator("elementwise_add_plugin", plugin_namespace="", attrs={}) TRT_PLUGIN_REGISTRY.register_creator(plugin_creator, "") plugin_creator = TRT_PLUGIN_REGISTRY.get_plugin_creator( - type=plugin_creator, version="1", plugin_namespace="" + type="elementwise_add_plugin", version="1", plugin_namespace="" ) assert plugin_creator, f"Unable to find elementwise_add_plugin creator" @@ -63,14 +63,16 @@ def torchtrt_ex_elementwise_add( lhs_dtype = None rhs_dtype = None + lhs_val = args[0] + rhs_val = args[1] lhs_val = get_trt_tensor(ctx, lhs_val, f"{name}_lhs", lhs_dtype) rhs_val = get_trt_tensor(ctx, rhs_val, f"{name}_rhs", rhs_dtype) - layer = ctx.net.add_plugin_v2( - [lhs_val, rhs_val], plugin + layer = ctx.net.add_plugin_v3( + [lhs_val, rhs_val], [], plugin ) # Add the plugin to the network being constructed - layer.name = f"automatic-{name}" + # layer.name = f"automatic-{name}" return layer.get_output(0) From 764acad121cbda825d66a1744c7a627aef0a7253 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Wed, 13 Nov 2024 02:16:11 +0000 Subject: [PATCH 5/8] support first example --- .../dynamo/conversion/plugin_ops_converters.py | 11 ++++++++++- setup.py | 2 ++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py index f87c129265..4b8f4b4311 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/plugin_ops_converters.py @@ -51,7 +51,7 @@ def torchtrt_ex_elementwise_add( # return layer.get_output(0) field_configs = trt.PluginFieldCollection([]) - plugin = plugin_creator.create_plugin(name=name, field_collection=field_configs) + plugin = plugin_creator.create_plugin(name="elementwise_add_plugin", field_collection=field_configs) assert plugin, "Unable to create CircularPaddingPlugin" # input_tensor = args[ @@ -66,6 +66,15 @@ def torchtrt_ex_elementwise_add( lhs_val = args[0] rhs_val = args[1] + if isinstance(lhs_val, TRTTensor): + lhs_dtype = lhs_val.dtype + # is_lhs_trt_tensor = True + if isinstance(rhs_val, TRTTensor): + rhs_dtype = rhs_val.dtype + # is_rhs_trt_tensor = True + + print(lhs_dtype) + lhs_val = get_trt_tensor(ctx, lhs_val, f"{name}_lhs", lhs_dtype) rhs_val = get_trt_tensor(ctx, rhs_val, f"{name}_rhs", rhs_dtype) diff --git a/setup.py b/setup.py index 0b8f47fb6f..bd490ec1be 100644 --- a/setup.py +++ b/setup.py @@ -440,6 +440,7 @@ def run(self): "torch_tensorrt.dynamo.conversion.impl.normalization", "torch_tensorrt.dynamo.conversion.impl.slice", "torch_tensorrt.dynamo.conversion.impl.unary", + "torch_tensorrt.dynamo.conversion.plugin", "torch_tensorrt.dynamo.lowering", "torch_tensorrt.dynamo.lowering.passes", "torch_tensorrt.dynamo.partitioning", @@ -468,6 +469,7 @@ def run(self): "torch_tensorrt.dynamo.conversion.impl.normalization": "py/torch_tensorrt/dynamo/conversion/impl/normalization", "torch_tensorrt.dynamo.conversion.impl.slice": "py/torch_tensorrt/dynamo/conversion/impl/slice", "torch_tensorrt.dynamo.conversion.impl.unary": "py/torch_tensorrt/dynamo/conversion/impl/unary", + "torch_tensorrt.dynamo.conversion.plugin": "py/torch_tensorrt/dynamo/conversion/plugin", "torch_tensorrt.dynamo.lowering": "py/torch_tensorrt/dynamo/lowering", "torch_tensorrt.dynamo.lowering.passes": "py/torch_tensorrt/dynamo/lowering/passes", "torch_tensorrt.dynamo.partitioning": "py/torch_tensorrt/dynamo/partitioning", From f0b0a0f0191c487664c015387dba2fa183d7ec49 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Fri, 22 Nov 2024 01:17:40 +0000 Subject: [PATCH 6/8] remove some comments --- .../dynamo/conversion/plugin/plugin_generator.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py b/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py index 9cfbd6b1b9..56efe1c714 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py +++ b/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py @@ -282,13 +282,3 @@ def create_plugin( return custom_plugin - -# Looks like deserilaize required? Not found in the example here: https://github.com/NVIDIA/TensorRT/blob/main/samples/python/python_plugin/circ_pad_plugin_multi_tactic.py - # def deserialize_plugin(self, name: str, data: bytes) -> CircularPaddingPlugin: - # dict = pkl.loads(data) - # deserialized = () - # deserialized.__dict__.update(dict) - # return deserialized - -# TRT_PLUGIN_REGISTRY = trt.get_plugin_registry() -# TRT_PLUGIN_REGISTRY.register_creator(CircularPaddingPluginCreator(), "") \ No newline at end of file From 47b13d84b47cec8d0823a2c6a7abad3fab1fa1f1 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Sat, 30 Nov 2024 18:40:16 -0800 Subject: [PATCH 7/8] wip: Trying out the new plugin API for autogen'd plugins + converters Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- .pre-commit-config.yaml | 12 +- examples/dynamo/automatic_plugin/custom_op.py | 93 --- .../dynamo/automatic_plugin_generation.py | 150 +++++ .../dynamo/conversion/__init__.py | 2 +- .../dynamo/conversion/plugin/__init__.py | 1 - .../conversion/plugin/plugin_generator.py | 284 --------- .../dynamo/conversion/plugins/__init__.py | 2 + .../conversion/plugins/plugin_generator.py | 189 ++++++ pyproject.toml | 88 ++- setup.py | 4 +- uv.lock | 579 ++++++++---------- 11 files changed, 670 insertions(+), 734 deletions(-) delete mode 100644 examples/dynamo/automatic_plugin/custom_op.py create mode 100644 examples/dynamo/automatic_plugin_generation.py delete mode 100644 py/torch_tensorrt/dynamo/conversion/plugin/__init__.py delete mode 100644 py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py create mode 100644 py/torch_tensorrt/dynamo/conversion/plugins/__init__.py create mode 100644 py/torch_tensorrt/dynamo/conversion/plugins/plugin_generator.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 29fb7b4d65..9449d9d3ba 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,10 +23,10 @@ repos: - repo: https://github.com/keith/pre-commit-buildifier rev: 6.4.0 hooks: - - id: buildifier + - id: buildifier args: - --warnings=all - - id: buildifier-lint + - id: buildifier-lint - repo: https://github.com/abravalheri/validate-pyproject rev: v0.16 hooks: @@ -37,9 +37,9 @@ repos: - id: isort name: isort (python) - repo: https://github.com/pre-commit/mirrors-mypy - rev: 'v1.9.0' + rev: "v1.9.0" hooks: - - id: mypy + - id: mypy exclude: "^py/torch_tensorrt/fx|^examples|^tests|^py/torch_tensorrt/dynamo/_experimental|^tools|^docs|noxfile.py|setup.py|versions.py" - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. @@ -57,13 +57,13 @@ repos: - id: typos - repo: https://github.com/astral-sh/uv-pre-commit # uv version. - rev: 0.4.10 + rev: 0.5.5 hooks: # Update the uv lockfile - id: uv-lock - repo: local hooks: - - id: dont-commit-upstream + - id: dont-commit-upstream name: NVIDIA-INTERNAL check entry: "!NVIDIA-INTERNAL" exclude: "^.pre-commit-config.yaml" diff --git a/examples/dynamo/automatic_plugin/custom_op.py b/examples/dynamo/automatic_plugin/custom_op.py deleted file mode 100644 index 043c75d1e6..0000000000 --- a/examples/dynamo/automatic_plugin/custom_op.py +++ /dev/null @@ -1,93 +0,0 @@ -import triton -import triton.language as tl - -@triton.jit -def elementwise_add_kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr): - # Program ID determines the block of data each thread will process - pid = tl.program_id(0) - # Compute the range of elements that this thread block will work on - block_start = pid * BLOCK_SIZE - # Range of indices this thread will handle - offsets = block_start + tl.arange(0, BLOCK_SIZE) - # Load elements from the X and Y tensors - x_vals = tl.load(X + offsets) - y_vals = tl.load(Y + offsets) - # Perform the element-wise multiplication - z_vals = x_vals + y_vals - # Store the result in Z - tl.store(Z + offsets, z_vals) - - -import torch -from torch.library import custom_op - - -@custom_op("torchtrt_ex::elementwise_add", mutates_args=()) # type: ignore[misc] -def elementwise_add(X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor: - # Ensure the tensors are on the GPU - assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device." - assert X.shape == Y.shape, "Tensors must have the same shape." - - # Create output tensor - Z = torch.empty_like(X) - - # Define block size - BLOCK_SIZE = 1024 - - # Grid of programs - grid = lambda meta: (X.numel() // meta['BLOCK_SIZE'],) - - # Launch the kernel - elementwise_add_kernel[grid](X, Y, Z, BLOCK_SIZE=BLOCK_SIZE) - - return Z - - -# Using the module in PyTorch -# X = torch.randn(1024, device='cuda', requires_grad=True) -# Y = torch.randn(1024, device='cuda', requires_grad=True) -# X = torch.full((128, 128), 2, device='cuda',) -# Y = torch.full((128, 128), 2, device='cuda',) -# # elementwise_mul_op = ElementwiseMulModule() -# Z = torch.ops.torchtrt_ex.elementwise_add(X, Y) -# print(Z) -# print(X + Y) -# print(X) -# print(Y) -# print(Z) -# print(X+Y) -# Z.sum().backward() - - -from torch import nn - - -class MyModel(nn.Module): # type: ignore[misc] - def __init__(self): - super().__init__() - - def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - z = torch.mul(x, y) - res = torch.ops.torchtrt_ex.elementwise_add(x, z) - - return res - - -my_model = MyModel().to("cuda") -m = torch.full((64, 64), 2, device='cuda',) -n = torch.full((64, 64), 3, device='cuda',) -# print(torch.ops.torchtrt_ex.elementwise_add(m, n)) -# print(my_model.forward(m, n)) - - -@torch.library.register_fake("torchtrt_ex::elementwise_add") -def _(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - return x - -import torch_tensorrt as torchtrt - - -with torchtrt.logging.info(): - model_trt = torchtrt.compile(my_model, inputs=[m, n], debug=True, min_block_size=1) - res = model_trt(m, n) - print(res) \ No newline at end of file diff --git a/examples/dynamo/automatic_plugin_generation.py b/examples/dynamo/automatic_plugin_generation.py new file mode 100644 index 0000000000..8bbb75ebfa --- /dev/null +++ b/examples/dynamo/automatic_plugin_generation.py @@ -0,0 +1,150 @@ +import triton +import triton.language as tl + +from typing import Tuple +import torch_tensorrt + +@triton.jit +def elementwise_mul_kernel(X, Y, Z, BLOCK_SIZE: tl.constexpr): + # Program ID determines the block of data each thread will process + pid = tl.program_id(0) + # Compute the range of elements that this thread block will work on + block_start = pid * BLOCK_SIZE + # Range of indices this thread will handle + offsets = block_start + tl.arange(0, BLOCK_SIZE) + # Load elements from the X and Y tensors + x_vals = tl.load(X + offsets) + y_vals = tl.load(Y + offsets) + # Perform the element-wise multiplication + z_vals = x_vals * y_vals + # Store the result in Z + tl.store(Z + offsets, z_vals) + + +import torch +from torch.library import custom_op + +#@torch_tensorrt.dynamo.conversion.plugin.custom_op("torchtrt_ex::elementwise_mul", mutates_args=()) +@torch.library.custom_op("torchtrt_ex::elementwise_mul", mutates_args=()) # type: ignore[misc] +def elementwise_mul(X: torch.Tensor, Y: torch.Tensor, b: float=.2, a: int=2) -> torch.Tensor: + # Ensure the tensors are on the GPU + assert X.is_cuda and Y.is_cuda, "Tensors must be on CUDA device." + assert X.shape == Y.shape, "Tensors must have the same shape." + + # Create output tensor + Z = torch.empty_like(X) + + # Define block size + BLOCK_SIZE = 1024 + + # Grid of programs + grid = lambda meta: (X.numel() // meta['BLOCK_SIZE'],) + + # Launch the kernel + elementwise_mul_kernel[grid](X, Y, Z, BLOCK_SIZE=BLOCK_SIZE) + + return Z + +from torch import nn + + +class MyModel(nn.Module): # type: ignore[misc] + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + z = torch.add(x, y) + res = torch.ops.torchtrt_ex.elementwise_mul.default(x, z, a=1) + + return res + + +my_model = MyModel().to("cuda") +m = torch.full((64, 64), 2, device='cuda', dtype=torch.float) +n = torch.full((64, 64), 3, device='cuda', dtype=torch.float) + +def mksym(shape_env, value, source, dynamic_dim): + return shape_env.create_symintnode( + shape_env.create_symbol( + value, + source=source, + dynamic_dim=dynamic_dim, + ), + hint=value, + source=source, + ) + +@torch.library.register_fake("torchtrt_ex::elementwise_mul") +def _(x: torch.Tensor, y: torch.Tensor, b: float=.2, a: int=2) -> torch.Tensor: + return x + +import tensorrt_bindings.plugin as trtp +from torch._dynamo.source import LocalSource +from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv +from sympy import lambdify + +@trtp.register("torchtrt_ex::elementwise_mul") +def _(x: trtp.TensorDesc, y: trtp.TensorDesc, b: float, a: int) -> Tuple[trtp.TensorDesc]: + from torch._subclasses.fake_tensor import FakeTensorMode + from torch.fx.experimental.symbolic_shapes import ShapeEnv + from sympy import lambdify + shape_env = ShapeEnv() + fake_mode = FakeTensorMode(shape_env=shape_env) + sample_x = {f"x{i}": 5 for i in range(x.ndim)} + sample_y = {f"y{i}": 5 for i in range(y.ndim)} + syms_x = [mksym(shape_env, v, LocalSource(k), DimDynamic.DYNAMIC) for k,v in sample_x.items()] + syms_y = [mksym(shape_env, v, LocalSource(k), DimDynamic.DYNAMIC) for k,v in sample_y.items()] + with FakeTensorMode() as fake_mode: + fake_x = torch.randn(syms_x) + fake_y = torch.randn(syms_y) + z = torch.ops.torchtrt_ex.elementwise_mul(fake_x, fake_y, b, a) + + shape_calc_fns = [None] * x.ndim + for i in range(x.ndim): + shape_calc_fns[i] = lambdify((syms_x[i].node.expr, syms_y[i].node.expr), z.shape[i].node.expr, "math") + + out_desc = x.like() + for i in range(out_desc.ndim): + out_desc.shape_expr[i] = shape_calc_fns[i](x.shape_expr[i], y.shape_expr[i]) + return out_desc + + +@trtp.impl("torchtrt_ex::elementwise_mul") +def _(x: trtp.Tensor, y: trtp.Tensor, b: float, a: int, outputs: Tuple[trtp.Tensor], stream: int): + # This should be based on Torch schema + in_tensors = [ + torch.as_tensor(i, device="cuda") for i in (x, y) + ] # What is the right device?? + dest_tensors = [torch.as_tensor(o, device="cuda") for o in outputs] + + stream = torch.cuda.ExternalStream(stream) + with torch.cuda.stream(stream): + out_tensors = torch.ops.torchtrt_ex.elementwise_mul(*in_tensors, b, a) + [d.copy_(o) for (d, o) in zip(dest_tensors, out_tensors)] + +# @trtp.impl("torchtrt_ex::elementwise_mul") +# def _(x: trtp.Tensor, y: trtp.Tensor, b: float, a: int, outputs: Tuple[trtp.Tensor], stream: int): +# # Define block size +# BLOCK_SIZE = 1024 + +# # Grid of programs +# grid = lambda meta: (x.numel() // meta['BLOCK_SIZE'],) + +# x_t = torch.as_tensor(x, device="cuda") +# y_t = torch.as_tensor(y, device="cuda") +# z_t = torch.as_tensor(outputs[0], device="cuda") +# # Launch the kernel +# elementwise_mul_kernel[grid](x_t, y_t, z_t, BLOCK_SIZE=BLOCK_SIZE) + +_ = torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter("torchtrt_ex::elementwise_mul", supports_dynamic_shapes=True) + +from torch_tensorrt.dynamo.conversion._ConverterRegistry import DYNAMO_CONVERTERS + +import torch_tensorrt as torchtrt +import tensorrt as trt +with torchtrt.logging.errors(): + model_trt = torchtrt.compile(my_model, inputs=[m, n], debug=True, min_block_size=1) + for i in range(300): + res = model_trt(m, n) + print(res) + assert torch.allclose(res, my_model(m,n)) diff --git a/py/torch_tensorrt/dynamo/conversion/__init__.py b/py/torch_tensorrt/dynamo/conversion/__init__.py index 235d1456b0..62c2504b03 100644 --- a/py/torch_tensorrt/dynamo/conversion/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/__init__.py @@ -1,4 +1,4 @@ -from . import aten_ops_converters, ops_evaluators, prims_ops_converters, plugin_ops_converters +from . import aten_ops_converters, ops_evaluators, prims_ops_converters, plugins from ._conversion import convert_module, interpret_module_to_result from ._ConversionContext import ConversionContext from ._ConverterRegistry import * # noqa: F403 diff --git a/py/torch_tensorrt/dynamo/conversion/plugin/__init__.py b/py/torch_tensorrt/dynamo/conversion/plugin/__init__.py deleted file mode 100644 index 016c091425..0000000000 --- a/py/torch_tensorrt/dynamo/conversion/plugin/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .plugin_generator import PluginCreator \ No newline at end of file diff --git a/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py b/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py deleted file mode 100644 index 56efe1c714..0000000000 --- a/py/torch_tensorrt/dynamo/conversion/plugin/plugin_generator.py +++ /dev/null @@ -1,284 +0,0 @@ -import tensorrt as trt -import cupy as cp -import torch -import numpy as np - -import logging - - -from enum import IntEnum -from typing import List - - -logger = logging.getLogger("CustomPlugin") - -_numpy_to_plugin_field_type = { - np.dtype('int32'): trt.PluginFieldType.INT32, - np.dtype('int16'): trt.PluginFieldType.INT16, - np.dtype('int8'): trt.PluginFieldType.INT8, - np.dtype('bool'): trt.PluginFieldType.INT8, - np.dtype('int64'): trt.PluginFieldType.INT64, - np.dtype('float32'): trt.PluginFieldType.FLOAT32, - np.dtype('float64'): trt.PluginFieldType.FLOAT64, - np.dtype('float16'): trt.PluginFieldType.FLOAT16 -} - -_built_in_to_plugin_field_type = { - int: trt.PluginFieldType.INT64, - float: trt.PluginFieldType.FLOAT64, - bool: trt.PluginFieldType.INT8, - # str is handled separately, so not needed here -} - -class Tactic(IntEnum): - TORCH = 1 - TRITON = 2 - -class CustomPlugin(trt.IPluginV3, trt.IPluginV3OneCore, trt.IPluginV3OneBuild, trt.IPluginV3OneRuntime): # type: ignore[misc] - def __init__( - self, plugin_name : str, attrs, phase = None - ): - # TODO: needs an additional passed in arguments to specify the needs for each plugin - # such as the one here: https://github.com/NVIDIA/TensorRT/blob/40efe7e9f2492657bbc455c4e2876e2ec792b812/samples/python/python_plugin/circ_pad_plugin_multi_tactic.py#L83 - trt.IPluginV3.__init__(self) - # Core capability, plugin attributes and behaviors common to both the build and runtime phases of a plugin’s lifetime - trt.IPluginV3OneCore.__init__(self) - # Build capability, plugin attributes and behaviors that the plugin must exhibit for the TensorRT builder. - trt.IPluginV3OneBuild.__init__(self) - # Runtime capability, plugin attributes and behaviors that the plugin must exhibit for it to be executable - trt.IPluginV3OneRuntime.__init__(self) - - # - # setattr(, ) - # self.pads = [] - # self.X_shape: List[int] = [] - - self.num_outputs = 1 # Defined by schema - self.plugin_namespace = "" - self.plugin_name = plugin_name - self.plugin_version = "1" - - # Set the timing cache ID to prevent unnecessary timing of second plugin instance - self.timing_cache_id = "" - - self.attrs = attrs - - self.tactic = None - - - # - # ex. - # TODO: need to parse the field collection here - # if fc is not None: - # assert fc[0].name == "pads" - # self.pads = fc[0].data - - if phase is not None: - self.phase = phase - - def get_capability_interface(self, type): - return self - - def get_output_data_types( - self, input_types: List[trt.DataType] - ) -> trt.DataType: - # WE CAN USE THE FAKE TENSOR IMPLEMENTATION TO FIGURE OUT THE EXPECTED OUTPUT DATA TYPE - # with torch.fake_tensor(): - # - # fake_outputs = torch.ops..(*fake_inputs) - - # return fake_outputs[index] - - # The example case here is simple for experiment - return [input_types[0]] - - def get_output_shapes( - self, - inputs: List[trt.DimsExprs], - shape_inputs, - exprBuilder: trt.IExprBuilder, - ) -> trt.DimsExprs: - - print(inputs) - - # WE NEED TO FIND A WAY TO GO FROM FAKE TENSOR IMPL TO CONSTRUCTING A DIMSEXPR - # THIS IS SOLVED IN SHAPE PROP IN PYTORCH WHERE SHAPE PROP CAN GIVE SYMINTS THAT ENCODE THE - # SHAPE MAP. - output_dims = trt.DimsExprs(inputs[0]) - - return [output_dims] - - def get_fields_to_serialize(self): - # should be passed in as another argument - field_names = [] - - for key, value in self.attrs.items(): - if isinstance(value, np.ndarray): - field_names.append( - trt.PluginField( - key, - value, - _numpy_to_plugin_field_type[np.dtype(value.dtype)], - ) - ) - elif isinstance(value, str): - field_names.append( - trt.PluginField(key, value.encode(), trt.PluginFieldType.CHAR) - ) - elif isinstance(value, bytes): - field_names.append( - trt.PluginField(key, value, trt.PluginFieldType.UNKNOWN) - ) - else: - field_names.append( - trt.PluginField( - key, - np.array([value]), - _built_in_to_plugin_field_type[type(value)], - ) - ) - - return trt.PluginFieldCollection(field_names) - - def configure_plugin(self, inp, out): - pass - - def on_shape_change(self, inp, out): - return - X_dims = inp[0].dims - self.X_shape = np.zeros((len(X_dims),)) - for i in range(len(X_dims)): - self.X_shape[i] = X_dims[i] - - def supports_format_combination(self, pos, in_out, num_inputs): - return - assert num_inputs == 1 - assert pos < len(in_out) - - desc = in_out[pos].desc - if desc.format != trt.TensorFormat.LINEAR: - return False - - # first input should be float16 or float32 - if pos == 0: - return desc.type == trt.DataType.FLOAT or desc.type == trt.DataType.HALF - - # output should have the same type as the input - if pos == 1: - return in_out[0].desc.type == desc.type - - assert False - - - def enqueue( - self, - input_desc: List[trt.PluginTensorDesc], - output_desc: List[trt.PluginTensorDesc], - inputs, - outputs, - workspace: int, - stream: int, - ) -> None: - # input and output memory handling - input_mems = [None] * (len(inputs)) - - for i in range(len(inputs)): - input_mems[i] = cp.cuda.UnownedMemory(inputs[i], np.prod(input_desc[i].dims) * cp.dtype(trt.nptype(input_desc[i].type)).itemsize, self) - - output_mems = [None] * (len(outputs)) - - for i in range(len(outputs)): - output_mems[i] = cp.cuda.UnownedMemory(outputs[i], np.prod(output_desc[i].dims) * cp.dtype(trt.nptype(output_desc[i].type)).itemsize, self) - - - input_data = [None] * ((len(inputs))) - for i in range(len(inputs)): - input_data[i] = cp.ndarray(tuple(input_desc[i].dims), dtype=input_desc[i].type, memptr = cp.cuda.MemoryPointer(input_mems[i], 0)) - - output_data = [None] * ((len(outputs))) - for i in range(len(outputs)): - output_data[i] = cp.ndarray((np.prod(output_desc[i].dims)), dtype = output_desc[i].type, memptr = cp.cuda.MemoryPointer(output_mems[i], 0)) - - #TODO: This is just for a simple case for elementwise operations - # using Torch implementation for now - input_torch_0 = torch.as_tensor(input_data[0], device='cuda') - input_torch_1 = torch.as_tensor(input_data[1], device='cuda') - - output = torch.ops.torchtrt_ex.elementwise_add(input_torch_0, input_torch_1) - - cp.copyto(output_data, output) - - - def attach_to_context(self, context): - return self.clone() - - def get_valid_tactics(self): - return [int(Tactic.TORCH), int(Tactic.TRITON)] - - def set_tactic(self, tactic): - self.tactic = Tactic(tactic) - - # if self.phase == trt.TensorRTPhase.RUNTIME: - # logger.info(f"Best tactic chosen: {self.tactic}") - - def clone(self): - cloned_plugin = CustomPlugin(self.plugin_name, self.attrs) - cloned_plugin.__dict__.update(self.__dict__) - return cloned_plugin - - -class PluginCreator(trt.IPluginCreatorV3One): # type: ignore[misc] - def __init__(self, plugin_name : str, plugin_namespace : str, attrs): - trt.IPluginCreatorV3One.__init__(self) - - self.name = plugin_name - self.plugin_namespace = plugin_namespace - self.plugin_version = "1" - - field_names = [] - for name, (builtin, type_) in attrs.items(): - if builtin: - if type_ is str: - field_names.append( - trt.PluginField(name, b"", trt.PluginFieldType.CHAR) - ) - elif type_ is bytes: - field_names.append( - trt.PluginField(name, b"", trt.PluginFieldType.UNKNOWN) - ) - else: - field_names.append( - trt.PluginField( - name, np.array([]), _built_in_to_plugin_field_type[type_] - ) - ) - else: - field_names.append( - trt.PluginField( - name, np.array([]), _numpy_to_plugin_field_type[np.dtype(type_)] - ) - ) - - self.field_names = trt.PluginFieldCollection(field_names) - - def create_plugin( - self, name: str, field_collection, phase=None - ) -> CustomPlugin: - - - attrs = {} - # for f in fc: - # if f.name not in desc.input_attrs: - # raise AssertionError( - # f"Unexpected attribute {f.name} provided to create_plugin. Expected one of {desc.input_attrs.keys()}." - # ) - - # if _is_numpy_array(desc.input_attrs[f.name]): - # attrs[f.name] = f.data.astype(_infer_numpy_type(desc.input_attrs[f.name])) - # else: - # attrs[f.name] = desc.input_attrs[f.name](f.data) - - custom_plugin = CustomPlugin(name, attrs) - - return custom_plugin - diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/__init__.py b/py/torch_tensorrt/dynamo/conversion/plugins/__init__.py new file mode 100644 index 0000000000..5f0ff5f8f9 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/plugins/__init__.py @@ -0,0 +1,2 @@ +from .plugin_generator import custom_op +from .plugin_generator import generate_plugin_converter diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/plugin_generator.py b/py/torch_tensorrt/dynamo/conversion/plugins/plugin_generator.py new file mode 100644 index 0000000000..bc50a92998 --- /dev/null +++ b/py/torch_tensorrt/dynamo/conversion/plugins/plugin_generator.py @@ -0,0 +1,189 @@ +import inspect +import logging +from enum import IntEnum +from types import FunctionType +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union + +import numpy as np + +# Seems like a bug in TensorRT +import tensorrt_bindings.plugin as trtp +from tensorrt_bindings.plugin._lib import QDP_REGISTRY +import torch +from torch._guards import detect_fake_mode +from torch._library.custom_ops import CustomOpDef, device_types_t +from torch._ops import OpOverload +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx.node import Argument, Node, Target, _get_qualified_name +from torch_tensorrt._enums import dtype +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext +from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + DYNAMO_CONVERTERS, + ConverterPriority, + ConverterSupport, + dynamo_tensorrt_converter, +) +from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor + +import tensorrt as trt + +_LOGGER: logging.Logger = logging.getLogger(__name__) + + +class Tactic(IntEnum): + TORCH = 1 + TRITON = 2 + + +def custom_op( + name: str, + fn: Optional[Callable] = None, + /, + *, + mutates_args: Union[str, Iterable[str]], + device_types: device_types_t = None, + schema: Optional[str] = None, + capability_validator: Optional[Callable[[Node, CompilationSettings], bool]] = None, + priority: ConverterPriority = ConverterPriority.STANDARD, + supports_dynamic_shapes: bool = False, +) -> Callable: + def inner(fn): + torch_custom_op_def = torch.library.custom_op( + name, mutates_args=mutates_args, device_types=device_types, schema=schema + )(fn) + + def _tensorrt_plugin_desc(args) -> Tuple[trtp.TensorDesc]: + print(args) + return (args[0].like(),) + + def tensorrt_plugin_desc( + in0: trtp.TensorDesc, in1: trtp.TensorDesc + ) -> Tuple[trtp.TensorDesc]: + return in0.like() + + tensorrt_plugin_reg = trtp.register(name)(tensorrt_plugin_desc) + print(tensorrt_plugin_reg) + + def _tensorrt_plugin_impl(args) -> None: + print(args) + + @trtp.impl(name) + def tensorrt_plugin_impl( + in0: trtp.Tensor, in1: trtp.Tensor, outputs: Tuple[trtp.Tensor], stream: int + ) -> None: + # This should be based on Torch schema + in_tensors = [ + torch.as_tensor(i, device="cuda") for i in (in0, in1) + ] # What is the right device?? + dest_tensors = [torch.as_tensor(o, device="cuda") for o in outputs] + + stream = torch.cuda.ExternalStream(stream) + with torch.cuda.stream(stream): + out_tensors = torch_custom_op_def._opoverload(*in_tensors) + [d.copy_(o) for (d, o) in zip(dest_tensors, out_tensors)] + + op_converter = generate_torch_op_converter( + torch_custom_op_def, capability_validator, priority, supports_dynamic_shapes + ) + + torch_custom_op_def._tensorrt_plugin_desc = tensorrt_plugin_desc + torch_custom_op_def._tensorrt_plugin_impl = tensorrt_plugin_impl + torch_custom_op_def._torch_tensorrt_converter = op_converter + + print(torch_custom_op_def._schema) + return torch_custom_op_def + + if fn is None: + return inner + + return inner(fn) + +def _generate_plugin_converter( + namespace: str, + op_name: str, + overload: Optional[str] = None, + capability_validator: Optional[Callable[[Node, CompilationSettings], bool]] = None, + priority: ConverterPriority = ConverterPriority.STANDARD, + supports_dynamic_shapes: bool = False, +): + torch_target = getattr(getattr(torch.ops, namespace), op_name) + overload_str = overload if overload else "" + overload_name = overload_str if overload else "default" + torch_overload = getattr(torch_target, overload_name) + assert f"{namespace}::{op_name}" in QDP_REGISTRY, f"Could not find a tensorrt plugin registered for op {namespace}::{op_name}, unable to generate converter" + torch_schema = torch_target._schemas[overload_str] + + def custom_kernel_converter( + ctx: ConversionContext, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, + ): + plugin = getattr(getattr(trtp.op, namespace), op_name) + tensor_inputs = plugin.input_tensor_names + tensor_args = args[0:len(tensor_inputs)] + itensor_args = [get_trt_tensor(ctx, t, f"{t_name}") for (t, t_name) in zip(tensor_args, tensor_inputs)] + + # Assuming TensorRT preserves kwargs order like PyTorch does + non_tensor_inputs = plugin.input_attrs + + non_tensor_args = args[len(tensor_inputs):] + non_tensor_kwargs = {k:v for k, v in zip(list(non_tensor_inputs.keys()), non_tensor_args)} + for (k,v) in non_tensor_kwargs.items(): + if isinstance(v, torch.fx.immutable_collections.immutable_list): + non_tensor_kwargs[k] = np.array(v) + + layer = ctx.net.add_plugin(plugin(*itensor_args, **non_tensor_kwargs)) + assert ( + layer + ), f"{namespace}::{name} plugin layer was not able to be created" + _LOGGER.debug( + f"Adding generated plugin for {namespace}::{name} to tensorrt network" + ) + layer.name = f"[{target}]-[{name}]" + return layer.get_output(0) + + custom_kernel_converter = dynamo_tensorrt_converter( + torch_overload, + capability_validator=capability_validator, + priority=priority, + supports_dynamic_shapes=supports_dynamic_shapes, + )( + custom_kernel_converter + ) # type: ignore + assert torch_overload in DYNAMO_CONVERTERS, f"Generated dynamo converter for {namespace}::{name} did not get properly registered in the converter registry" + return custom_kernel_converter + + +def generate_torch_op_converter( + op_reg: CustomOpDef, + capability_validator: Optional[Callable[[Node, CompilationSettings], bool]] = None, + priority: ConverterPriority = ConverterPriority.STANDARD, + supports_dynamic_shapes: bool = False, +): + return _generate_plugin_converter( + op_reg._namespace, + op_reg._name, + capability_validator=capability_validator, + priority=priority, + supports_dynamic_shapes=supports_dynamic_shapes + ) + + + +def generate_plugin_converter( + plugin_id: str, + capability_validator: Optional[Callable[[Node, CompilationSettings], bool]] = None, + priority: ConverterPriority = ConverterPriority.STANDARD, + supports_dynamic_shapes: bool = False, +): + plugin_ns, plugin_name = plugin_id.split("::") + return _generate_plugin_converter( + plugin_ns, + plugin_name, + capability_validator=capability_validator, + priority=priority, + supports_dynamic_shapes=supports_dynamic_shapes + ) diff --git a/pyproject.toml b/pyproject.toml index 1284e458f4..4ff1cf3af2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ requires = [ "cffi>=1.15.1", "typing-extensions>=4.7.0", "future>=0.18.3", - "tensorrt-cu12==10.3.0", + "tensorrt-cu12==10.6.0", "torch>=2.6.0.dev,<2.7.0", "pybind11==2.6.2", "numpy", @@ -55,15 +55,32 @@ keywords = [ ] dependencies = [ "torch>=2.6.0.dev,<2.7.0", - "tensorrt-cu12==10.3.0", - "tensorrt-cu12-bindings==10.3.0", - "tensorrt-cu12-libs==10.3.0", + "tensorrt-cu12==10.6.0", + "tensorrt-cu12-bindings==10.6.0", + "tensorrt-cu12-libs==10.6.0", "packaging>=23", "numpy", "typing-extensions>=4.7.0", ] + dynamic = ["version"] +[dependency-groups] +dev = [ + "pre-commit>=2.20.0", + "black>=22.6.0", + "clang-format==14.0.6", + "typos", + "mypy", + "isort", + "ruff", + "pytest", + "pytest-xdist", + "parameterized>=0.2.0", + "expecttest==0.1.6", + "pyyaml", +] + [project.optional-dependencies] torchvision = [ "torchvision", @@ -83,29 +100,54 @@ package-dir = { "" = "py" } include-package-data = false [tool.uv] -dev-dependencies = [ - "pre-commit>=2.20.0", - "black>=22.6.0", - "clang-format==14.0.6", - "typos", - "mypy", - "isort", - "ruff", - "pytest", - "pytest-xdist", - "parameterized>=0.2.0", - "expecttest==0.1.6", - "pyyaml", -] - environments = ["sys_platform == 'linux'", "sys_platform == 'windows'"] -extra-index-url = [ - "https://download.pytorch.org/whl/nightly/cu124", # We are going to define the dev enviorment as latest supported CUDA, and allow CI to handle the others, change as needed +prerelease = "if-necessary-or-explicit" + + +[tool.uv.sources] +torch = [ + { index = "pytorch-nightly-cu126"}, +] +torchvision = [ + { index = "pytorch-nightly-cu126"}, ] -prerelease = "if-necessary-or-explicit" -index-strategy = "unsafe-best-match" +[[tool.uv.index]] +name = "pytorch-nightly-cu126" +url = "https://download.pytorch.org/whl/nightly/cu126" +explicit = false + +# [[tool.uv.index]] +# name = "pytorch-nightly-cu124" +# url = "https://download.pytorch.org/whl/nightly/cu124" +# explicit = true + +# [[tool.uv.index]] +# name = "pytorch-nightly-cu118" +# url = "https://download.pytorch.org/whl/nightly/cu118" +# explicit = true + +# [[tool.uv.index]] +# name = "pytorch-test-cu124" +# url = "https://download.pytorch.org/whl/test/cu124" +# explicit = false + +# [[tool.uv.index]] +# name = "pytorch-test-cu118" +# url = "https://download.pytorch.org/whl/test/cu118" +# explicit = false + +# [[tool.uv.index]] +# name = "pytorch-release-cu124" +# url = "https://download.pytorch.org/whl/cu124" +# explicit = false + +# [[tool.uv.index]] +# name = "pytorch-release-cu118" +# url = "https://download.pytorch.org/whl/cu118" +# explicit = false + [tool.ruff] # NOTE: Synchoronize the ignores with .flake8 diff --git a/setup.py b/setup.py index bd490ec1be..46114996e6 100644 --- a/setup.py +++ b/setup.py @@ -440,7 +440,7 @@ def run(self): "torch_tensorrt.dynamo.conversion.impl.normalization", "torch_tensorrt.dynamo.conversion.impl.slice", "torch_tensorrt.dynamo.conversion.impl.unary", - "torch_tensorrt.dynamo.conversion.plugin", + "torch_tensorrt.dynamo.conversion.plugins", "torch_tensorrt.dynamo.lowering", "torch_tensorrt.dynamo.lowering.passes", "torch_tensorrt.dynamo.partitioning", @@ -469,7 +469,7 @@ def run(self): "torch_tensorrt.dynamo.conversion.impl.normalization": "py/torch_tensorrt/dynamo/conversion/impl/normalization", "torch_tensorrt.dynamo.conversion.impl.slice": "py/torch_tensorrt/dynamo/conversion/impl/slice", "torch_tensorrt.dynamo.conversion.impl.unary": "py/torch_tensorrt/dynamo/conversion/impl/unary", - "torch_tensorrt.dynamo.conversion.plugin": "py/torch_tensorrt/dynamo/conversion/plugin", + "torch_tensorrt.dynamo.conversion.plugins": "py/torch_tensorrt/dynamo/conversion/plugins", "torch_tensorrt.dynamo.lowering": "py/torch_tensorrt/dynamo/lowering", "torch_tensorrt.dynamo.lowering.passes": "py/torch_tensorrt/dynamo/lowering/passes", "torch_tensorrt.dynamo.partitioning": "py/torch_tensorrt/dynamo/partitioning", diff --git a/uv.lock b/uv.lock index 493873c773..13a94bdb23 100644 --- a/uv.lock +++ b/uv.lock @@ -1,9 +1,11 @@ version = 1 requires-python = ">=3.9" resolution-markers = [ - "python_full_version < '3.13' and sys_platform == 'linux'", + "python_full_version < '3.12' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and sys_platform == 'linux'", "python_full_version >= '3.13' and sys_platform == 'linux'", - "python_full_version < '3.13' and sys_platform == 'windows'", + "python_full_version < '3.12' and sys_platform == 'windows'", + "python_full_version == '3.12.*' and sys_platform == 'windows'", "python_full_version >= '3.13' and sys_platform == 'windows'", ] supported-markers = [ @@ -75,10 +77,9 @@ wheels = [ [[package]] name = "certifi" version = "2024.8.30" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/b0/ee/9b19140fe824b367c04c5e1b369942dd754c4c5462d5674002f75c4dedc1/certifi-2024.8.30.tar.gz", hash = "sha256:bec941d2aa8195e248a60b31ff9f0558284cf01a52591ceda73ea9afffd69fd9", size = 168507 } +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } wheels = [ - { url = "https://files.pythonhosted.org/packages/12/90/3c9ff0512038035f59d279fddeb79f5f1eccd8859f06d6163c58798b9487/certifi-2024.8.30-py3-none-any.whl", hash = "sha256:922820b53db7a7257ffbda3f597266d435245903d80737e34f8a45ff3e3230d8", size = 167321 }, + { url = "https://download.pytorch.org/whl/nightly/certifi-2024.8.30-py3-none-any.whl" }, ] [[package]] @@ -93,7 +94,7 @@ wheels = [ [[package]] name = "charset-normalizer" version = "3.3.2" -source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } wheels = [ { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, { url = "https://download.pytorch.org/whl/nightly/charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" }, @@ -155,7 +156,7 @@ wheels = [ [[package]] name = "colorama" version = "0.4.6" -source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } wheels = [ { url = "https://download.pytorch.org/whl/nightly/colorama-0.4.6-py2.py3-none-any.whl" }, ] @@ -248,25 +249,23 @@ wheels = [ [[package]] name = "filelock" version = "3.16.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/9d/db/3ef5bb276dae18d6ec2124224403d1d67bccdbefc17af4cc8f553e341ab1/filelock-3.16.1.tar.gz", hash = "sha256:c249fbfcd5db47e5e2d6d62198e565475ee65e4831e2561c8e313fa7eb961435", size = 18037 } +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b9/f8/feced7779d755758a52d1f6635d990b8d98dc0a29fa568bbe0625f18fdf3/filelock-3.16.1-py3-none-any.whl", hash = "sha256:2082e5703d51fbf98ea75855d9d5527e33d8ff23099bec374a134febee6946b0", size = 16163 }, + { url = "https://download.pytorch.org/whl/nightly/filelock-3.16.1-py3-none-any.whl" }, ] [[package]] name = "fsspec" -version = "2024.9.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/62/7c/12b0943011daaaa9c35c2a2e22e5eb929ac90002f08f1259d69aedad84de/fsspec-2024.9.0.tar.gz", hash = "sha256:4b0afb90c2f21832df142f292649035d80b421f60a9e1c027802e5a0da2b04e8", size = 286206 } +version = "2024.10.0" +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1d/a0/6aaea0c2fbea2f89bfd5db25fb1e3481896a423002ebe4e55288907a97a3/fsspec-2024.9.0-py3-none-any.whl", hash = "sha256:a0947d552d8a6efa72cc2c730b12c41d043509156966cca4fb157b0f2a0c574b", size = 179253 }, + { url = "https://download.pytorch.org/whl/nightly/fsspec-2024.10.0-py3-none-any.whl" }, ] [[package]] name = "huggingface-hub" -version = "0.25.0" -source = { registry = "https://pypi.org/simple" } +version = "0.25.1" +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } dependencies = [ { name = "filelock", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, { name = "fsspec", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, @@ -276,9 +275,8 @@ dependencies = [ { name = "tqdm", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, { name = "typing-extensions", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/81/18/52812091169325bf609feac958d5612e9c49788155a170352c7d55b6a74c/huggingface_hub-0.25.0.tar.gz", hash = "sha256:fb5fbe6c12fcd99d187ec7db95db9110fb1a20505f23040a5449a717c1a0db4d", size = 365666 } wheels = [ - { url = "https://files.pythonhosted.org/packages/d5/ce/1f8e61cd63175cc2e79233b954b1c4e85363c788fb3a1fa23c87a25c9b81/huggingface_hub-0.25.0-py3-none-any.whl", hash = "sha256:e2f357b35d72d5012cfd127108c4e14abcd61ba4ebc90a5a374dc2456cb34e12", size = 436429 }, + { url = "https://download.pytorch.org/whl/nightly/huggingface_hub-0.25.1-py3-none-any.whl" }, ] [[package]] @@ -293,22 +291,20 @@ wheels = [ [[package]] name = "idna" version = "3.10" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f1/70/7703c29685631f5a7590aa73f1f1d3fa9a380e654b86af429e0934a32f7d/idna-3.10.tar.gz", hash = "sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9", size = 190490 } +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } wheels = [ - { url = "https://files.pythonhosted.org/packages/76/c6/c88e154df9c4e1a2a66ccf0005a88dfb2650c1dffb6f5ce603dfbd452ce3/idna-3.10-py3-none-any.whl", hash = "sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3", size = 70442 }, + { url = "https://download.pytorch.org/whl/nightly/idna-3.10-py3-none-any.whl" }, ] [[package]] name = "importlib-metadata" -version = "8.5.0" -source = { registry = "https://pypi.org/simple" } +version = "7.1.0" +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } dependencies = [ { name = "zipp", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/cd/12/33e59336dca5be0c398a7482335911a33aa0e20776128f038019f1a95f1b/importlib_metadata-8.5.0.tar.gz", hash = "sha256:71522656f0abace1d072b9e5481a48f07c138e00f079c38c8f883823f9c26bd7", size = 55304 } wheels = [ - { url = "https://files.pythonhosted.org/packages/a0/d9/a1e041c5e7caa9a05c925f4bdbdfb7f006d1f74996af53467bc394c97be7/importlib_metadata-8.5.0-py3-none-any.whl", hash = "sha256:45e54197d28b7a7f1559e60b95e7c567032b602131fbd588f1497f47880aa68b", size = 26514 }, + { url = "https://download.pytorch.org/whl/nightly/importlib_metadata-7.1.0-py3-none-any.whl" }, ] [[package]] @@ -381,12 +377,13 @@ wheels = [ [[package]] name = "jinja2" version = "3.1.4" -source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } dependencies = [ { name = "markupsafe", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, ] wheels = [ { url = "https://download.pytorch.org/whl/nightly/Jinja2-3.1.4-py3-none-any.whl" }, + { url = "https://download.pytorch.org/whl/nightly/jinja2-3.1.4-py3-none-any.whl" }, ] [[package]] @@ -413,10 +410,10 @@ wheels = [ [[package]] name = "markupsafe" version = "2.1.5" -source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } sdist = { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5.tar.gz" } wheels = [ - { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e61659ba32cf2cf1481e575d0462554625196a1f2fc06a1c777d3f48e8865d46" }, { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" }, { url = "https://download.pytorch.org/whl/nightly/MarkupSafe-2.1.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, @@ -454,7 +451,7 @@ wheels = [ [[package]] name = "mpmath" version = "1.3.0" -source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } wheels = [ { url = "https://download.pytorch.org/whl/nightly/mpmath-1.3.0-py3-none-any.whl" }, ] @@ -484,7 +481,7 @@ wheels = [ [[package]] name = "mypy-extensions" version = "1.0.0" -source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } wheels = [ { url = "https://download.pytorch.org/whl/nightly/mypy_extensions-1.0.0-py3-none-any.whl" }, ] @@ -492,7 +489,7 @@ wheels = [ [[package]] name = "networkx" version = "3.3" -source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } wheels = [ { url = "https://download.pytorch.org/whl/nightly/networkx-3.3-py3-none-any.whl" }, ] @@ -527,7 +524,7 @@ wheels = [ [[package]] name = "numpy" version = "1.26.4" -source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } wheels = [ { url = "https://download.pytorch.org/whl/nightly/numpy-1.26.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, { url = "https://download.pytorch.org/whl/nightly/numpy-1.26.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, @@ -541,93 +538,99 @@ wheels = [ [[package]] name = "nvidia-cublas-cu12" -version = "12.4.5.8" -source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } +version = "12.6.3.3" +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } wheels = [ - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu126/nvidia_cublas_cu12-12.6.3.3-py3-none-manylinux2014_x86_64.whl" }, ] [[package]] name = "nvidia-cuda-cupti-cu12" -version = "12.4.127" -source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } +version = "12.6.80" +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } wheels = [ - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu126/nvidia_cuda_cupti_cu12-12.6.80-py3-none-manylinux2014_x86_64.whl" }, ] [[package]] name = "nvidia-cuda-nvrtc-cu12" -version = "12.4.127" -source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } +version = "12.6.77" +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } wheels = [ - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu126/nvidia_cuda_nvrtc_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl" }, ] [[package]] name = "nvidia-cuda-runtime-cu12" -version = "12.4.127" -source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } +version = "12.6.77" +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } wheels = [ - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu126/nvidia_cuda_runtime_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl" }, ] [[package]] name = "nvidia-cudnn-cu12" -version = "9.1.0.70" -source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } +version = "9.5.1.17" +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } dependencies = [ { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, ] wheels = [ - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu126/nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu126/nvidia_cudnn_cu12-9.5.1.17-py3-none-manylinux_2_28_x86_64.whl" }, ] [[package]] name = "nvidia-cufft-cu12" -version = "11.2.1.3" -source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } +version = "11.3.0.4" +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } +dependencies = [ + { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, +] wheels = [ - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu126/nvidia_cufft_cu12-11.3.0.4-py3-none-manylinux2014_x86_64.whl" }, ] [[package]] name = "nvidia-curand-cu12" -version = "10.3.5.147" -source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } +version = "10.3.7.77" +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } wheels = [ - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu126/nvidia_curand_cu12-10.3.7.77-py3-none-manylinux2014_x86_64.whl" }, ] [[package]] name = "nvidia-cusolver-cu12" -version = "11.6.1.9" -source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } +version = "11.7.1.2" +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } dependencies = [ { name = "nvidia-cublas-cu12", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, { name = "nvidia-cusparse-cu12", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, ] wheels = [ - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu126/nvidia_cusolver_cu12-11.7.1.2-py3-none-manylinux2014_x86_64.whl" }, ] [[package]] name = "nvidia-cusparse-cu12" -version = "12.3.1.170" -source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } +version = "12.5.4.2" +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } dependencies = [ { name = "nvidia-nvjitlink-cu12", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, ] wheels = [ - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu126/nvidia_cusparse_cu12-12.5.4.2-py3-none-manylinux2014_x86_64.whl" }, +] + +[[package]] +name = "nvidia-cusparselt-cu12" +version = "0.6.3" +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } +wheels = [ + { url = "https://download.pytorch.org/whl/nightly/cu126/nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu126/nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64%20(1).whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu126/nvidia_cusparselt_cu12-0.6.3-py3-none-manylinux2014_x86_64.whl" }, ] [[package]] @@ -675,36 +678,33 @@ torch = [ [[package]] name = "nvidia-nccl-cu12" version = "2.21.5" -source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } wheels = [ - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu126/nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl" }, ] [[package]] name = "nvidia-nvjitlink-cu12" -version = "12.4.127" -source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } +version = "12.6.77" +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } wheels = [ - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu126/nvidia_nvjitlink_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl" }, ] [[package]] name = "nvidia-nvtx-cu12" -version = "12.4.127" -source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } +version = "12.6.77" +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } wheels = [ - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/nvidia_nvtx_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu126/nvidia_nvtx_cu12-12.6.77-py3-none-manylinux2014_x86_64.whl" }, ] [[package]] name = "packaging" version = "24.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/51/65/50db4dda066951078f0a96cf12f4b9ada6e4b811516bf0262c0f4f7064d4/packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002", size = 148788 } +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } wheels = [ - { url = "https://files.pythonhosted.org/packages/08/aa/cc0199a5f0ad350994d660967a8efb233fe0416e4639146c089643407ce6/packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124", size = 53985 }, + { url = "https://download.pytorch.org/whl/nightly/packaging-24.1-py3-none-any.whl" }, ] [[package]] @@ -748,48 +748,31 @@ wheels = [ [[package]] name = "pillow" -version = "10.4.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/cd/74/ad3d526f3bf7b6d3f408b73fde271ec69dfac8b81341a318ce825f2b3812/pillow-10.4.0.tar.gz", hash = "sha256:166c1cd4d24309b30d61f79f4a9114b7b2313d7450912277855ff5dfd7cd4a06", size = 46555059 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/8a/25/1fc45761955f9359b1169aa75e241551e74ac01a09f487adaaf4c3472d11/pillow-10.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7928ecbf1ece13956b95d9cbcfc77137652b02763ba384d9ab508099a2eca856", size = 4332075 }, - { url = "https://files.pythonhosted.org/packages/5e/dd/425b95d0151e1d6c951f45051112394f130df3da67363b6bc75dc4c27aba/pillow-10.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e4d49b85c4348ea0b31ea63bc75a9f3857869174e2bf17e7aba02945cd218e6f", size = 4444808 }, - { url = "https://files.pythonhosted.org/packages/b1/84/9a15cc5726cbbfe7f9f90bfb11f5d028586595907cd093815ca6644932e3/pillow-10.4.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:6c762a5b0997f5659a5ef2266abc1d8851ad7749ad9a6a5506eb23d314e4f46b", size = 4356290 }, - { url = "https://files.pythonhosted.org/packages/b5/5b/6651c288b08df3b8c1e2f8c1152201e0b25d240e22ddade0f1e242fc9fa0/pillow-10.4.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:a985e028fc183bf12a77a8bbf36318db4238a3ded7fa9df1b9a133f1cb79f8fc", size = 4525163 }, - { url = "https://files.pythonhosted.org/packages/07/8b/34854bf11a83c248505c8cb0fcf8d3d0b459a2246c8809b967963b6b12ae/pillow-10.4.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:812f7342b0eee081eaec84d91423d1b4650bb9828eb53d8511bcef8ce5aecf1e", size = 4463100 }, - { url = "https://files.pythonhosted.org/packages/78/63/0632aee4e82476d9cbe5200c0cdf9ba41ee04ed77887432845264d81116d/pillow-10.4.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:ac1452d2fbe4978c2eec89fb5a23b8387aba707ac72810d9490118817d9c0b46", size = 4592880 }, - { url = "https://files.pythonhosted.org/packages/73/d5/c4011a76f4207a3c151134cd22a1415741e42fa5ddecec7c0182887deb3d/pillow-10.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5dc6761a6efc781e6a1544206f22c80c3af4c8cf461206d46a1e6006e4429ff3", size = 4340304 }, - { url = "https://files.pythonhosted.org/packages/ac/10/c67e20445a707f7a610699bba4fe050583b688d8cd2d202572b257f46600/pillow-10.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5e84b6cc6a4a3d76c153a6b19270b3526a5a8ed6b09501d3af891daa2a9de7d6", size = 4452804 }, - { url = "https://files.pythonhosted.org/packages/a9/83/6523837906d1da2b269dee787e31df3b0acb12e3d08f024965a3e7f64665/pillow-10.4.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:bbc527b519bd3aa9d7f429d152fea69f9ad37c95f0b02aebddff592688998abe", size = 4365126 }, - { url = "https://files.pythonhosted.org/packages/ba/e5/8c68ff608a4203085158cff5cc2a3c534ec384536d9438c405ed6370d080/pillow-10.4.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:76a911dfe51a36041f2e756b00f96ed84677cdeb75d25c767f296c1c1eda1319", size = 4533541 }, - { url = "https://files.pythonhosted.org/packages/f4/7c/01b8dbdca5bc6785573f4cee96e2358b0918b7b2c7b60d8b6f3abf87a070/pillow-10.4.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:59291fb29317122398786c2d44427bbd1a6d7ff54017075b22be9d21aa59bd8d", size = 4471616 }, - { url = "https://files.pythonhosted.org/packages/c8/57/2899b82394a35a0fbfd352e290945440e3b3785655a03365c0ca8279f351/pillow-10.4.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:416d3a5d0e8cfe4f27f574362435bc9bae57f679a7158e0096ad2beb427b8696", size = 4600802 }, - { url = "https://files.pythonhosted.org/packages/84/48/6e394b86369a4eb68b8a1382c78dc092245af517385c086c5094e3b34428/pillow-10.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:29dbdc4207642ea6aad70fbde1a9338753d33fb23ed6956e706936706f52dd80", size = 4343799 }, - { url = "https://files.pythonhosted.org/packages/3b/f3/a8c6c11fa84b59b9df0cd5694492da8c039a24cd159f0f6918690105c3be/pillow-10.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf2342ac639c4cf38799a44950bbc2dfcb685f052b9e262f446482afaf4bffca", size = 4459973 }, - { url = "https://files.pythonhosted.org/packages/7d/1b/c14b4197b80150fb64453585247e6fb2e1d93761fa0fa9cf63b102fde822/pillow-10.4.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:f5b92f4d70791b4a67157321c4e8225d60b119c5cc9aee8ecf153aace4aad4ef", size = 4370054 }, - { url = "https://files.pythonhosted.org/packages/55/77/40daddf677897a923d5d33329acd52a2144d54a9644f2a5422c028c6bf2d/pillow-10.4.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:86dcb5a1eb778d8b25659d5e4341269e8590ad6b4e8b44d9f4b07f8d136c414a", size = 4539484 }, - { url = "https://files.pythonhosted.org/packages/40/54/90de3e4256b1207300fb2b1d7168dd912a2fb4b2401e439ba23c2b2cabde/pillow-10.4.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:780c072c2e11c9b2c7ca37f9a2ee8ba66f44367ac3e5c7832afcfe5104fd6d1b", size = 4477375 }, - { url = "https://files.pythonhosted.org/packages/13/24/1bfba52f44193860918ff7c93d03d95e3f8748ca1de3ceaf11157a14cf16/pillow-10.4.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:37fb69d905be665f68f28a8bba3c6d3223c8efe1edf14cc4cfa06c241f8c81d9", size = 4608773 }, - { url = "https://files.pythonhosted.org/packages/46/2b/99c28c4379a85e65378211971c0b430d9c7234b1ec4d59b2668f6299e011/pillow-10.4.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bee197b30783295d2eb680b311af15a20a8b24024a19c3a26431ff83eb8d1f70", size = 4339837 }, - { url = "https://files.pythonhosted.org/packages/f1/74/b1ec314f624c0c43711fdf0d8076f82d9d802afd58f1d62c2a86878e8615/pillow-10.4.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ef61f5dd14c300786318482456481463b9d6b91ebe5ef12f405afbba77ed0be", size = 4455562 }, - { url = "https://files.pythonhosted.org/packages/4a/2a/4b04157cb7b9c74372fa867096a1607e6fedad93a44deeff553ccd307868/pillow-10.4.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:297e388da6e248c98bc4a02e018966af0c5f92dfacf5a5ca22fa01cb3179bca0", size = 4366761 }, - { url = "https://files.pythonhosted.org/packages/ac/7b/8f1d815c1a6a268fe90481232c98dd0e5fa8c75e341a75f060037bd5ceae/pillow-10.4.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:e4db64794ccdf6cb83a59d73405f63adbe2a1887012e308828596100a0b2f6cc", size = 4536767 }, - { url = "https://files.pythonhosted.org/packages/e5/77/05fa64d1f45d12c22c314e7b97398ffb28ef2813a485465017b7978b3ce7/pillow-10.4.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:bd2880a07482090a3bcb01f4265f1936a903d70bc740bfcb1fd4e8a2ffe5cf5a", size = 4477989 }, - { url = "https://files.pythonhosted.org/packages/12/63/b0397cfc2caae05c3fb2f4ed1b4fc4fc878f0243510a7a6034ca59726494/pillow-10.4.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:4b35b21b819ac1dbd1233317adeecd63495f6babf21b7b2512d244ff6c6ce309", size = 4610255 }, - { url = "https://files.pythonhosted.org/packages/60/a3/7ebbeabcd341eab722896d1a5b59a3df98c4b4d26cf4b0385f8aa94296f7/pillow-10.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:134ace6dc392116566980ee7436477d844520a26a4b1bd4053f6f47d096997fd", size = 4328295 }, - { url = "https://files.pythonhosted.org/packages/32/3f/c02268d0c6fb6b3958bdda673c17b315c821d97df29ae6969f20fb49388a/pillow-10.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:930044bb7679ab003b14023138b50181899da3f25de50e9dbee23b61b4de2126", size = 4440810 }, - { url = "https://files.pythonhosted.org/packages/67/5d/1c93c8cc35f2fdd3d6cc7e4ad72d203902859a2867de6ad957d9b708eb8d/pillow-10.4.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:c76e5786951e72ed3686e122d14c5d7012f16c8303a674d18cdcd6d89557fc5b", size = 4352283 }, - { url = "https://files.pythonhosted.org/packages/bc/a8/8655557c9c7202b8abbd001f61ff36711cefaf750debcaa1c24d154ef602/pillow-10.4.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:b2724fdb354a868ddf9a880cb84d102da914e99119211ef7ecbdc613b8c96b3c", size = 4521800 }, - { url = "https://files.pythonhosted.org/packages/58/78/6f95797af64d137124f68af1bdaa13b5332da282b86031f6fa70cf368261/pillow-10.4.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:dbc6ae66518ab3c5847659e9988c3b60dc94ffb48ef9168656e0019a93dbf8a1", size = 4459177 }, - { url = "https://files.pythonhosted.org/packages/8a/6d/2b3ce34f1c4266d79a78c9a51d1289a33c3c02833fe294ef0dcbb9cba4ed/pillow-10.4.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:06b2f7898047ae93fad74467ec3d28fe84f7831370e3c258afa533f81ef7f3df", size = 4589079 }, - { url = "https://files.pythonhosted.org/packages/d7/ac/4184edd511b14f760c73f5bb8a5d6fd85c591c8aff7c2229677a355c4179/pillow-10.4.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9f4727572e2918acaa9077c919cbbeb73bd2b3ebcfe033b72f858fc9fbef0026", size = 3435020 }, - { url = "https://files.pythonhosted.org/packages/da/21/1749cd09160149c0a246a81d646e05f35041619ce76f6493d6a96e8d1103/pillow-10.4.0-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ff25afb18123cea58a591ea0244b92eb1e61a1fd497bf6d6384f09bc3262ec3e", size = 3490539 }, - { url = "https://files.pythonhosted.org/packages/b6/f5/f71fe1888b96083b3f6dfa0709101f61fc9e972c0c8d04e9d93ccef2a045/pillow-10.4.0-pp310-pypy310_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:dc3e2db6ba09ffd7d02ae9141cfa0ae23393ee7687248d46a7507b75d610f4f5", size = 3476125 }, - { url = "https://files.pythonhosted.org/packages/96/b9/c0362c54290a31866c3526848583a2f45a535aa9d725fd31e25d318c805f/pillow-10.4.0-pp310-pypy310_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:02a2be69f9c9b8c1e97cf2713e789d4e398c751ecfd9967c18d0ce304efbf885", size = 3579373 }, - { url = "https://files.pythonhosted.org/packages/0a/22/492f9f61e4648422b6ca39268ec8139277a5b34648d28f400faac14e0f48/pillow-10.4.0-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b885f89040bb8c4a1573566bbb2f44f5c505ef6e74cec7ab9068c900047f04b", size = 3434958 }, - { url = "https://files.pythonhosted.org/packages/f9/19/559a48ad4045704bb0547965b9a9345f5cd461347d977a56d178db28819e/pillow-10.4.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87dd88ded2e6d74d31e1e0a99a726a6765cda32d00ba72dc37f0651f306daaa8", size = 3490340 }, - { url = "https://files.pythonhosted.org/packages/d9/de/cebaca6fb79905b3a1aa0281d238769df3fb2ede34fd7c0caa286575915a/pillow-10.4.0-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:2db98790afc70118bd0255c2eeb465e9767ecf1f3c25f9a1abb8ffc8cfd1fe0a", size = 3476048 }, - { url = "https://files.pythonhosted.org/packages/71/f0/86d5b2f04693b0116a01d75302b0a307800a90d6c351a8aa4f8ae76cd499/pillow-10.4.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:f7baece4ce06bade126fb84b8af1c33439a76d8a6fd818970215e0560ca28c27", size = 3579366 }, +version = "11.0.0" +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } +wheels = [ + { url = "https://download.pytorch.org/whl/nightly/pillow-11.0.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pillow-11.0.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pillow-11.0.0-cp310-cp310-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pillow-11.0.0-cp310-cp310-manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pillow-11.0.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pillow-11.0.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pillow-11.0.0-cp311-cp311-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pillow-11.0.0-cp311-cp311-manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pillow-11.0.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pillow-11.0.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pillow-11.0.0-cp312-cp312-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pillow-11.0.0-cp312-cp312-manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pillow-11.0.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pillow-11.0.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pillow-11.0.0-cp313-cp313-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pillow-11.0.0-cp313-cp313-manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pillow-11.0.0-cp313-cp313t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pillow-11.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pillow-11.0.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pillow-11.0.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pillow-11.0.0-cp39-cp39-manylinux_2_28_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pillow-11.0.0-cp39-cp39-manylinux_2_28_x86_64.whl" }, ] [[package]] @@ -841,12 +824,11 @@ wheels = [ [[package]] name = "psutil" version = "6.0.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/18/c7/8c6872f7372eb6a6b2e4708b88419fb46b857f7a2e1892966b851cc79fc9/psutil-6.0.0.tar.gz", hash = "sha256:8faae4f310b6d969fa26ca0545338b21f73c6b15db7c4a8d934a5482faa818f2", size = 508067 } +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } wheels = [ - { url = "https://files.pythonhosted.org/packages/35/56/72f86175e81c656a01c4401cd3b1c923f891b31fbcebe98985894176d7c9/psutil-6.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ed2440ada7ef7d0d608f20ad89a04ec47d2d3ab7190896cd62ca5fc4fe08bf0", size = 287478 }, - { url = "https://files.pythonhosted.org/packages/19/74/f59e7e0d392bc1070e9a70e2f9190d652487ac115bb16e2eff6b22ad1d24/psutil-6.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd9a97c8e94059b0ef54a7d4baf13b405011176c3b6ff257c247cae0d560ecd", size = 290455 }, - { url = "https://files.pythonhosted.org/packages/cd/5f/60038e277ff0a9cc8f0c9ea3d0c5eb6ee1d2470ea3f9389d776432888e47/psutil-6.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e2e8d0054fc88153ca0544f5c4d554d42e33df2e009c4ff42284ac9ebdef4132", size = 292046 }, + { url = "https://download.pytorch.org/whl/nightly/psutil-6.0.0-cp36-abi3-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl" }, + { url = "https://download.pytorch.org/whl/nightly/psutil-6.0.0-cp36-abi3-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/psutil-6.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, ] [[package]] @@ -991,124 +973,94 @@ wheels = [ [[package]] name = "pytorch-triton" -version = "3.1.0+5fe38ffd73" -source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } -dependencies = [ - { name = "filelock", marker = "(python_full_version < '3.13' and sys_platform == 'linux') or (python_full_version < '3.13' and sys_platform == 'windows')" }, -] -wheels = [ - { url = "https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2B5fe38ffd73-cp310-cp310-linux_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2B5fe38ffd73-cp311-cp311-linux_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2B5fe38ffd73-cp312-cp312-linux_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/pytorch_triton-3.1.0%2B5fe38ffd73-cp39-cp39-linux_x86_64.whl" }, +version = "3.2.0+git35c6c7c6" +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } +wheels = [ + { url = "https://download.pytorch.org/whl/nightly/pytorch_triton-3.2.0%2Bgit35c6c7c6-cp310-cp310-linux_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pytorch_triton-3.2.0%2Bgit35c6c7c6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pytorch_triton-3.2.0%2Bgit35c6c7c6-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pytorch_triton-3.2.0%2Bgit35c6c7c6-cp311-cp311-linux_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pytorch_triton-3.2.0%2Bgit35c6c7c6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pytorch_triton-3.2.0%2Bgit35c6c7c6-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pytorch_triton-3.2.0%2Bgit35c6c7c6-cp312-cp312-linux_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pytorch_triton-3.2.0%2Bgit35c6c7c6-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pytorch_triton-3.2.0%2Bgit35c6c7c6-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pytorch_triton-3.2.0%2Bgit35c6c7c6-cp313-cp313-linux_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pytorch_triton-3.2.0%2Bgit35c6c7c6-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pytorch_triton-3.2.0%2Bgit35c6c7c6-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pytorch_triton-3.2.0%2Bgit35c6c7c6-cp39-cp39-linux_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pytorch_triton-3.2.0%2Bgit35c6c7c6-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/pytorch_triton-3.2.0%2Bgit35c6c7c6-cp39-cp39-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl" }, ] [[package]] name = "pyyaml" version = "6.0.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/54/ed/79a089b6be93607fa5cdaedf301d7dfb23af5f25c398d5ead2525b063e17/pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e", size = 130631 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/49/ee/14c54df452143b9ee9f0f29074d7ca5516a36edb0b4cc40c3f280131656f/PyYAML-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237", size = 718463 }, - { url = "https://files.pythonhosted.org/packages/4d/61/de363a97476e766574650d742205be468921a7b532aa2499fcd886b62530/PyYAML-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b", size = 719280 }, - { url = "https://files.pythonhosted.org/packages/6b/4e/1523cb902fd98355e2e9ea5e5eb237cbc5f3ad5f3075fa65087aa0ecb669/PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed", size = 751239 }, - { url = "https://files.pythonhosted.org/packages/b7/33/5504b3a9a4464893c32f118a9cc045190a91637b119a9c881da1cf6b7a72/PyYAML-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180", size = 695802 }, - { url = "https://files.pythonhosted.org/packages/5c/20/8347dcabd41ef3a3cdc4f7b7a2aff3d06598c8779faa189cdbf878b626a4/PyYAML-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68", size = 720527 }, - { url = "https://files.pythonhosted.org/packages/ad/0c/c804f5f922a9a6563bab712d8dcc70251e8af811fce4524d57c2c0fd49a4/PyYAML-6.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c", size = 736829 }, - { url = "https://files.pythonhosted.org/packages/51/16/6af8d6a6b210c8e54f1406a6b9481febf9c64a3109c541567e35a49aa2e7/PyYAML-6.0.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317", size = 764167 }, - { url = "https://files.pythonhosted.org/packages/75/e4/2c27590dfc9992f73aabbeb9241ae20220bd9452df27483b6e56d3975cc5/PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85", size = 762952 }, - { url = "https://files.pythonhosted.org/packages/9b/97/ecc1abf4a823f5ac61941a9c00fe501b02ac3ab0e373c3857f7d4b83e2b6/PyYAML-6.0.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4", size = 735301 }, - { url = "https://files.pythonhosted.org/packages/45/73/0f49dacd6e82c9430e46f4a027baa4ca205e8b0a9dce1397f44edc23559d/PyYAML-6.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e", size = 756638 }, - { url = "https://files.pythonhosted.org/packages/c3/93/9916574aa8c00aa06bbac729972eb1071d002b8e158bd0e83a3b9a20a1f7/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5", size = 739154 }, - { url = "https://files.pythonhosted.org/packages/95/0f/b8938f1cbd09739c6da569d172531567dbcc9789e0029aa070856f123984/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425", size = 766223 }, - { url = "https://files.pythonhosted.org/packages/b9/2b/614b4752f2e127db5cc206abc23a8c19678e92b23c3db30fc86ab731d3bd/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476", size = 767542 }, - { url = "https://files.pythonhosted.org/packages/d4/00/dd137d5bcc7efea1836d6264f049359861cf548469d18da90cd8216cf05f/PyYAML-6.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48", size = 731164 }, - { url = "https://files.pythonhosted.org/packages/c9/1f/4f998c900485e5c0ef43838363ba4a9723ac0ad73a9dc42068b12aaba4e4/PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b", size = 756611 }, - { url = "https://files.pythonhosted.org/packages/7c/9a/337322f27005c33bcb656c655fa78325b730324c78620e8328ae28b64d0c/PyYAML-6.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133", size = 733428 }, - { url = "https://files.pythonhosted.org/packages/a3/69/864fbe19e6c18ea3cc196cbe5d392175b4cf3d5d0ac1403ec3f2d237ebb5/PyYAML-6.0.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484", size = 763361 }, - { url = "https://files.pythonhosted.org/packages/04/24/b7721e4845c2f162d26f50521b825fb061bc0a5afcf9a386840f23ea19fa/PyYAML-6.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5", size = 759523 }, - { url = "https://files.pythonhosted.org/packages/2b/b2/e3234f59ba06559c6ff63c4e10baea10e5e7df868092bf9ab40e5b9c56b6/PyYAML-6.0.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc", size = 726660 }, - { url = "https://files.pythonhosted.org/packages/fe/0f/25911a9f080464c59fab9027482f822b86bf0608957a5fcc6eaac85aa515/PyYAML-6.0.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652", size = 751597 }, - { url = "https://files.pythonhosted.org/packages/0e/9a/8cc68be846c972bda34f6c2a93abb644fb2476f4dcc924d52175786932c9/PyYAML-6.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290", size = 720891 }, - { url = "https://files.pythonhosted.org/packages/e9/6c/6e1b7f40181bc4805e2e07f4abc10a88ce4648e7e95ff1abe4ae4014a9b2/PyYAML-6.0.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12", size = 722614 }, - { url = "https://files.pythonhosted.org/packages/3d/32/e7bd8535d22ea2874cef6a81021ba019474ace0d13a4819c2a4bce79bd6a/PyYAML-6.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b1fdb9dc17f5a7677423d508ab4f243a726dea51fa5e70992e59a7411c89d19", size = 737360 }, - { url = "https://files.pythonhosted.org/packages/d7/12/7322c1e30b9be969670b672573d45479edef72c9a0deac3bb2868f5d7469/PyYAML-6.0.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0b69e4ce7a131fe56b7e4d770c67429700908fc0752af059838b1cfb41960e4e", size = 699006 }, - { url = "https://files.pythonhosted.org/packages/82/72/04fcad41ca56491995076630c3ec1e834be241664c0c09a64c9a2589b507/PyYAML-6.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a9f8c2e67970f13b16084e04f134610fd1d374bf477b17ec1599185cf611d725", size = 723577 }, +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } +wheels = [ + { url = "https://download.pytorch.org/whl/nightly/PyYAML-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/PyYAML-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/PyYAML-6.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/PyYAML-6.0.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/PyYAML-6.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/PyYAML-6.0.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/PyYAML-6.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/PyYAML-6.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290" }, + { url = "https://download.pytorch.org/whl/nightly/PyYAML-6.0.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/PyYAML-6.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, ] [[package]] name = "regex" version = "2024.9.11" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f9/38/148df33b4dbca3bd069b963acab5e0fa1a9dbd6820f8c322d0dd6faeff96/regex-2024.9.11.tar.gz", hash = "sha256:6c188c307e8433bcb63dc1915022deb553b4203a70722fc542c363bf120a01fd", size = 399403 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/88/87/1ce4a5357216b19b7055e7d3b0efc75a6e426133bf1e7d094321df514257/regex-2024.9.11-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:46989629904bad940bbec2106528140a218b4a36bb3042d8406980be1941429c", size = 783177 }, - { url = "https://files.pythonhosted.org/packages/3c/65/b9f002ab32f7b68e7d1dcabb67926f3f47325b8dbc22cc50b6a043e1d07c/regex-2024.9.11-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a906ed5e47a0ce5f04b2c981af1c9acf9e8696066900bf03b9d7879a6f679fc8", size = 823193 }, - { url = "https://files.pythonhosted.org/packages/22/91/8339dd3abce101204d246e31bc26cdd7ec07c9f91598472459a3a902aa41/regex-2024.9.11-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e9a091b0550b3b0207784a7d6d0f1a00d1d1c8a11699c1a4d93db3fbefc3ad35", size = 809950 }, - { url = "https://files.pythonhosted.org/packages/cb/19/556638aa11c2ec9968a1da998f07f27ec0abb9bf3c647d7c7985ca0b8eea/regex-2024.9.11-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5ddcd9a179c0a6fa8add279a4444015acddcd7f232a49071ae57fa6e278f1f71", size = 782661 }, - { url = "https://files.pythonhosted.org/packages/d1/e9/7a5bc4c6ef8d9cd2bdd83a667888fc35320da96a4cc4da5fa084330f53db/regex-2024.9.11-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6b41e1adc61fa347662b09398e31ad446afadff932a24807d3ceb955ed865cc8", size = 772348 }, - { url = "https://files.pythonhosted.org/packages/f1/0b/29f2105bfac3ed08e704914c38e93b07c784a6655f8a015297ee7173e95b/regex-2024.9.11-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:ced479f601cd2f8ca1fd7b23925a7e0ad512a56d6e9476f79b8f381d9d37090a", size = 697460 }, - { url = "https://files.pythonhosted.org/packages/71/3a/52ff61054d15a4722605f5872ad03962b319a04c1ebaebe570b8b9b7dde1/regex-2024.9.11-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:635a1d96665f84b292e401c3d62775851aedc31d4f8784117b3c68c4fcd4118d", size = 769151 }, - { url = "https://files.pythonhosted.org/packages/97/07/37e460ab5ca84be8e1e197c3b526c5c86993dcc9e13cbc805c35fc2463c1/regex-2024.9.11-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:c0256beda696edcf7d97ef16b2a33a8e5a875affd6fa6567b54f7c577b30a137", size = 777478 }, - { url = "https://files.pythonhosted.org/packages/65/7b/953075723dd5ab00780043ac2f9de667306ff9e2a85332975e9f19279174/regex-2024.9.11-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:3ce4f1185db3fbde8ed8aa223fc9620f276c58de8b0d4f8cc86fd1360829edb6", size = 845373 }, - { url = "https://files.pythonhosted.org/packages/40/b8/3e9484c6230b8b6e8f816ab7c9a080e631124991a4ae2c27a81631777db0/regex-2024.9.11-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:09d77559e80dcc9d24570da3745ab859a9cf91953062e4ab126ba9d5993688ca", size = 845369 }, - { url = "https://files.pythonhosted.org/packages/b7/99/38434984d912edbd2e1969d116257e869578f67461bd7462b894c45ed874/regex-2024.9.11-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:7a22ccefd4db3f12b526eccb129390942fe874a3a9fdbdd24cf55773a1faab1a", size = 773935 }, - { url = "https://files.pythonhosted.org/packages/b1/51/91a5ebdff17f9ec4973cb0aa9d37635efec1c6868654bbc25d1543aca4ec/regex-2024.9.11-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a4cc92bb6db56ab0c1cbd17294e14f5e9224f0cc6521167ef388332604e92679", size = 791779 }, - { url = "https://files.pythonhosted.org/packages/07/4a/022c5e6f0891a90cd7eb3d664d6c58ce2aba48bff107b00013f3d6167069/regex-2024.9.11-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d05ac6fa06959c4172eccd99a222e1fbf17b5670c4d596cb1e5cde99600674c4", size = 832605 }, - { url = "https://files.pythonhosted.org/packages/ac/1c/3793990c8c83ca04e018151ddda83b83ecc41d89964f0f17749f027fc44d/regex-2024.9.11-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:040562757795eeea356394a7fb13076ad4f99d3c62ab0f8bdfb21f99a1f85664", size = 818556 }, - { url = "https://files.pythonhosted.org/packages/e9/5c/8b385afbfacb853730682c57be56225f9fe275c5bf02ac1fc88edbff316d/regex-2024.9.11-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6113c008a7780792efc80f9dfe10ba0cd043cbf8dc9a76ef757850f51b4edc50", size = 792808 }, - { url = "https://files.pythonhosted.org/packages/9b/8b/a4723a838b53c771e9240951adde6af58c829fb6a6a28f554e8131f53839/regex-2024.9.11-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8e5fb5f77c8745a60105403a774fe2c1759b71d3e7b4ca237a5e67ad066c7199", size = 781115 }, - { url = "https://files.pythonhosted.org/packages/83/5f/031a04b6017033d65b261259c09043c06f4ef2d4eac841d0649d76d69541/regex-2024.9.11-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:54d9ff35d4515debf14bc27f1e3b38bfc453eff3220f5bce159642fa762fe5d4", size = 778155 }, - { url = "https://files.pythonhosted.org/packages/fd/cd/4660756070b03ce4a66663a43f6c6e7ebc2266cc6b4c586c167917185eb4/regex-2024.9.11-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:df5cbb1fbc74a8305b6065d4ade43b993be03dbe0f8b30032cced0d7740994bd", size = 784614 }, - { url = "https://files.pythonhosted.org/packages/93/8d/65b9bea7df120a7be8337c415b6d256ba786cbc9107cebba3bf8ff09da99/regex-2024.9.11-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:7fb89ee5d106e4a7a51bce305ac4efb981536301895f7bdcf93ec92ae0d91c7f", size = 853744 }, - { url = "https://files.pythonhosted.org/packages/96/a7/fba1eae75eb53a704475baf11bd44b3e6ccb95b316955027eb7748f24ef8/regex-2024.9.11-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:a738b937d512b30bf75995c0159c0ddf9eec0775c9d72ac0202076c72f24aa96", size = 855890 }, - { url = "https://files.pythonhosted.org/packages/45/14/d864b2db80a1a3358534392373e8a281d95b28c29c87d8548aed58813910/regex-2024.9.11-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:e28f9faeb14b6f23ac55bfbbfd3643f5c7c18ede093977f1df249f73fd22c7b1", size = 781887 }, - { url = "https://files.pythonhosted.org/packages/ca/fa/521eb683b916389b4975337873e66954e0f6d8f91bd5774164a57b503185/regex-2024.9.11-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee439691d8c23e76f9802c42a95cfeebf9d47cf4ffd06f18489122dbb0a7ad64", size = 795181 }, - { url = "https://files.pythonhosted.org/packages/28/db/63047feddc3280cc242f9c74f7aeddc6ee662b1835f00046f57d5630c827/regex-2024.9.11-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a8f877c89719d759e52783f7fe6e1c67121076b87b40542966c02de5503ace42", size = 835842 }, - { url = "https://files.pythonhosted.org/packages/e3/94/86adc259ff8ec26edf35fcca7e334566c1805c7493b192cb09679f9c3dee/regex-2024.9.11-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:23b30c62d0f16827f2ae9f2bb87619bc4fba2044911e2e6c2eb1af0161cdb766", size = 823533 }, - { url = "https://files.pythonhosted.org/packages/29/52/84662b6636061277cb857f658518aa7db6672bc6d1a3f503ccd5aefc581e/regex-2024.9.11-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85ab7824093d8f10d44330fe1e6493f756f252d145323dd17ab6b48733ff6c0a", size = 797037 }, - { url = "https://files.pythonhosted.org/packages/c3/2a/cd4675dd987e4a7505f0364a958bc41f3b84942de9efaad0ef9a2646681c/regex-2024.9.11-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8dee5b4810a89447151999428fe096977346cf2f29f4d5e29609d2e19e0199c9", size = 784106 }, - { url = "https://files.pythonhosted.org/packages/6f/75/3ea7ec29de0bbf42f21f812f48781d41e627d57a634f3f23947c9a46e303/regex-2024.9.11-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:98eeee2f2e63edae2181c886d7911ce502e1292794f4c5ee71e60e23e8d26b5d", size = 782468 }, - { url = "https://files.pythonhosted.org/packages/d3/67/15519d69b52c252b270e679cb578e22e0c02b8dd4e361f2b04efcc7f2335/regex-2024.9.11-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:57fdd2e0b2694ce6fc2e5ccf189789c3e2962916fb38779d3e3521ff8fe7a822", size = 790324 }, - { url = "https://files.pythonhosted.org/packages/9c/71/eff77d3fe7ba08ab0672920059ec30d63fa7e41aa0fb61c562726e9bd721/regex-2024.9.11-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:d552c78411f60b1fdaafd117a1fca2f02e562e309223b9d44b7de8be451ec5e0", size = 860214 }, - { url = "https://files.pythonhosted.org/packages/81/11/e1bdf84a72372e56f1ea4b833dd583b822a23138a616ace7ab57a0e11556/regex-2024.9.11-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:a0b2b80321c2ed3fcf0385ec9e51a12253c50f146fddb2abbb10f033fe3d049a", size = 859420 }, - { url = "https://files.pythonhosted.org/packages/ea/75/9753e9dcebfa7c3645563ef5c8a58f3a47e799c872165f37c55737dadd3e/regex-2024.9.11-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:18406efb2f5a0e57e3a5881cd9354c1512d3bb4f5c45d96d110a66114d84d23a", size = 787333 }, - { url = "https://files.pythonhosted.org/packages/b9/54/9fe8f9aec5007bbbbce28ba3d2e3eaca425f95387b7d1e84f0d137d25237/regex-2024.9.11-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eb1ae19e64c14c7ec1995f40bd932448713d3c73509e82d8cd7744dc00e29e86", size = 795337 }, - { url = "https://files.pythonhosted.org/packages/b2/e7/6b2f642c3cded271c4f16cc4daa7231be544d30fe2b168e0223724b49a61/regex-2024.9.11-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f47cd43a5bfa48f86925fe26fbdd0a488ff15b62468abb5d2a1e092a4fb10e85", size = 835848 }, - { url = "https://files.pythonhosted.org/packages/cd/9e/187363bdf5d8c0e4662117b92aa32bf52f8f09620ae93abc7537d96d3311/regex-2024.9.11-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9d4a76b96f398697fe01117093613166e6aa8195d63f1b4ec3f21ab637632963", size = 823503 }, - { url = "https://files.pythonhosted.org/packages/f8/10/601303b8ee93589f879664b0cfd3127949ff32b17f9b6c490fb201106c4d/regex-2024.9.11-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0ea51dcc0835eea2ea31d66456210a4e01a076d820e9039b04ae8d17ac11dee6", size = 797049 }, - { url = "https://files.pythonhosted.org/packages/ef/1c/ea200f61ce9f341763f2717ab4daebe4422d83e9fd4ac5e33435fd3a148d/regex-2024.9.11-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b7aaa315101c6567a9a45d2839322c51c8d6e81f67683d529512f5bcfb99c802", size = 784144 }, - { url = "https://files.pythonhosted.org/packages/d8/5c/d2429be49ef3292def7688401d3deb11702c13dcaecdc71d2b407421275b/regex-2024.9.11-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c57d08ad67aba97af57a7263c2d9006d5c404d721c5f7542f077f109ec2a4a29", size = 782483 }, - { url = "https://files.pythonhosted.org/packages/12/d9/cbc30f2ff7164f3b26a7760f87c54bf8b2faed286f60efd80350a51c5b99/regex-2024.9.11-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:f8404bf61298bb6f8224bb9176c1424548ee1181130818fcd2cbffddc768bed8", size = 790320 }, - { url = "https://files.pythonhosted.org/packages/19/1d/43ed03a236313639da5a45e61bc553c8d41e925bcf29b0f8ecff0c2c3f25/regex-2024.9.11-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:dd4490a33eb909ef5078ab20f5f000087afa2a4daa27b4c072ccb3cb3050ad84", size = 860435 }, - { url = "https://files.pythonhosted.org/packages/34/4f/5d04da61c7c56e785058a46349f7285ae3ebc0726c6ea7c5c70600a52233/regex-2024.9.11-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:eee9130eaad130649fd73e5cd92f60e55708952260ede70da64de420cdcad554", size = 859571 }, - { url = "https://files.pythonhosted.org/packages/12/7f/8398c8155a3c70703a8e91c29532558186558e1aea44144b382faa2a6f7a/regex-2024.9.11-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6a2644a93da36c784e546de579ec1806bfd2763ef47babc1b03d765fe560c9f8", size = 787398 }, - { url = "https://files.pythonhosted.org/packages/b4/21/feaa5b0d3e5e3bad659cd7d640e6b76cc0719504dbd9bc8f67cfa21bde82/regex-2024.9.11-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c69ada171c2d0e97a4b5aa78fbb835e0ffbb6b13fc5da968c09811346564f0d3", size = 782747 }, - { url = "https://files.pythonhosted.org/packages/bb/89/93516f0aa3e8a9366df2cf79bb0290abdc7dbe5dd27373d9bea0978b7ba6/regex-2024.9.11-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:02087ea0a03b4af1ed6ebab2c54d7118127fee8d71b26398e8e4b05b78963199", size = 822700 }, - { url = "https://files.pythonhosted.org/packages/d5/e7/79c04ccb81cee2831d9d4499274919b9153c1741ce8b3421d69cb0032f1b/regex-2024.9.11-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:69dee6a020693d12a3cf892aba4808fe168d2a4cef368eb9bf74f5398bfd4ee8", size = 809327 }, - { url = "https://files.pythonhosted.org/packages/01/e6/a7256c99c312b68f01cfd4f8eae6e770906fffb3832ecb66f35ca5b86b96/regex-2024.9.11-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:297f54910247508e6e5cae669f2bc308985c60540a4edd1c77203ef19bfa63ca", size = 781970 }, - { url = "https://files.pythonhosted.org/packages/18/c4/29e8b6ff2208775858b5d4a2caa6428d40b5fade95aee426de7e42ffff39/regex-2024.9.11-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ecea58b43a67b1b79805f1a0255730edaf5191ecef84dbc4cc85eb30bc8b63b9", size = 771885 }, - { url = "https://files.pythonhosted.org/packages/95/78/7acd8882ac335f1f5ae1756417739fda3053e0bcacea8716ae4a04e74553/regex-2024.9.11-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:eab4bb380f15e189d1313195b062a6aa908f5bd687a0ceccd47c8211e9cf0d4a", size = 696978 }, - { url = "https://files.pythonhosted.org/packages/cb/d2/1d44f9b4a3d33ff5773fd79bea53e992d00f81e0af6f1f4e2efac1e4d897/regex-2024.9.11-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:0cbff728659ce4bbf4c30b2a1be040faafaa9eca6ecde40aaff86f7889f4ab39", size = 768655 }, - { url = "https://files.pythonhosted.org/packages/79/ba/92ef9d3b8f59cb3df9febef07098dfb4a43c3bdcf35b1084c2009b0a93bf/regex-2024.9.11-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:54c4a097b8bc5bb0dfc83ae498061d53ad7b5762e00f4adaa23bee22b012e6ba", size = 776922 }, - { url = "https://files.pythonhosted.org/packages/16/71/d964c0c9d447f04bbe6ab5eafd220208e7d52b9608e452e6fcad553b38e0/regex-2024.9.11-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:73d6d2f64f4d894c96626a75578b0bf7d9e56dcda8c3d037a2118fdfe9b1c664", size = 845014 }, - { url = "https://files.pythonhosted.org/packages/83/cb/a378cdc2468782eefefa50183bbeabc3357fb588d4109d845f0a56e68713/regex-2024.9.11-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:e53b5fbab5d675aec9f0c501274c467c0f9a5d23696cfc94247e1fb56501ed89", size = 844916 }, - { url = "https://files.pythonhosted.org/packages/b9/f0/82ea1565a6639270cfe96263002b3d91084a1db5048d9b6084f83bd5972d/regex-2024.9.11-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:0ffbcf9221e04502fc35e54d1ce9567541979c3fdfb93d2c554f0ca583a19b35", size = 773409 }, +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } +wheels = [ + { url = "https://download.pytorch.org/whl/nightly/regex-2024.9.11-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/regex-2024.9.11-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" }, + { url = "https://download.pytorch.org/whl/nightly/regex-2024.9.11-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/regex-2024.9.11-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/regex-2024.9.11-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" }, + { url = "https://download.pytorch.org/whl/nightly/regex-2024.9.11-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/regex-2024.9.11-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/regex-2024.9.11-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" }, + { url = "https://download.pytorch.org/whl/nightly/regex-2024.9.11-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/regex-2024.9.11-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/regex-2024.9.11-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" }, + { url = "https://download.pytorch.org/whl/nightly/regex-2024.9.11-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/regex-2024.9.11-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" }, + { url = "https://download.pytorch.org/whl/nightly/regex-2024.9.11-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/regex-2024.9.11-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/regex-2024.9.11-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" }, + { url = "https://download.pytorch.org/whl/nightly/regex-2024.9.11-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/regex-2024.9.11-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" }, + { url = "https://download.pytorch.org/whl/nightly/regex-2024.9.11-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/regex-2024.9.11-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/regex-2024.9.11-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" }, + { url = "https://download.pytorch.org/whl/nightly/regex-2024.9.11-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/regex-2024.9.11-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" }, + { url = "https://download.pytorch.org/whl/nightly/regex-2024.9.11-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/regex-2024.9.11-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/regex-2024.9.11-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl" }, + { url = "https://download.pytorch.org/whl/nightly/regex-2024.9.11-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl" }, ] [[package]] name = "requests" version = "2.32.3" -source = { registry = "https://pypi.org/simple" } +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } dependencies = [ { name = "certifi", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, { name = "charset-normalizer", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, { name = "idna", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, { name = "urllib3", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/63/70/2bf7780ad2d390a8d301ad0b550f1581eadbd9a20f896afe06353c2a2913/requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760", size = 131218 } wheels = [ - { url = "https://files.pythonhosted.org/packages/f9/9b/335f9764261e915ed497fcdeb11df5dfd6f7bf257d4a6a2a686d80da4d54/requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6", size = 64928 }, + { url = "https://download.pytorch.org/whl/nightly/requests-2.32.3-py3-none-any.whl" }, ] [[package]] @@ -1152,59 +1104,38 @@ wheels = [ [[package]] name = "safetensors" version = "0.4.5" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/cb/46/a1c56ed856c6ac3b1a8b37abe5be0cac53219367af1331e721b04d122577/safetensors-0.4.5.tar.gz", hash = "sha256:d73de19682deabb02524b3d5d1f8b3aaba94c72f1bbfc7911b9b9d5d391c0310", size = 65702 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/8b/67/49556aeacc00df353767ed31d68b492fecf38c3f664c52692e4d92aa0032/safetensors-0.4.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6885016f34bef80ea1085b7e99b3c1f92cb1be78a49839203060f67b40aee761", size = 441382 }, - { url = "https://files.pythonhosted.org/packages/5d/ce/e9f4869a37bb11229e6cdb4e73a6ef23b4f360eee9dca5f7e40982779704/safetensors-0.4.5-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:133620f443450429322f238fda74d512c4008621227fccf2f8cf4a76206fea7c", size = 439001 }, - { url = "https://files.pythonhosted.org/packages/a0/27/aee8cf031b89c34caf83194ec6b7f2eed28d053fff8b6da6d00c85c56035/safetensors-0.4.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4fb3e0609ec12d2a77e882f07cced530b8262027f64b75d399f1504ffec0ba56", size = 478026 }, - { url = "https://files.pythonhosted.org/packages/da/33/1d9fc4805c623636e7d460f28eec92ebd1856f7a552df8eb78398a1ef4de/safetensors-0.4.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d0f1dd769f064adc33831f5e97ad07babbd728427f98e3e1db6902e369122737", size = 495545 }, - { url = "https://files.pythonhosted.org/packages/b9/df/6f766b56690709d22e83836e4067a1109a7d84ea152a6deb5692743a2805/safetensors-0.4.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c6d156bdb26732feada84f9388a9f135528c1ef5b05fae153da365ad4319c4c5", size = 435016 }, - { url = "https://files.pythonhosted.org/packages/90/fa/7bc3f18086201b1e55a42c88b822ae197d0158e12c54cd45c887305f1b7e/safetensors-0.4.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:9e347d77e2c77eb7624400ccd09bed69d35c0332f417ce8c048d404a096c593b", size = 456273 }, - { url = "https://files.pythonhosted.org/packages/3e/59/2ae50150d37a65c1c5f01aec74dc737707b8bbecdc76307e5a1a12c8a376/safetensors-0.4.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:9f556eea3aec1d3d955403159fe2123ddd68e880f83954ee9b4a3f2e15e716b6", size = 619669 }, - { url = "https://files.pythonhosted.org/packages/fe/43/10f0bb597aef62c9c154152e265057089f3c729bdd980e6c32c3ec2407a4/safetensors-0.4.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9483f42be3b6bc8ff77dd67302de8ae411c4db39f7224dec66b0eb95822e4163", size = 605212 }, - { url = "https://files.pythonhosted.org/packages/39/83/c4a7ce01d626e46ea2b45887f2e59b16441408031e2ce2f9fe01860c6946/safetensors-0.4.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:09dedf7c2fda934ee68143202acff6e9e8eb0ddeeb4cfc24182bef999efa9f42", size = 441093 }, - { url = "https://files.pythonhosted.org/packages/47/26/cc52de647e71bd9a0b0d78ead0d31d9c462b35550a817aa9e0cab51d6db4/safetensors-0.4.5-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:59b77e4b7a708988d84f26de3ebead61ef1659c73dcbc9946c18f3b1786d2688", size = 438960 }, - { url = "https://files.pythonhosted.org/packages/06/78/332538546775ee97e749867df2d58f2282d9c48a1681e4891eed8b94ec94/safetensors-0.4.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5d3bc83e14d67adc2e9387e511097f254bd1b43c3020440e708858c684cbac68", size = 478031 }, - { url = "https://files.pythonhosted.org/packages/d9/03/a3c8663f1ddda54e624ecf43fce651659b49e8e1603c52c3e464b442acfa/safetensors-0.4.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:39371fc551c1072976073ab258c3119395294cf49cdc1f8476794627de3130df", size = 494754 }, - { url = "https://files.pythonhosted.org/packages/e6/ee/69e498a892f208bd1da4104d4b9be887f8611bf4942144718b6738482250/safetensors-0.4.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a6c19feda32b931cae0acd42748a670bdf56bee6476a046af20181ad3fee4090", size = 435013 }, - { url = "https://files.pythonhosted.org/packages/a2/61/f0cfce984515b86d1260f556ba3b782158e2855e6a318446ac2613786fa9/safetensors-0.4.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a659467495de201e2f282063808a41170448c78bada1e62707b07a27b05e6943", size = 455984 }, - { url = "https://files.pythonhosted.org/packages/e7/a9/3e3b48fcaade3eb4e347d39ebf0bd44291db21a3e4507854b42a7cb910ac/safetensors-0.4.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:bad5e4b2476949bcd638a89f71b6916fa9a5cae5c1ae7eede337aca2100435c0", size = 619513 }, - { url = "https://files.pythonhosted.org/packages/80/23/2a7a1be24258c0e44c1d356896fd63dc0545a98d2d0184925fa09cd3ec76/safetensors-0.4.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a3a315a6d0054bc6889a17f5668a73f94f7fe55121ff59e0a199e3519c08565f", size = 604841 }, - { url = "https://files.pythonhosted.org/packages/d6/6c/7e04b7626809fc63f3698f4c50e43aff2864b40089aa4506c918a75b8eed/safetensors-0.4.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1524b54246e422ad6fb6aea1ac71edeeb77666efa67230e1faf6999df9b2e27f", size = 441134 }, - { url = "https://files.pythonhosted.org/packages/58/2b/ffe7c86a277e6c1595fbdf415cfe2903f253f574a5405e93fda8baaa582c/safetensors-0.4.5-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b3139098e3e8b2ad7afbca96d30ad29157b50c90861084e69fcb80dec7430461", size = 438467 }, - { url = "https://files.pythonhosted.org/packages/67/9c/f271bd804e08c7fda954d17b70ff281228a88077337a9e70feace4f4cc93/safetensors-0.4.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:65573dc35be9059770808e276b017256fa30058802c29e1038eb1c00028502ea", size = 476566 }, - { url = "https://files.pythonhosted.org/packages/4c/ad/4cf76a3e430a8a26108407fa6cb93e6f80d996a5cb75d9540c8fe3862990/safetensors-0.4.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fd33da8e9407559f8779c82a0448e2133737f922d71f884da27184549416bfed", size = 492253 }, - { url = "https://files.pythonhosted.org/packages/d9/40/a6f75ea449a9647423ec8b6f72c16998d35aa4b43cb38536ac060c5c7bf5/safetensors-0.4.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3685ce7ed036f916316b567152482b7e959dc754fcc4a8342333d222e05f407c", size = 434769 }, - { url = "https://files.pythonhosted.org/packages/52/47/d4b49b1231abf3131f7bb0bc60ebb94b27ee33e0a1f9569da05f8ac65dee/safetensors-0.4.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:dde2bf390d25f67908278d6f5d59e46211ef98e44108727084d4637ee70ab4f1", size = 457166 }, - { url = "https://files.pythonhosted.org/packages/c3/cd/006468b03b0fa42ff82d795d47c4193e99001e96c3f08bd62ef1b5cab586/safetensors-0.4.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7469d70d3de970b1698d47c11ebbf296a308702cbaae7fcb993944751cf985f4", size = 619280 }, - { url = "https://files.pythonhosted.org/packages/22/4d/b6208d918e83daa84b424c0ac3191ae61b44b3191613a3a5a7b38f94b8ad/safetensors-0.4.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3a6ba28118636a130ccbb968bc33d4684c48678695dba2590169d5ab03a45646", size = 605390 }, - { url = "https://files.pythonhosted.org/packages/a4/c7/4fda8a0ebb96662550433378f4a74c677fa5fc4d0a43a7ec287d1df254a9/safetensors-0.4.5-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:585f1703a518b437f5103aa9cf70e9bd437cb78eea9c51024329e4fb8a3e3679", size = 441378 }, - { url = "https://files.pythonhosted.org/packages/14/31/9abb431f6209de9c80dab83e1112ebd769f1e32e7ab7ab228a02424a4693/safetensors-0.4.5-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:4b99fbf72e3faf0b2f5f16e5e3458b93b7d0a83984fe8d5364c60aa169f2da89", size = 438831 }, - { url = "https://files.pythonhosted.org/packages/37/37/99bfb195578a808b8d045159ee9264f8da58d017ac0701853dcacda14d4e/safetensors-0.4.5-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b17b299ca9966ca983ecda1c0791a3f07f9ca6ab5ded8ef3d283fff45f6bcd5f", size = 477112 }, - { url = "https://files.pythonhosted.org/packages/7d/05/fac3ef107e60d2a78532bed171a91669d4bb259e1236f5ea8c67a6976c75/safetensors-0.4.5-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:76ded72f69209c9780fdb23ea89e56d35c54ae6abcdec67ccb22af8e696e449a", size = 493373 }, - { url = "https://files.pythonhosted.org/packages/cf/7a/825800ee8c68214b4fd3506d5e19209338c69b41e01c6e14dd13969cc8b9/safetensors-0.4.5-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2783956926303dcfeb1de91a4d1204cd4089ab441e622e7caee0642281109db3", size = 435422 }, - { url = "https://files.pythonhosted.org/packages/5e/6c/7a3233c08bde558d6c33a41219119866cb596139a4673cc6c24024710ffd/safetensors-0.4.5-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d94581aab8c6b204def4d7320f07534d6ee34cd4855688004a4354e63b639a35", size = 457382 }, - { url = "https://files.pythonhosted.org/packages/a0/58/0b7bcba3788ff503990cf9278d611b56c029400612ba93e772c987b5aa03/safetensors-0.4.5-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:67e1e7cb8678bb1b37ac48ec0df04faf689e2f4e9e81e566b5c63d9f23748523", size = 619301 }, - { url = "https://files.pythonhosted.org/packages/82/cc/9c2cf58611daf1c83ce5d37f9de66353e23fcda36008b13fd3409a760aa3/safetensors-0.4.5-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:dbd280b07e6054ea68b0cb4b16ad9703e7d63cd6890f577cb98acc5354780142", size = 605580 }, - { url = "https://files.pythonhosted.org/packages/08/94/7760694760f1e5001bd62c93155b8b7ccb652d1f4d0161d1e72b5bf9581a/safetensors-0.4.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:139fbee92570ecea774e6344fee908907db79646d00b12c535f66bc78bd5ea2c", size = 442391 }, - { url = "https://files.pythonhosted.org/packages/03/1c/0db6e6e5cb293907b2242447b48cc09f31478aa02f08773155c2a2db22de/safetensors-0.4.5-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c36302c1c69eebb383775a89645a32b9d266878fab619819ce660309d6176c9b", size = 440015 }, - { url = "https://files.pythonhosted.org/packages/15/58/9658bf7ca3a4e77577fbd2c7afda4701c558db66b01daf7cd4d9dbd9781e/safetensors-0.4.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d641f5b8149ea98deb5ffcf604d764aad1de38a8285f86771ce1abf8e74c4891", size = 478099 }, - { url = "https://files.pythonhosted.org/packages/9e/fa/44d9723a988dd54f43a5fcfa6b4d3a721e9294bb55d1c3e539a88619f1b2/safetensors-0.4.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b4db6a61d968de73722b858038c616a1bebd4a86abe2688e46ca0cc2d17558f2", size = 497170 }, - { url = "https://files.pythonhosted.org/packages/5d/80/81ba44fc82afbf5ca553913ac49460e325dc5cf00c317b34c14d43ebd76b/safetensors-0.4.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b75a616e02f21b6f1d5785b20cecbab5e2bd3f6358a90e8925b813d557666ec1", size = 436076 }, - { url = "https://files.pythonhosted.org/packages/2e/ad/7880a359b0f93322689804bdbe1e9a3110652963478712933ff04a3d45c3/safetensors-0.4.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:788ee7d04cc0e0e7f944c52ff05f52a4415b312f5efd2ee66389fb7685ee030c", size = 456901 }, - { url = "https://files.pythonhosted.org/packages/89/4f/0b61e4add7ea9dfa8141d0bb1b8357e3a08730a020c3a287f0e889c386b5/safetensors-0.4.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:87bc42bd04fd9ca31396d3ca0433db0be1411b6b53ac5a32b7845a85d01ffc2e", size = 620159 }, - { url = "https://files.pythonhosted.org/packages/a9/60/544687daf8ce8dc9a74260992ac058d7e3f20c91eada5ca232898d005149/safetensors-0.4.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:4037676c86365a721a8c9510323a51861d703b399b78a6b4486a54a65a975fca", size = 605993 }, - { url = "https://files.pythonhosted.org/packages/ae/88/3068e1bb16f5e9f9068901de3cf7b3db270b9bfe6e7d51d4b55c1da0425d/safetensors-0.4.5-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dd8a1f6d2063a92cd04145c7fd9e31a1c7d85fbec20113a14b487563fdbc0597", size = 442311 }, - { url = "https://files.pythonhosted.org/packages/f7/15/a2bb77ebbaa76b61ec2e9f731fe4db7f9473fd855d881957c51b3a168892/safetensors-0.4.5-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:951d2fcf1817f4fb0ef0b48f6696688a4e852a95922a042b3f96aaa67eedc920", size = 436678 }, - { url = "https://files.pythonhosted.org/packages/ec/79/9608c4546cdbfe3860dd7aa59e3562c9289113398b1a0bd89b68ce0a9d41/safetensors-0.4.5-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6ac85d9a8c1af0e3132371d9f2d134695a06a96993c2e2f0bbe25debb9e3f67a", size = 457316 }, - { url = "https://files.pythonhosted.org/packages/0f/23/b17b483f2857835962ad33e38014efd4911791187e177bc23b057d35bee8/safetensors-0.4.5-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:e3cec4a29eb7fe8da0b1c7988bc3828183080439dd559f720414450de076fcab", size = 620565 }, - { url = "https://files.pythonhosted.org/packages/19/46/5d11dc300feaad285c2f1bd784ff3f689f5e0ab6be49aaf568f3a77019eb/safetensors-0.4.5-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:21742b391b859e67b26c0b2ac37f52c9c0944a879a25ad2f9f9f3cd61e7fda8f", size = 606660 }, - { url = "https://files.pythonhosted.org/packages/b3/ff/b26d78b6100a08e57a1986ab71a2f9f093ba9943626f4967cd514cd43de2/safetensors-0.4.5-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f0b6453c54c57c1781292c46593f8a37254b8b99004c68d6c3ce229688931a22", size = 442275 }, - { url = "https://files.pythonhosted.org/packages/71/29/6ac541358a07ec593ec9e88636908010bc9bf56c8018e0d25b4481adb64a/safetensors-0.4.5-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:adaa9c6dead67e2dd90d634f89131e43162012479d86e25618e821a03d1eb1dc", size = 437217 }, - { url = "https://files.pythonhosted.org/packages/2b/f8/258564b71fe95d0117356e6915b1c0128f1ec3031cf8522a28f9d2108b47/safetensors-0.4.5-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:73e7d408e9012cd17511b382b43547850969c7979efc2bc353f317abaf23c84c", size = 458132 }, - { url = "https://files.pythonhosted.org/packages/18/ac/510eebf3ac521fec3b0ea78e654e22d85de3406613209d20133b5b3cca33/safetensors-0.4.5-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:775409ce0fcc58b10773fdb4221ed1eb007de10fe7adbdf8f5e8a56096b6f0bc", size = 621171 }, - { url = "https://files.pythonhosted.org/packages/e0/c8/a02b635e39f3b904f52aff099505bdfbb40252d2d18a05e7fedc0bb64a28/safetensors-0.4.5-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:834001bed193e4440c4a3950a31059523ee5090605c907c66808664c932b549c", size = 607366 }, +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } +wheels = [ + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl" }, + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" }, + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl" }, + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl" }, + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" }, + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl" }, + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl" }, + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" }, + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl" }, + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl" }, + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" }, + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl" }, + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl" }, + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl" }, + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl" }, + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/safetensors-0.4.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl" }, ] [[package]] @@ -1232,17 +1163,16 @@ wheels = [ [[package]] name = "setuptools" -version = "75.1.0" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/27/b8/f21073fde99492b33ca357876430822e4800cdf522011f18041351dfa74b/setuptools-75.1.0.tar.gz", hash = "sha256:d59a21b17a275fb872a9c3dae73963160ae079f1049ed956880cd7c09b120538", size = 1348057 } +version = "70.2.0" +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ff/ae/f19306b5a221f6a436d8f2238d5b80925004093fa3edea59835b514d9057/setuptools-75.1.0-py3-none-any.whl", hash = "sha256:35ab7fd3bcd95e6b7fd704e4a1539513edad446c097797f2985e0e4b960772f2", size = 1248506 }, + { url = "https://download.pytorch.org/whl/nightly/setuptools-70.2.0-py3-none-any.whl" }, ] [[package]] name = "six" version = "1.16.0" -source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } wheels = [ { url = "https://download.pytorch.org/whl/nightly/six-1.16.0-py2.py3-none-any.whl" }, ] @@ -1264,7 +1194,7 @@ wheels = [ [[package]] name = "sympy" version = "1.13.1" -source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } dependencies = [ { name = "mpmath", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, ] @@ -1274,29 +1204,33 @@ wheels = [ [[package]] name = "tensorrt-cu12" -version = "10.3.0" +version = "10.6.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/45/be/7147f7644b933fbba27fd02c4444bd17f3a68e3c34c8de31d31786e92d60/tensorrt-cu12-10.3.0.tar.gz", hash = "sha256:14f0e60f40713a658f9634fffb1a5a665c35feb019be48b2f49e25ac12d2d084", size = 18294 } +sdist = { url = "https://files.pythonhosted.org/packages/b8/ca/7fbe8f3f454b324e548b69697b3965b3a11e82bffaf6fb7901393fcdb2ff/tensorrt-cu12-10.6.0.tar.gz", hash = "sha256:ca9e611ff8a9424cf6431a5f649eacb35721f25335112d17dea248e9334b59a7", size = 18331 } [[package]] name = "tensorrt-cu12-bindings" -version = "10.3.0" +version = "10.6.0" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/3b/68/eab45c46fdcafe08c6b21de4560fe2d3d845ce072d3e7743de4077c2d8c0/tensorrt_cu12_bindings-10.3.0-cp310-none-manylinux_2_17_x86_64.whl", hash = "sha256:1d6e4cf08ef1f54f6fd44a33cf6b253050af2fc6e9a1d92e40e1436a1d858eb0", size = 1108101 }, - { url = "https://files.pythonhosted.org/packages/4c/ce/47593af3fd15777ff48040da2901d539905c9bed3fc167d4368b0d4fcbf7/tensorrt_cu12_bindings-10.3.0-cp311-none-manylinux_2_17_x86_64.whl", hash = "sha256:59ace22d7f2ca1e9dcde2cb0cb5916912cb3cd5a9d72dd7852be0160d9b3a0ee", size = 1111069 }, - { url = "https://files.pythonhosted.org/packages/cd/1f/8215c8ff476bdc5f8d256413892ad48296df4277af077eefb9f7c0dcfeac/tensorrt_cu12_bindings-10.3.0-cp312-none-manylinux_2_17_x86_64.whl", hash = "sha256:f5c2582aeaa7f5628d2c4d4148a701ebe97be78f7ff3b46a617f0ee0cb5460f2", size = 1098829 }, - { url = "https://files.pythonhosted.org/packages/71/bf/32b901d844527fdfa5dbc7e57ac3ac10c48ce682254289f790a72faae162/tensorrt_cu12_bindings-10.3.0-cp39-none-manylinux_2_17_x86_64.whl", hash = "sha256:db337018c55043502eff993f165160044b4bebb935f01c8f8f93e4ee71481dc4", size = 1108759 }, + { url = "https://files.pythonhosted.org/packages/46/20/87ee19a4b5784e92703f208e55291f2fc0dc4d63c743f5e5c8cfc0eb34a5/tensorrt_cu12_bindings-10.6.0-cp310-none-manylinux_2_17_x86_64.whl", hash = "sha256:c804ecc791e7f7ec3bdc3eb6b18b271952981c6bba323e7642f99eb7975648a6", size = 1180107 }, + { url = "https://files.pythonhosted.org/packages/dd/65/166d719cfff5bf16247eb6702d7eca2bcc99a758dbdfb935013dbaa0060a/tensorrt_cu12_bindings-10.6.0-cp310-none-manylinux_2_31_aarch64.whl", hash = "sha256:39e6ee87a22c095b48478e62c8206646b84505e83d91e3b300cc551d164025d4", size = 1153044 }, + { url = "https://files.pythonhosted.org/packages/15/b9/c1a27ed050f1cf148eb3e5385204a984105a5310ded60f95bb098ca5e71f/tensorrt_cu12_bindings-10.6.0-cp311-none-manylinux_2_17_x86_64.whl", hash = "sha256:fd1f4b4b48ac9ab070f0e9fc012892a936e7651c9c43382304a8a081f640b723", size = 1181192 }, + { url = "https://files.pythonhosted.org/packages/d5/1d/7cffca7f0d35c79f0bc1462e13ef2b6c3bc5fcfbd774028299977e49d875/tensorrt_cu12_bindings-10.6.0-cp311-none-manylinux_2_31_aarch64.whl", hash = "sha256:14bcdc98d5f19e9b325563ce1ba2c74657270764e5591c75f9eb703ffa2f2d5d", size = 1151587 }, + { url = "https://files.pythonhosted.org/packages/d3/28/c70435e21daa52fa20b5a48fb8fc225a116c4441bdeea5b15edbade8ada2/tensorrt_cu12_bindings-10.6.0-cp312-none-manylinux_2_17_x86_64.whl", hash = "sha256:8a1bd5e7911f9875a3c3436cdb834e6aba7a94ab02c42e36883d025ce448835b", size = 1170744 }, + { url = "https://files.pythonhosted.org/packages/d0/03/3c2f981db3b4620bad291015c1ce04417227f99b0b9b23ffc6fe62ccdbf2/tensorrt_cu12_bindings-10.6.0-cp312-none-manylinux_2_31_aarch64.whl", hash = "sha256:5922c164c6448de8165cc709f0391b2c2e4b1fdddde3b80db06470c514268272", size = 1136506 }, + { url = "https://files.pythonhosted.org/packages/af/a1/41b188be56410b79f1dc74812871c30912fc39d7f8371c8789f17da68959/tensorrt_cu12_bindings-10.6.0-cp39-none-manylinux_2_17_x86_64.whl", hash = "sha256:d35bad461985a5b29b3ffb90ba3b006b6db03a948ed45397be769c077fb42556", size = 1180819 }, + { url = "https://files.pythonhosted.org/packages/9f/9f/fd47f33011f6575af2adcca37d8f3cdb1dac9a985c8fc245b6ad177bd432/tensorrt_cu12_bindings-10.6.0-cp39-none-manylinux_2_31_aarch64.whl", hash = "sha256:a115167cba6242c527a08ab2884cd1c467e3ba38aad3210912ef4d04cba82c68", size = 1153397 }, ] [[package]] name = "tensorrt-cu12-libs" -version = "10.3.0" +version = "10.6.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nvidia-cuda-runtime-cu12", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/80/10/91ccf3ba8edaf4d5f05dbe36aa533033714e0011422d57035a5491ba69c4/tensorrt_cu12_libs-10.3.0.tar.gz", hash = "sha256:d2f36838e2762b5ceb62f614157ba4764de2fa1f4fe5661c6cfc07e07e6e71da", size = 630 } +sdist = { url = "https://files.pythonhosted.org/packages/da/4e/8bb122944a9f7e42ff1d6abb526aeb603d5cfefa1f204bf405ae89d5b979/tensorrt_cu12_libs-10.6.0.tar.gz", hash = "sha256:c3c2192331da2867d094a3e237c5282fad95243d7e66b4499b6f3708a4fe14c3", size = 626 } [[package]] name = "tokenizers" @@ -1362,8 +1296,8 @@ wheels = [ [[package]] name = "torch" -version = "2.6.0.dev20240924+cu124" -source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } +version = "2.6.0.dev20241126+cu126" +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } dependencies = [ { name = "filelock", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, { name = "fsspec", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, @@ -1378,6 +1312,7 @@ dependencies = [ { name = "nvidia-curand-cu12", marker = "(platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform == 'linux') or (platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform == 'windows')" }, { name = "nvidia-cusolver-cu12", marker = "(platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform == 'linux') or (platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform == 'windows')" }, { name = "nvidia-cusparse-cu12", marker = "(platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform == 'linux') or (platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform == 'windows')" }, + { name = "nvidia-cusparselt-cu12", marker = "(platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform == 'linux') or (platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform == 'windows')" }, { name = "nvidia-nccl-cu12", marker = "(platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform == 'linux') or (platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform == 'windows')" }, { name = "nvidia-nvjitlink-cu12", marker = "(platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform == 'linux') or (platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform == 'windows')" }, { name = "nvidia-nvtx-cu12", marker = "(platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform == 'linux') or (platform_machine == 'x86_64' and platform_system == 'Linux' and sys_platform == 'windows')" }, @@ -1387,20 +1322,17 @@ dependencies = [ { name = "typing-extensions", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, ] wheels = [ - { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.6.0.dev20240924%2Bcu124-cp310-cp310-linux_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.6.0.dev20240924%2Bcu124-cp310-cp310-linux_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.6.0.dev20240924%2Bcu124-cp311-cp311-linux_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.6.0.dev20240924%2Bcu124-cp311-cp311-linux_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.6.0.dev20240924%2Bcu124-cp312-cp312-linux_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.6.0.dev20240924%2Bcu124-cp312-cp312-linux_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.6.0.dev20240924%2Bcu124-cp313-cp313-linux_x86_64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.6.0.dev20240924%2Bcu124-cp39-cp39-linux_aarch64.whl" }, - { url = "https://download.pytorch.org/whl/nightly/cu124/torch-2.6.0.dev20240924%2Bcu124-cp39-cp39-linux_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu126/torch-2.6.0.dev20241126%2Bcu126-cp310-cp310-linux_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu126/torch-2.6.0.dev20241126%2Bcu126-cp311-cp311-linux_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu126/torch-2.6.0.dev20241126%2Bcu126-cp312-cp312-linux_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu126/torch-2.6.0.dev20241126%2Bcu126-cp313-cp313-linux_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu126/torch-2.6.0.dev20241126%2Bcu126-cp313-cp313t-linux_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu126/torch-2.6.0.dev20241126%2Bcu126-cp39-cp39-linux_x86_64.whl" }, ] [[package]] name = "torch-tensorrt" -version = "2.6.0.dev0+0de0b1651" +version = "2.6.0.dev0+441fce45b" source = { editable = "." } dependencies = [ { name = "numpy", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, @@ -1449,11 +1381,11 @@ requires-dist = [ { name = "packaging", specifier = ">=23" }, { name = "rich", marker = "extra == 'monitoring-tools'", specifier = ">=13.7.1" }, { name = "rich", extras = ["jupyter"], marker = "extra == 'jupyter'", specifier = ">=13.7.1" }, - { name = "tensorrt-cu12", specifier = "==10.3.0" }, - { name = "tensorrt-cu12-bindings", specifier = "==10.3.0" }, - { name = "tensorrt-cu12-libs", specifier = "==10.3.0" }, - { name = "torch", specifier = ">=2.6.0.dev0,<2.7.0" }, - { name = "torchvision", marker = "extra == 'torchvision'" }, + { name = "tensorrt-cu12", specifier = "==10.6.0" }, + { name = "tensorrt-cu12-bindings", specifier = "==10.6.0" }, + { name = "tensorrt-cu12-libs", specifier = "==10.6.0" }, + { name = "torch", specifier = ">=2.6.0.dev0,<2.7.0", index = "https://download.pytorch.org/whl/nightly/cu126" }, + { name = "torchvision", marker = "extra == 'torchvision'", index = "https://download.pytorch.org/whl/nightly/cu126" }, { name = "typing-extensions", specifier = ">=4.7.0" }, ] @@ -1489,29 +1421,30 @@ wheels = [ [[package]] name = "torchvision" -version = "0.11.3" -source = { registry = "https://pypi.org/simple" } +version = "0.20.0.dev20241126+cu126" +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } dependencies = [ { name = "numpy", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, { name = "pillow", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, { name = "torch", marker = "sys_platform == 'linux' or sys_platform == 'windows'" }, ] wheels = [ - { url = "https://files.pythonhosted.org/packages/48/20/380758a94be49d38798a6cfd25824f72ec1f230b00c0014efb15903777c6/torchvision-0.11.3-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:8bc8a7db80c97ca254be362ba883a202192e361ba2f6dff7ff5bb010d4bfc23a", size = 14675721 }, - { url = "https://files.pythonhosted.org/packages/ac/b1/9702d02e233bec7ce231cc8be94489ee31084fb6d350703f0ed22086ebed/torchvision-0.11.3-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:eca0b0f7a0e462bdecf7926d89faae6dcd51da418ca0cf70e725981ed775a11b", size = 23199346 }, - { url = "https://files.pythonhosted.org/packages/ac/d3/913e25d7775c74f76d174a82eba45bf68e384dc78373598f6c2b3a727fed/torchvision-0.11.3-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:25e72231be8ce03467a77806d9c3f5fd34b9cd23b9543d3e999bf57622377532", size = 14674764 }, + { url = "https://download.pytorch.org/whl/nightly/cu126/torchvision-0.20.0.dev20241126%2Bcu126-cp310-cp310-linux_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu126/torchvision-0.20.0.dev20241126%2Bcu126-cp311-cp311-linux_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu126/torchvision-0.20.0.dev20241126%2Bcu126-cp312-cp312-linux_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu126/torchvision-0.20.0.dev20241126%2Bcu126-cp313-cp313-linux_x86_64.whl" }, + { url = "https://download.pytorch.org/whl/nightly/cu126/torchvision-0.20.0.dev20241126%2Bcu126-cp39-cp39-linux_x86_64.whl" }, ] [[package]] name = "tqdm" version = "4.66.5" -source = { registry = "https://pypi.org/simple" } +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } dependencies = [ { name = "colorama", marker = "(platform_system == 'Windows' and sys_platform == 'linux') or (platform_system == 'Windows' and sys_platform == 'windows')" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/58/83/6ba9844a41128c62e810fddddd72473201f3eacde02046066142a2d96cc5/tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad", size = 169504 } wheels = [ - { url = "https://files.pythonhosted.org/packages/48/5d/acf5905c36149bbaec41ccf7f2b68814647347b72075ac0b1fe3022fdc73/tqdm-4.66.5-py3-none-any.whl", hash = "sha256:90279a3770753eafc9194a0364852159802111925aa30eb3f9d85b0e805ac7cd", size = 78351 }, + { url = "https://download.pytorch.org/whl/nightly/tqdm-4.66.5-py3-none-any.whl" }, ] [[package]] @@ -1547,7 +1480,7 @@ wheels = [ [[package]] name = "typing-extensions" version = "4.12.2" -source = { registry = "https://download.pytorch.org/whl/nightly/cu124" } +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } wheels = [ { url = "https://download.pytorch.org/whl/nightly/typing_extensions-4.12.2-py3-none-any.whl" }, ] @@ -1568,10 +1501,9 @@ wheels = [ [[package]] name = "urllib3" version = "2.2.3" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/ed/63/22ba4ebfe7430b76388e7cd448d5478814d3032121827c12a2cc287e2260/urllib3-2.2.3.tar.gz", hash = "sha256:e7d814a81dad81e6caf2ec9fdedb284ecc9c73076b62654547cc64ccdcae26e9", size = 300677 } +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ce/d9/5f4c13cecde62396b0d3fe530a50ccea91e7dfc1ccf0e09c228841bb5ba8/urllib3-2.2.3-py3-none-any.whl", hash = "sha256:ca899ca043dcb1bafa3e262d73aa25c465bfb49e0bd9dd5d59f1d0acba2f8fac", size = 126338 }, + { url = "https://download.pytorch.org/whl/nightly/urllib3-2.2.3-py3-none-any.whl" }, ] [[package]] @@ -1608,9 +1540,8 @@ wheels = [ [[package]] name = "zipp" -version = "3.20.2" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/54/bf/5c0000c44ebc80123ecbdddba1f5dcd94a5ada602a9c225d84b5aaa55e86/zipp-3.20.2.tar.gz", hash = "sha256:bc9eb26f4506fda01b81bcde0ca78103b6e62f991b381fec825435c836edbc29", size = 24199 } +version = "3.19.2" +source = { registry = "https://download.pytorch.org/whl/nightly/cu126" } wheels = [ - { url = "https://files.pythonhosted.org/packages/62/8b/5ba542fa83c90e09eac972fc9baca7a88e7e7ca4b221a89251954019308b/zipp-3.20.2-py3-none-any.whl", hash = "sha256:a817ac80d6cf4b23bf7f2828b7cabf326f15a001bea8b1f9b49631780ba28350", size = 9200 }, + { url = "https://download.pytorch.org/whl/nightly/zipp-3.19.2-py3-none-any.whl" }, ] From 9bcab518f4e8cf50c68cda83f9e4fcd0782eae80 Mon Sep 17 00:00:00 2001 From: Bo Wang Date: Fri, 17 Jan 2025 18:19:04 +0000 Subject: [PATCH 8/8] update --- .../dynamo/automatic_plugin_generation.py | 274 ++++++++++++++++-- .../conversion/plugins/plugin_generator.py | 63 ++++ 2 files changed, 313 insertions(+), 24 deletions(-) diff --git a/examples/dynamo/automatic_plugin_generation.py b/examples/dynamo/automatic_plugin_generation.py index 8bbb75ebfa..c0bd2aa764 100644 --- a/examples/dynamo/automatic_plugin_generation.py +++ b/examples/dynamo/automatic_plugin_generation.py @@ -83,30 +83,256 @@ def _(x: torch.Tensor, y: torch.Tensor, b: float=.2, a: int=2) -> torch.Tensor: from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv from sympy import lambdify -@trtp.register("torchtrt_ex::elementwise_mul") -def _(x: trtp.TensorDesc, y: trtp.TensorDesc, b: float, a: int) -> Tuple[trtp.TensorDesc]: - from torch._subclasses.fake_tensor import FakeTensorMode - from torch.fx.experimental.symbolic_shapes import ShapeEnv - from sympy import lambdify - shape_env = ShapeEnv() - fake_mode = FakeTensorMode(shape_env=shape_env) - sample_x = {f"x{i}": 5 for i in range(x.ndim)} - sample_y = {f"y{i}": 5 for i in range(y.ndim)} - syms_x = [mksym(shape_env, v, LocalSource(k), DimDynamic.DYNAMIC) for k,v in sample_x.items()] - syms_y = [mksym(shape_env, v, LocalSource(k), DimDynamic.DYNAMIC) for k,v in sample_y.items()] - with FakeTensorMode() as fake_mode: - fake_x = torch.randn(syms_x) - fake_y = torch.randn(syms_y) - z = torch.ops.torchtrt_ex.elementwise_mul(fake_x, fake_y, b, a) - - shape_calc_fns = [None] * x.ndim - for i in range(x.ndim): - shape_calc_fns[i] = lambdify((syms_x[i].node.expr, syms_y[i].node.expr), z.shape[i].node.expr, "math") - - out_desc = x.like() - for i in range(out_desc.ndim): - out_desc.shape_expr[i] = shape_calc_fns[i](x.shape_expr[i], y.shape_expr[i]) - return out_desc + +def generate_plugin(plugin_name : str): + namespace, name = plugin_name.split("::") + + torch_op = getattr(getattr(torch.ops, namespace), name) # torch.ops.torchtrt_ex.elementwise_mul + print(torch_op) + # retrieve torch.ops.torchtrt_ex.elementwise_mul + + print(torch_op._schemas) + + # import pdb; pdb.set_trace(); + + # def parse_torch_op_schema(torch_op): + # schema = torch_op._schemas[''] + # args = [] + # kwargs = {} + # for arg in schema.arguments: + # print(f"Name: {arg.name}, Type: {arg.type}, Default: {arg.default_value}") + + # for ret in schema.returns: + # print(f"Return Type: {ret.type}") + # # if arg.default_value is None: + # # args.append(arg.name) + # # else: + # # kwargs[arg.name] = arg.default_value + # return args, kwargs + + # parse_torch_op_schema(torch_op) + + def _generic_plugin_desc_creator(torch_op): + schema = torch_op._schemas[''] + + tensor_args = [] + + arg_list = [] + + func_body = [] + + func_body.append(" shape_env = ShapeEnv()") + func_body.append("fake_mode = FakeTensorMode(shape_env=shape_env)") + + for arg in schema.arguments: + print(arg.type) + # import pdb; pdb.set_trace(); + arg_type = "trtp.TensorDesc" if arg.type.isSubtypeOf(torch._C.TensorType.get()) else arg.type + arg_list.append(f"{arg.name} : {arg_type}") + + if arg.type.isSubtypeOf(torch._C.TensorType.get()): + tensor_args.append(arg) + + for arg in tensor_args: + func_body.append(f"sample_{arg.name} = {{f'{arg.name}{{i}}': 5 for i in range({arg.name}.ndim)}}") + + + for arg in tensor_args: + func_body.append(f"sysm_{arg.name} = [mksym(shape_env, v, LocalSource(k), DimDynamic.DYNAMIC) for k, v in sample_{arg.name}.items()]") + + func_body.append("with FakeTensorMode() as fake_mode:") + + for arg in tensor_args: + func_body.append(f" fake_{arg.name} = torch.randn(sysm_{arg.name})") + + # sample_x = {f"x{i}": 5 for i in range(x.ndim)} + # sample_y = {f"y{i}": 5 for i in range(y.ndim)} + # syms_x = [mksym(shape_env, v, LocalSource(k), DimDynamic.DYNAMIC) for k,v in sample_x.items()] + # syms_y = [mksym(shape_env, v, LocalSource(k), DimDynamic.DYNAMIC) for k,v in sample_y.items()] + # running_line = f"output = {torch_op}(" + running_args = [] + for arg in schema.arguments: + if arg in tensor_args: + running_args.append(f"fake_{arg.name}") + else: + running_args.append(f"{arg.name}") + running_line_args = ", ".join(running_args) + running_line = f" output = torch.ops.{torch_op}({running_line_args})" + func_body.append(running_line) + + + # Join the argument list to create the signature + input_signature = ", ".join(arg_list) + print(input_signature) + + ret_list = [] + for ret in schema.returns: + print(ret.type) + if ret.type.isSubtypeOf(torch._C.TensorType.get()): + ret_list.append(f"trtp.TensorDesc") + else: + raise Exception("Return type has be to Tensor for TRT plugin") + + + ret_signature = "trtp.TensorDesc" if len(ret_list) == 1 else f"Tuple[{', '.join(ret_list)}" + + plugin_signature = f"def add_plugin_desc({input_signature}) -> {ret_signature}:" + print(plugin_signature) + + + body_str = "\n ".join(func_body) + print("-----------------\n") + print(plugin_signature) + print(body_str) + print("\n-----------------\n") + + def generate_signature(torch_op): + schema = torch_op._schemas[''] + tensor_args = [] + arg_list = [] + func_body = [] + + func_body.append(" shape_env = ShapeEnv()") + func_body.append("fake_mode = FakeTensorMode(shape_env=shape_env)") + + args = [] + kwargs = [] + + for arg in schema.arguments: + # import pdb; pdb.set_trace(); + arg_type = "trtp.TensorDesc" if arg.type.isSubtypeOf(torch._C.TensorType.get()) else arg.type + # arg_list.append(f"{arg.name} : {arg_type}") + arg_list.append(arg.name) + + if arg.type.isSubtypeOf(torch._C.TensorType.get()): + tensor_args.append(arg) + + if arg.default_value is None: + args.append(arg.name) + else: + kwargs.append(f"{arg.name} = {arg.default_value}") + + input_signature = ", ".join(arg_list) + + ret_list = [] + for ret in schema.returns: + print(ret.type) + if ret.type.isSubtypeOf(torch._C.TensorType.get()): + ret_list.append(f"trtp.TensorDesc") + else: + raise Exception("Return type has be to Tensor for TRT plugin") + + + ret_signature = "trtp.TensorDesc" if len(ret_list) == 1 else f"Tuple[{', '.join(ret_list)}" + + plugin_signature = f"def add_plugin_desc({input_signature}):" + args_input = ", ".join(args) + kwargs_input = ", ".join(kwargs) + # print(args_input) + # print(kwargs_input) + # print(plugin_signature) + return args_input, kwargs_input, plugin_signature + + + # _generic_plugin_desc_creator(torch_op) + + args_input, kwargs_input, plugin_signature = generate_signature(torch_op) + + def _generic_plugin_desc(*args, **kwargs) -> trtp.TensorDesc: + shape_env = ShapeEnv() + fake_mode = FakeTensorMode(shape_env=shape_env) + syms_args = [] + for arg in args: + sample = {f"{i}": 5 for i in range(arg.ndim)} + syms_arg = [mksym(shape_env, v, LocalSource(k), DimDynamic.DYNAMIC) for k,v in sample.items()] + syms_args.append(syms_arg) + + with FakeTensorMode() as fake_mode: + fake_args = [] + for syms_arg in syms_args: + fake_arg = torch.randn(syms_arg) + fake_args.append(fake_arg) + + output = torch_op(fake_args, kwargs) + + # We assume that number of dimensions are the same in torch op + + shape_calc_fns = [None] * args[0].ndim + for i in range(args[0].ndim): + input_node_expr = [syms_arg[i].node.expr for syms_arg in syms_args] + shape_calc_fns[i] = lambdify(input_node_expr, output.shape[i].node.expr, "math") + + + out_desc = args[0].like() + for i in range(out_desc.ndim): + input_shape_expr = [arg.shape_expr[i] for arg in args] + out_desc.shape_expr[i] = shape_calc_fns[i](input_shape_expr) + return out_desc + + + + # [SOME PYTHON CODE HERE] + codegen_plugin = f""" +{plugin_signature} + return _generic_plugin_desc({args_input}, {kwargs_input}) + """ + +# codegen_plugin = f""" +# def add_plugin_desc(a): +# return a +# """ + + print(codegen_plugin) + + plugin_code = compile(codegen_plugin, "", "exec") + + print(type(plugin_code)) + print(plugin_code.co_consts[0]) + + from types import FunctionType + + plugin= FunctionType(plugin_code.co_consts[0], globals(), "plugin") + + print(plugin) + + print(f"Function name: {plugin.__name__}") + print(f"Argument count: {plugin.__code__.co_argcount}") + print(f"Argument names: {plugin.__code__.co_varnames[:plugin.__code__.co_argcount]}") + print(f"Function bytecode: {plugin.__code__.co_code}") + + plugin.__annotations__ = {'X' : trtp.TensorDesc, 'Y' : trtp.TensorDesc, 'b' : float, 'a': int, 'return': trtp.TensorDesc} + + trtp.register(plugin_name)(plugin) + + + return plugin + +generate_plugin("torchtrt_ex::elementwise_mul") + +# @trtp.register("torchtrt_ex::elementwise_mul") +# def _(x: trtp.TensorDesc, y: trtp.TensorDesc, b: float, a: int) -> Tuple[trtp.TensorDesc]: +# from torch._subclasses.fake_tensor import FakeTensorMode +# from torch.fx.experimental.symbolic_shapes import ShapeEnv +# from sympy import lambdify +# shape_env = ShapeEnv() +# fake_mode = FakeTensorMode(shape_env=shape_env) +# sample_x = {f"x{i}": 5 for i in range(x.ndim)} +# sample_y = {f"y{i}": 5 for i in range(y.ndim)} +# syms_x = [mksym(shape_env, v, LocalSource(k), DimDynamic.DYNAMIC) for k,v in sample_x.items()] +# syms_y = [mksym(shape_env, v, LocalSource(k), DimDynamic.DYNAMIC) for k,v in sample_y.items()] +# with FakeTensorMode() as fake_mode: +# fake_x = torch.randn(syms_x) +# fake_y = torch.randn(syms_y) +# z = torch.ops.torchtrt_ex.elementwise_mul(fake_x, fake_y, b, a) + +# shape_calc_fns = [None] * x.ndim +# for i in range(x.ndim): +# shape_calc_fns[i] = lambdify((syms_x[i].node.expr, syms_y[i].node.expr), z.shape[i].node.expr, "math") + +# out_desc = x.like() +# for i in range(out_desc.ndim): +# out_desc.shape_expr[i] = shape_calc_fns[i](x.shape_expr[i], y.shape_expr[i]) +# return out_desc @trtp.impl("torchtrt_ex::elementwise_mul") diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/plugin_generator.py b/py/torch_tensorrt/dynamo/conversion/plugins/plugin_generator.py index bc50a92998..a1a1dd2e66 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugins/plugin_generator.py +++ b/py/torch_tensorrt/dynamo/conversion/plugins/plugin_generator.py @@ -99,6 +99,69 @@ def tensorrt_plugin_impl( return inner(fn) +def _generate_plugin( + namespace: str, + op_name: str, +): + @trtp.register(f"{namespace}::{op_name}") + def add_plugin_desc(x: trtp.TensorDesc, y: trtp.TensorDesc, b: float, a: int) -> Tuple[trtp.TensorDesc]: + from torch._subclasses.fake_tensor import FakeTensorMode + from torch.fx.experimental.symbolic_shapes import ShapeEnv + from sympy import lambdify + shape_env = ShapeEnv() + fake_mode = FakeTensorMode(shape_env=shape_env) + sample_x = {f"x{i}": 5 for i in range(x.ndim)} + sample_y = {f"y{i}": 5 for i in range(y.ndim)} + syms_x = [mksym(shape_env, v, LocalSource(k), DimDynamic.DYNAMIC) for k,v in sample_x.items()] + syms_y = [mksym(shape_env, v, LocalSource(k), DimDynamic.DYNAMIC) for k,v in sample_y.items()] + with FakeTensorMode() as fake_mode: + fake_x = torch.randn(syms_x) + fake_y = torch.randn(syms_y) + z = torch.ops.torchtrt_ex.elementwise_mul(fake_x, fake_y, b, a) + + shape_calc_fns = [None] * x.ndim + for i in range(x.ndim): + shape_calc_fns[i] = lambdify((syms_x[i].node.expr, syms_y[i].node.expr), z.shape[i].node.expr, "math") + + out_desc = x.like() + for i in range(out_desc.ndim): + out_desc.shape_expr[i] = shape_calc_fns[i](x.shape_expr[i], y.shape_expr[i]) + + + # Type annotations can be omitted for autotune and impl definitions, but will be checked for consistency if added + @trtp.autotune(f"{namespace}::{op_name}") + def add_plugin_autotune( + inp0: trtp.TensorDesc, block_size: int, outputs: Tuple[trtp.TensorDesc] + ) -> List[trtp.AutoTuneCombination]: + return [trtp.AutoTuneCombination("FP32|FP16, FP32|FP16", "LINEAR", [1, 2])] + + + @trtp.impl(f"{namespace}::{op_name}") + def add_plugin_impl(x: trtp.Tensor, y: trtp.Tensor, b: float, a: int, outputs: Tuple[trtp.Tensor], stream: int): + # This should be based on Torch schema + in_tensors = [ + torch.as_tensor(i, device="cuda") for i in (x, y) + ] # What is the right device?? + dest_tensors = [torch.as_tensor(o, device="cuda") for o in outputs] + + stream = torch.cuda.ExternalStream(stream) + with torch.cuda.stream(stream): + out_tensors = torch.ops.torchtrt_ex.elementwise_mul(*in_tensors, b, a) + [d.copy_(o) for (d, o) in zip(dest_tensors, out_tensors)] + + + +def generate_plugin( + plugin_id: str, + # capability_validator: Optional[Callable[[Node, CompilationSettings], bool]] = None, + +): + plugin_ns, plugin_name = plugin_id.split("::") + _generate_plugin( + plugin_ns, + plugin_name, + ) + def _generate_plugin_converter( namespace: str, op_name: str,