|
1 | | -// RUN: sdy_opt %s -sdy-sink-data-flow-edges | FileCheck %s |
| 1 | +// RUN: sdy_opt %s -split-input-file -sdy-sink-data-flow-edges | FileCheck %s |
2 | 2 |
|
3 | 3 | sdy.mesh @mesh = <["a"=2, "b"=2, "c"=2]> |
4 | 4 | sdy.mesh @other_mesh = <["c"=4]> |
@@ -249,3 +249,202 @@ func.func @manual_computation_origin_debug_info(%arg0: tensor<32x32x32xf32>) -> |
249 | 249 | %2 = sdy.data_flow_edge %1 sharding=<@mesh, [{"a", ?}, {"b", ?}, {?}]> {sdy.origin_sharding = {a = "mc_0_input: 0", b = "mc_0_output: 0"}} : tensor<32x32x32xf32> |
250 | 250 | return %2 : tensor<32x32x32xf32> |
251 | 251 | } |
| 252 | + |
| 253 | +// ----- |
| 254 | + |
| 255 | +// CHECK-LABEL: func private @bar(%arg0: tensor<8xf32>) |
| 256 | +func.func private @bar(%arg0: tensor<8xf32>) -> tensor<8xf32> { |
| 257 | + // CHECK-NEXT: %[[NEGATE:.*]] = stablehlo.negate %arg0 |
| 258 | + // CHECK-NEXT: return %[[NEGATE]] |
| 259 | + %0 = sdy.func_data_flow_edge %arg0 : tensor<8xf32> |
| 260 | + %1 = stablehlo.negate %0: tensor<8xf32> |
| 261 | + return %1 : tensor<8xf32> |
| 262 | +} |
| 263 | + |
| 264 | +// CHECK-LABEL: func @simple_call_graph_on_func_with_single_argument(%arg0: tensor<8xf32>) |
| 265 | +func.func @simple_call_graph_on_func_with_single_argument(%arg0: tensor<8xf32>) -> tensor<8xf32> { |
| 266 | + // CHECK-NEXT: %[[ABS0:.*]] = stablehlo.abs %arg0 |
| 267 | + // CHECK-NEXT: %[[CALL:.*]] = call @bar(%[[ABS0]]) |
| 268 | + // CHECK-NEXT: %[[ABS1:.*]] = stablehlo.abs %[[CALL]] |
| 269 | + // CHECK-NEXT: return %[[ABS1]] |
| 270 | + %0 = stablehlo.abs %arg0 : tensor<8xf32> |
| 271 | + %1 = call @bar(%0) : (tensor<8xf32>) -> (tensor<8xf32>) |
| 272 | + %2 = sdy.func_data_flow_edge %1 : tensor<8xf32> |
| 273 | + %3 = stablehlo.abs %2 : tensor<8xf32> |
| 274 | + return %3 : tensor<8xf32> |
| 275 | +} |
| 276 | + |
| 277 | +// ----- |
| 278 | + |
| 279 | +// CHECK-LABEL: @bar(%arg0: tensor<8xf32>) |
| 280 | +func.func private @bar(%arg0: tensor<8xf32>) -> tensor<8xf32> { |
| 281 | + // CHECK-NEXT: %[[NEGATE:.*]] = stablehlo.negate %arg0 |
| 282 | + // CHECK-NEXT: return %[[NEGATE]] |
| 283 | + %0 = sdy.func_data_flow_edge %arg0 : tensor<8xf32> |
| 284 | + %1 = stablehlo.negate %0: tensor<8xf32> |
| 285 | + return %1 : tensor<8xf32> |
| 286 | +} |
| 287 | + |
| 288 | +// CHECK-LABEL: @multiple_calls_on_same_func(%arg0: tensor<8xf32>) |
| 289 | +func.func @multiple_calls_on_same_func(%arg0: tensor<8xf32>) -> tensor<8xf32> { |
| 290 | + // CHECK-NEXT: %[[ABS0:.*]] = stablehlo.abs %arg0 |
| 291 | + // CHECK-NEXT: %[[CALL0:.*]] = call @bar(%[[ABS0]]) |
| 292 | + // CHECK-NEXT: %[[ABS1:.*]] = stablehlo.abs %[[CALL0]] |
| 293 | + // CHECK-NEXT: %[[CALL1:.*]] = call @bar(%[[ABS1]]) |
| 294 | + // CHECK-NEXT: %[[ABS2:.*]] = stablehlo.abs %[[CALL1]] |
| 295 | + // CHECK-NEXT: return %[[ABS2]] |
| 296 | + %0 = stablehlo.abs %arg0 : tensor<8xf32> |
| 297 | + %1 = call @bar(%0) : (tensor<8xf32>) -> (tensor<8xf32>) |
| 298 | + %2 = sdy.func_data_flow_edge %1 : tensor<8xf32> |
| 299 | + %3 = stablehlo.abs %2 : tensor<8xf32> |
| 300 | + %4 = call @bar(%3) : (tensor<8xf32>) -> (tensor<8xf32>) |
| 301 | + %5 = sdy.func_data_flow_edge %4 : tensor<8xf32> |
| 302 | + %6 = stablehlo.abs %5 : tensor<8xf32> |
| 303 | + return %6 : tensor<8xf32> |
| 304 | +} |
| 305 | + |
| 306 | +// ----- |
| 307 | + |
| 308 | +// CHECK-LABEL: @bar(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) |
| 309 | +func.func private @bar(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<8xf32> { |
| 310 | + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %arg1 |
| 311 | + // CHECK-NEXT: return %[[ADD]] |
| 312 | + %0 = sdy.func_data_flow_edge %arg0 : tensor<8xf32> |
| 313 | + %1 = sdy.func_data_flow_edge %arg1 : tensor<8xf32> |
| 314 | + %2 = stablehlo.add %0, %1: tensor<8xf32> |
| 315 | + return %2 : tensor<8xf32> |
| 316 | +} |
| 317 | + |
| 318 | +// CHECK-LABEL: @simple_call_graph_on_func_with_multiple_argument(%arg0: tensor<8xf32>) |
| 319 | +func.func @simple_call_graph_on_func_with_multiple_argument(%arg0: tensor<8xf32>) -> tensor<8xf32> { |
| 320 | + // CHECK-NEXT: %[[ABS0:.*]] = stablehlo.abs %arg0 |
| 321 | + // CHECK-NEXT: %[[ABS1:.*]] = stablehlo.abs %arg0 |
| 322 | + // CHECK-NEXT: %[[CALL:.*]] = call @bar(%[[ABS0]], %[[ABS1]]) |
| 323 | + // CHECK-NEXT: %[[ABS2:.*]] = stablehlo.abs %[[CALL]] |
| 324 | + // CHECK-NEXT: return %[[ABS2]] |
| 325 | + %0 = stablehlo.abs %arg0 : tensor<8xf32> |
| 326 | + %1 = stablehlo.abs %arg0 : tensor<8xf32> |
| 327 | + %2 = call @bar(%0, %1) : (tensor<8xf32>, tensor<8xf32>) -> (tensor<8xf32>) |
| 328 | + %3 = sdy.func_data_flow_edge %2 : tensor<8xf32> |
| 329 | + %4 = stablehlo.abs %3 : tensor<8xf32> |
| 330 | + return %4 : tensor<8xf32> |
| 331 | +} |
| 332 | + |
| 333 | +// ----- |
| 334 | + |
| 335 | +// CHECK-LABEL: @bar(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) |
| 336 | +func.func private @bar(%arg0: tensor<8xf32>, %arg1: tensor<8xf32>) -> tensor<8xf32> { |
| 337 | + // CHECK-NEXT: %[[ADD:.*]] = stablehlo.add %arg0, %arg1 |
| 338 | + // CHECK-NEXT: return %[[ADD]] |
| 339 | + %0 = sdy.func_data_flow_edge %arg0 : tensor<8xf32> |
| 340 | + %1 = sdy.func_data_flow_edge %arg1 : tensor<8xf32> |
| 341 | + %2 = stablehlo.add %0, %1: tensor<8xf32> |
| 342 | + return %2 : tensor<8xf32> |
| 343 | +} |
| 344 | + |
| 345 | +// CHECK-LABEL: @simple_call_graph_on_func_with_multiple_argument_same_operand(%arg0: tensor<8xf32>) |
| 346 | +func.func @simple_call_graph_on_func_with_multiple_argument_same_operand(%arg0: tensor<8xf32>) -> tensor<8xf32> { |
| 347 | + // CHECK-NEXT: %[[ABS0:.*]] = stablehlo.abs %arg0 |
| 348 | + // CHECK-NEXT: %[[CALL:.*]] = call @bar(%[[ABS0]], %[[ABS0]]) |
| 349 | + // CHECK-NEXT: %[[ABS1:.*]] = stablehlo.abs %[[CALL]] |
| 350 | + // CHECK-NEXT: return %[[ABS1]] |
| 351 | + %0 = stablehlo.abs %arg0 : tensor<8xf32> |
| 352 | + %1 = call @bar(%0, %0) : (tensor<8xf32>, tensor<8xf32>) -> (tensor<8xf32>) |
| 353 | + %2 = sdy.func_data_flow_edge %1 : tensor<8xf32> |
| 354 | + %3 = stablehlo.abs %2 : tensor<8xf32> |
| 355 | + return %3 : tensor<8xf32> |
| 356 | +} |
| 357 | + |
| 358 | +// ----- |
| 359 | + |
| 360 | +sdy.mesh @mesh = <["a"=2]> |
| 361 | + |
| 362 | +// CHECK-LABEL: func private @bar(%arg0: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}]>}) |
| 363 | +func.func private @bar(%arg0: tensor<8xf32>) -> tensor<8xf32> { |
| 364 | + // CHECK-NEXT: %[[NEGATE:.*]] = stablehlo.negate %arg0 |
| 365 | + // CHECK-NEXT: return %[[NEGATE]] |
| 366 | + %0 = sdy.func_data_flow_edge %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}]>]>} : tensor<8xf32> |
| 367 | + %1 = stablehlo.negate %0: tensor<8xf32> |
| 368 | + return %1 : tensor<8xf32> |
| 369 | +} |
| 370 | + |
| 371 | +// CHECK-LABEL: func @simple_call_graph_on_func_with_sharded_argument(%arg0: tensor<8xf32>) |
| 372 | +func.func @simple_call_graph_on_func_with_sharded_argument(%arg0: tensor<8xf32>) -> tensor<8xf32> { |
| 373 | + // CHECK-NEXT: %[[ABS0:.*]] = stablehlo.abs %arg0 |
| 374 | + // CHECK-NEXT: %[[CALL:.*]] = call @bar(%[[ABS0]]) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}]>]>} |
| 375 | + // CHECK-NEXT: %[[ABS1:.*]] = stablehlo.abs %[[CALL]] |
| 376 | + // CHECK-NEXT: %[[ABS2:.*]] = stablehlo.abs %[[CALL]] |
| 377 | + // CHECK-NEXT: return %[[ABS1]] |
| 378 | + %0 = stablehlo.abs %arg0 : tensor<8xf32> |
| 379 | + %1 = call @bar(%0) : (tensor<8xf32>) -> (tensor<8xf32>) |
| 380 | + %2 = sdy.func_data_flow_edge %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}]>]>} : tensor<8xf32> |
| 381 | + %3 = stablehlo.abs %2 : tensor<8xf32> |
| 382 | + %4 = stablehlo.abs %2 : tensor<8xf32> |
| 383 | + return %3 : tensor<8xf32> |
| 384 | +} |
| 385 | + |
| 386 | +// ----- |
| 387 | + |
| 388 | +sdy.mesh @mesh = <["a"=2]> |
| 389 | + |
| 390 | +// CHECK-LABEL: func private @bar(%arg0: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}]>}) |
| 391 | +func.func private @bar(%arg0: tensor<8xf32>) -> tensor<8xf32> { |
| 392 | + // CHECK-NEXT: %[[NEGATE:.*]] = stablehlo.negate %arg0 |
| 393 | + // CHECK-NEXT: return %[[NEGATE]] |
| 394 | + %0 = sdy.func_data_flow_edge %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}]>]>} : tensor<8xf32> |
| 395 | + %1 = stablehlo.negate %0: tensor<8xf32> |
| 396 | + return %1 : tensor<8xf32> |
| 397 | +} |
| 398 | + |
| 399 | +// CHECK-LABEL: func @func_data_flow_edge_has_sharding_call_does_not(%arg0: tensor<8xf32>) |
| 400 | +func.func @func_data_flow_edge_has_sharding_call_does_not(%arg0: tensor<8xf32>) -> tensor<8xf32> { |
| 401 | + // CHECK-NEXT: %[[ABS0:.*]] = stablehlo.abs %arg0 |
| 402 | + // CHECK-NEXT: %[[CALL:.*]] = call @bar(%[[ABS0]]) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{?}]>]>} |
| 403 | + // CHECK-NEXT: %[[ABS1:.*]] = stablehlo.abs %[[CALL]] |
| 404 | + // CHECK-NEXT: %[[ABS2:.*]] = stablehlo.abs %[[CALL]] |
| 405 | + // CHECK-NEXT: return %[[ABS1]] |
| 406 | + %0 = stablehlo.abs %arg0 : tensor<8xf32> |
| 407 | + %1 = call @bar(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}]>]>} : (tensor<8xf32>) -> (tensor<8xf32>) |
| 408 | + %2 = sdy.func_data_flow_edge %1 : tensor<8xf32> |
| 409 | + %3 = stablehlo.abs %2 : tensor<8xf32> |
| 410 | + %4 = stablehlo.abs %2 : tensor<8xf32> |
| 411 | + return %3 : tensor<8xf32> |
| 412 | +} |
| 413 | + |
| 414 | +// ----- |
| 415 | + |
| 416 | +sdy.mesh @mesh = <["a"=2]> |
| 417 | + |
| 418 | +// CHECK-LABEL: func private @bar(%arg0: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}]>}) |
| 419 | +func.func private @bar(%arg0: tensor<8xf32>) -> tensor<8xf32> { |
| 420 | + // CHECK-NEXT: %[[NEGATE:.*]] = stablehlo.negate %arg0 |
| 421 | + // CHECK-NEXT: return %[[NEGATE]] |
| 422 | + %0 = sdy.func_data_flow_edge %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}]>]>} : tensor<8xf32> |
| 423 | + %1 = stablehlo.negate %0: tensor<8xf32> |
| 424 | + return %1 : tensor<8xf32> |
| 425 | +} |
| 426 | + |
| 427 | +// CHECK-LABEL: func private @foo(%arg0: tensor<8xf32> {sdy.sharding = #sdy.sharding<@mesh, [{"a"}]>}) |
| 428 | +func.func private @foo(%arg0: tensor<8xf32>) -> tensor<8xf32> { |
| 429 | + // CHECK-NEXT: %[[CALL:.*]] = call @bar(%arg0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{?}]>]>} |
| 430 | + // CHECK-NEXT: %[[ABS:.*]] = stablehlo.abs %[[CALL]] |
| 431 | + // CHECK-NEXT: return %[[ABS]] |
| 432 | + %0 = sdy.func_data_flow_edge %arg0 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}]>]>} : tensor<8xf32> |
| 433 | + %1 = call @bar(%0) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}]>]>} : (tensor<8xf32>) -> (tensor<8xf32>) |
| 434 | + %2 = sdy.func_data_flow_edge %1 : tensor<8xf32> |
| 435 | + %3 = stablehlo.abs %2 : tensor<8xf32> |
| 436 | + return %3 : tensor<8xf32> |
| 437 | +} |
| 438 | + |
| 439 | +// CHECK-LABEL: func @main_calls_foo_calls_bar(%arg0: tensor<8xf32>) |
| 440 | +func.func @main_calls_foo_calls_bar(%arg0: tensor<8xf32>) -> tensor<8xf32> { |
| 441 | + // CHECK-NEXT: %[[ABS0:.*]] = stablehlo.abs %arg0 |
| 442 | + // CHECK-NEXT: %[[CALL:.*]] = call @foo(%[[ABS0]]) {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}]>]>} |
| 443 | + // CHECK-NEXT: %[[ABS1:.*]] = stablehlo.abs %[[CALL]] |
| 444 | + // CHECK-NEXT: return %[[ABS1]] |
| 445 | + %0 = stablehlo.abs %arg0 : tensor<8xf32> |
| 446 | + %1 = call @foo(%0) : (tensor<8xf32>) -> (tensor<8xf32>) |
| 447 | + %2 = sdy.func_data_flow_edge %1 {sdy.sharding = #sdy.sharding_per_value<[<@mesh, [{"a"}]>]>} : tensor<8xf32> |
| 448 | + %3 = stablehlo.abs %2 : tensor<8xf32> |
| 449 | + return %3 : tensor<8xf32> |
| 450 | +} |
0 commit comments