Skip to content

Split dot operation on K.#10325

Open
copybara-service[bot] wants to merge 1 commit into
masterfrom
test_921199785
Open

Split dot operation on K.#10325
copybara-service[bot] wants to merge 1 commit into
masterfrom
test_921199785

Conversation

@copybara-service

Copy link
Copy Markdown
Contributor

Split dot operation on K.

Currently split_k is the length of k so we don't expect major performance differences.

Sample IR for f32[110,10240,2560].

Before:

null = constant_buffer([], 0, {}) {
 v3 = allocate(automatic, 4, {
   {},
   {[0, 15], 4, <>},
   {[0, 10239], 64, <>},
   {[(buffer_min(out2, 0) / 16), (buffer_max(out2, 0) / 16)], <>, <>}
 }) {
  out2.d0 = loop(parallel, [buffer_min(out2, 0), buffer_max(out2, 0)], 32) {
   closure {in0, in1, out2, v3, null, out2.d0} in {
    v3.out2.d0 = crop_dim(v3, 3, [(out2.d0 / 16), ((out2.d0 / 16) + 1)]) {
     call(pack_b, {in1}, {v3.out2.d0}, {})
    }
    out2.out2.d0 = crop_dim(out2, 0, [out2.d0, (out2.d0 + 31)]) {
     call(dot num_k_dims=1, {in0, v3, null}, {out2.out2.d0}, {})
    }
   }
  }
 }
}

After where k_split=1024.

null = constant_buffer([], 0, {}) {
 reduction = allocate(automatic, 0, {
   {[0, 10239], 0, <>}
 }) {
  k0#0 = loop(serial, [0, 10239], 1024) {
   reduction = crop_dim(reduction, 0, [k0#0, (k0#0 + 1023)]) {
    v3 = allocate(automatic, 4, {
      {},
      {[0, 15], 4, <>},
      {[k0#0, buffer_max(reduction, 0)], 64, <>},
      {[(buffer_min(out2, 0) / 16), (buffer_max(out2, 0) / 16)], <>, <>}
    }) {
     d0#0 = loop(parallel, [buffer_min(out2, 0), buffer_max(out2, 0)], 32) {
      closure {in0, in1, out2, v3, null, reduction, d0#0} in {
       v3.d0#0 = crop_dim(v3, 3, [(d0#0 / 16), ((d0#0 / 16) + 1)]) {
        call(pack_b, {in1}, {v3.d0#0}, {})
       }
       out2.d0#0 = crop_dim(out2, 0, [d0#0, (d0#0 + 31)]) {
        call(dot num_k_dims=1, {in0, v3, null}, {out2.d0#0, reduction}, {})
       }
      }
     }
    }
   }
  }
 }
}

Currently split_k is the length of k so we don't expect major performance differences.

Sample IR for `f32[110,10240,2560]`.

Before:
```
null = constant_buffer([], 0, {}) {
 v3 = allocate(automatic, 4, {
   {},
   {[0, 15], 4, <>},
   {[0, 10239], 64, <>},
   {[(buffer_min(out2, 0) / 16), (buffer_max(out2, 0) / 16)], <>, <>}
 }) {
  out2.d0 = loop(parallel, [buffer_min(out2, 0), buffer_max(out2, 0)], 32) {
   closure {in0, in1, out2, v3, null, out2.d0} in {
    v3.out2.d0 = crop_dim(v3, 3, [(out2.d0 / 16), ((out2.d0 / 16) + 1)]) {
     call(pack_b, {in1}, {v3.out2.d0}, {})
    }
    out2.out2.d0 = crop_dim(out2, 0, [out2.d0, (out2.d0 + 31)]) {
     call(dot num_k_dims=1, {in0, v3, null}, {out2.out2.d0}, {})
    }
   }
  }
 }
}
```

After where k_split=1024.
```
null = constant_buffer([], 0, {}) {
 reduction = allocate(automatic, 0, {
   {[0, 10239], 0, <>}
 }) {
  k0#0 = loop(serial, [0, 10239], 1024) {
   reduction = crop_dim(reduction, 0, [k0#0, (k0#0 + 1023)]) {
    v3 = allocate(automatic, 4, {
      {},
      {[0, 15], 4, <>},
      {[k0#0, buffer_max(reduction, 0)], 64, <>},
      {[(buffer_min(out2, 0) / 16), (buffer_max(out2, 0) / 16)], <>, <>}
    }) {
     d0#0 = loop(parallel, [buffer_min(out2, 0), buffer_max(out2, 0)], 32) {
      closure {in0, in1, out2, v3, null, reduction, d0#0} in {
       v3.d0#0 = crop_dim(v3, 3, [(d0#0 / 16), ((d0#0 / 16) + 1)]) {
        call(pack_b, {in1}, {v3.d0#0}, {})
       }
       out2.d0#0 = crop_dim(out2, 0, [d0#0, (d0#0 + 31)]) {
        call(dot num_k_dims=1, {in0, v3, null}, {out2.d0#0, reduction}, {})
       }
      }
     }
    }
   }
  }
 }
}
```
PiperOrigin-RevId: 921199785
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant