Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 8ace2ee

Browse files
committed
Implemented atrous convolution for depthwiseConv2D
* Modified cpu depthwiseConv2D logic to be similar to that on the GPU, so that dilation can be easily applied. Still need to add more tests for depthwise atrous convolution
1 parent ea09a10 commit 8ace2ee

File tree

4 files changed

+59
-20
lines changed

4 files changed

+59
-20
lines changed

src/kernels/backend_cpu.ts

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -869,19 +869,17 @@ export class MathBackendCPU implements KernelBackend {
869869
for (let wR = 0; wR < filterHeight; wR++) {
870870
const xR = xRCorner + wR * dilationHeight;
871871

872-
if (xR < 0 || xR >= convInfo.inHeight)
873-
continue;
872+
if (xR < 0 || xR >= convInfo.inHeight) continue;
874873

875-
for(let wC = 0; wC < filterWidth; wC++) {
874+
for (let wC = 0; wC < filterWidth; wC++) {
876875
const xC = xCCorner + wC * dilationWidth;
877876

878877
if (xC < 0 || xC >= convInfo.inWidth) {
879878
continue;
880879
}
881880

882881
for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
883-
const pixel = x.get(
884-
b, xR, xC, d1);
882+
const pixel = x.get(b, xR, xC, d1);
885883
const weight = filter.get(wR, wC, d1, d2);
886884
dotProd += pixel * weight;
887885
}
@@ -987,6 +985,8 @@ export class MathBackendCPU implements KernelBackend {
987985
Tensor4D {
988986
const filterHeight = convInfo.filterHeight;
989987
const filterWidth = convInfo.filterWidth;
988+
const dilationHeight = convInfo.dilationHeight;
989+
const dilationWidth = convInfo.dilationWidth;
990990
const padLeft = convInfo.padInfo.left;
991991
const padTop = convInfo.padInfo.top;
992992
const chMul = convInfo.outChannels / convInfo.inChannels;
@@ -996,18 +996,20 @@ export class MathBackendCPU implements KernelBackend {
996996
for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
997997
for (let yR = 0; yR < convInfo.outHeight; ++yR) {
998998
const xRCorner = yR * convInfo.strideHeight - padLeft;
999-
const xRMin = Math.max(0, xRCorner);
1000-
const xRMax = Math.min(convInfo.inHeight, filterHeight + xRCorner);
1001999
for (let yC = 0; yC < convInfo.outWidth; ++yC) {
10021000
const xCCorner = yC * convInfo.strideWidth - padTop;
1003-
const xCMin = Math.max(0, xCCorner);
1004-
const xCMax = Math.min(convInfo.inWidth, filterWidth + xCCorner);
10051001
for (let q = 0; q < chMul; ++q) {
10061002
let dotProd = 0;
1007-
for (let xR = xRMin; xR < xRMax; ++xR) {
1008-
const wR = xR - xRCorner;
1009-
for (let xC = xCMin; xC < xCMax; ++xC) {
1010-
const wC = xC - xCCorner;
1003+
for (let wR = 0; wR < filterHeight; ++wR) {
1004+
const xR = xRCorner + wR * dilationHeight;
1005+
1006+
if (xR < 0 || xR >= convInfo.inHeight) continue;
1007+
1008+
for (let wC = 0; wC < filterWidth; ++wC) {
1009+
const xC = xCCorner + wC * dilationWidth;
1010+
1011+
if (xC < 0 || xC >= convInfo.inWidth) continue;
1012+
10111013
const pixel = x.get(b, xR, xC, d1);
10121014
const weight = filter.get(wR, wC, d1, q);
10131015
dotProd += pixel * weight;
@@ -1019,6 +1021,7 @@ export class MathBackendCPU implements KernelBackend {
10191021
}
10201022
}
10211023
}
1024+
10221025
return y.toTensor();
10231026
}
10241027

src/kernels/webgl/conv_gpu_depthwise.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ export class DepthwiseConv2DProgram implements GPGPUProgram {
3232
const padLeft = convInfo.padInfo.left;
3333
const strideHeight = convInfo.strideHeight;
3434
const strideWidth = convInfo.strideWidth;
35+
const dilationHeight = convInfo.dilationHeight;
36+
const dilationWidth = convInfo.dilationWidth;
3537
const filterHeight = convInfo.filterHeight;
3638
const filterWidth = convInfo.filterWidth;
3739
const channelMul = convInfo.outChannels / convInfo.inChannels;
@@ -56,14 +58,14 @@ export class DepthwiseConv2DProgram implements GPGPUProgram {
5658
float dotProd = 0.0;
5759
// TODO(dsmilkov): Flatten the two for loops and vec4 the operations.
5860
for (int wR = 0; wR < ${filterHeight}; wR++) {
59-
int xR = xRCorner + wR;
61+
int xR = xRCorner + wR * ${dilationHeight};
6062
6163
if (xR < 0 || xR >= ${xNumRows}) {
6264
continue;
6365
}
6466
6567
for (int wC = 0; wC < ${filterWidth}; wC++) {
66-
int xC = xCCorner + wC;
68+
int xC = xCCorner + wC * ${dilationWidth};
6769
6870
if (xC < 0 || xC >= ${xNumCols}) {
6971
continue;

src/ops/conv.ts

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -423,11 +423,11 @@ export class ConvOps {
423423
if (dilations == null) {
424424
dilations = [1, 1];
425425
}
426-
const [dilationHeight, dilationWidth] = parseTupleParam(dilations);
427426
util.assert(
428-
dilationHeight === 1 && dilationWidth === 1,
429-
'Error in depthwiseConv2D: dilation rates greater than 1 are not yet ' +
430-
`supported. Got dilations '${dilations}'`);
427+
eitherStridesOrDilationsAreOne(strides, dilations),
428+
'Error in depthwiseConv2d: Either strides or dilations must be 1.' +
429+
`Got strides ${strides} and dilations '${dilations}'`);
430+
431431
if (dimRoundingMode != null) {
432432
util.assert(
433433
util.isInt(pad as number),

src/ops/conv2d_depthwise_test.ts

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ import {ALL_ENVS, describeWithFlags, expectArraysClose} from '../test_util';
2020
import {Rank} from '../types';
2121

2222
describeWithFlags('depthwiseConv2D', ALL_ENVS, () => {
23-
it('input=1x3x3x1,f=2,s=1,p=valid,chMul=1', () => {
23+
it('input=1x3x3x1,f=2,s=1,d=1,p=valid,chMul=1', () => {
2424
const fSize = 2;
2525
const pad = 'valid';
2626
const stride = 1;
@@ -44,6 +44,40 @@ describeWithFlags('depthwiseConv2D', ALL_ENVS, () => {
4444
expectArraysClose(result, expected);
4545
});
4646

47+
it('input=1x3x3x1,f=2,s=1,d=2,p=valid,chMul=1', () => {
48+
const fSize = 2;
49+
const pad = 'valid';
50+
const stride = 1;
51+
const dilation = 2;
52+
const chMul = 1;
53+
const inDepth = 1;
54+
55+
const x = dl.tensor4d(
56+
[
57+
0.230664, 0.987388, 0.0685208, 0.419224, 0.887861, 0.731641,
58+
0.0741907, 0.409265, 0.351377
59+
],
60+
[1, 3, 3, inDepth]);
61+
const w = dl.tensor4d(
62+
[0.303873, 0.229223, 0.144333, 0.803373],
63+
[fSize, fSize, inDepth, chMul],
64+
);
65+
// adding a dilation rate is equivalent to using a filter
66+
// with 0s for the dilation rate
67+
const fSizeDilated = fSize + (fSize - 1) * (dilation - 1);
68+
const wDilated = dl.tensor4d(
69+
[0.303873, 0, 0.229223, 0, 0, 0, 0.144333, 0, 0.803373],
70+
[fSizeDilated, fSizeDilated, inDepth, chMul],
71+
);
72+
73+
const result = dl.depthwiseConv2d(x, w, stride, pad, dilation);
74+
75+
const expectedResult = dl.depthwiseConv2d(x, wDilated, stride, pad);
76+
77+
expect(result.shape).toEqual(expectedResult.shape);
78+
expectArraysClose(result, expectedResult);
79+
});
80+
4781
it('input=1x3x3x2,f=2,s=1,p=same,chMul=1', () => {
4882
const fSize = 2;
4983
const pad = 'same';

0 commit comments

Comments
 (0)