Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 40a6bfd

Browse files
oveddandsmilkov
authored andcommitted
2d atrous convolution and atrous depthwise convolution (#794)
* Implemented atrous 1d and 2d convolution. * Updated backend gpu and cpu to do atrous convolution when there is a dilation rate. * Refactored backend_cpu.conv2d to be consistent with backend_gpu.conv2d, and also to make the way it does dilated convolution consistent. * Added tests for 1d and 2d convolution with dilation rates that show the effect on the filter when dilation rates are set. * Updated computeDefaultPad to account for a dilation rate. Implemented atrous convolution for depthwiseConv2D * Modified cpu depthwiseConv2D logic to be similar to that on the GPU, so that dilation can be easily applied. Still need to add more tests for depthwise atrous convolution Per feedback, changed order of parameters to match Tensorflow api. Added dataFormat parameter to conv1d, conv2d, and depthwiseConv2d, but did not yet implement that parameter; it defaults to a value and cannot be a different value until the functionality is implemented raising error when gradient is done with atrous convolution * Made ordering of parameters consistent in convolution methods with new dilation parameters * added depthwise conv tests
1 parent fbed961 commit 40a6bfd

File tree

9 files changed

+440
-53
lines changed

9 files changed

+440
-53
lines changed

src/kernels/backend_cpu.ts

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -863,6 +863,8 @@ export class MathBackendCPU implements KernelBackend {
863863
conv2d(x: Tensor4D, filter: Tensor4D, convInfo: Conv2DInfo): Tensor4D {
864864
const filterHeight = convInfo.filterHeight;
865865
const filterWidth = convInfo.filterWidth;
866+
const dilationHeight = convInfo.dilationHeight;
867+
const dilationWidth = convInfo.dilationWidth;
866868
const padLeft = convInfo.padInfo.left;
867869
const padTop = convInfo.padInfo.top;
868870
const y = ops.buffer<Rank.R4>(convInfo.outShape, x.dtype);
@@ -871,17 +873,24 @@ export class MathBackendCPU implements KernelBackend {
871873
for (let d2 = 0; d2 < convInfo.outChannels; ++d2) {
872874
for (let yR = 0; yR < convInfo.outHeight; ++yR) {
873875
const xRCorner = yR * convInfo.strideHeight - padLeft;
874-
const xRMin = Math.max(0, xRCorner);
875-
const xRMax = Math.min(convInfo.inHeight, filterHeight + xRCorner);
876876
for (let yC = 0; yC < convInfo.outWidth; ++yC) {
877877
const xCCorner = yC * convInfo.strideWidth - padTop;
878-
const xCMin = Math.max(0, xCCorner);
879-
const xCMax = Math.min(convInfo.inWidth, filterWidth + xCCorner);
878+
880879
let dotProd = 0;
881-
for (let xR = xRMin; xR < xRMax; ++xR) {
882-
const wR = xR - xRCorner;
883-
for (let xC = xCMin; xC < xCMax; ++xC) {
884-
const wC = xC - xCCorner;
880+
for (let wR = 0; wR < filterHeight; wR++) {
881+
const xR = xRCorner + wR * dilationHeight;
882+
883+
if (xR < 0 || xR >= convInfo.inHeight) {
884+
continue;
885+
}
886+
887+
for (let wC = 0; wC < filterWidth; wC++) {
888+
const xC = xCCorner + wC * dilationWidth;
889+
890+
if (xC < 0 || xC >= convInfo.inWidth) {
891+
continue;
892+
}
893+
885894
for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
886895
const pixel = x.get(b, xR, xC, d1);
887896
const weight = filter.get(wR, wC, d1, d2);
@@ -989,6 +998,8 @@ export class MathBackendCPU implements KernelBackend {
989998
Tensor4D {
990999
const filterHeight = convInfo.filterHeight;
9911000
const filterWidth = convInfo.filterWidth;
1001+
const dilationHeight = convInfo.dilationHeight;
1002+
const dilationWidth = convInfo.dilationWidth;
9921003
const padLeft = convInfo.padInfo.left;
9931004
const padTop = convInfo.padInfo.top;
9941005
const chMul = convInfo.outChannels / convInfo.inChannels;
@@ -998,18 +1009,24 @@ export class MathBackendCPU implements KernelBackend {
9981009
for (let d1 = 0; d1 < convInfo.inChannels; ++d1) {
9991010
for (let yR = 0; yR < convInfo.outHeight; ++yR) {
10001011
const xRCorner = yR * convInfo.strideHeight - padLeft;
1001-
const xRMin = Math.max(0, xRCorner);
1002-
const xRMax = Math.min(convInfo.inHeight, filterHeight + xRCorner);
10031012
for (let yC = 0; yC < convInfo.outWidth; ++yC) {
10041013
const xCCorner = yC * convInfo.strideWidth - padTop;
1005-
const xCMin = Math.max(0, xCCorner);
1006-
const xCMax = Math.min(convInfo.inWidth, filterWidth + xCCorner);
10071014
for (let q = 0; q < chMul; ++q) {
10081015
let dotProd = 0;
1009-
for (let xR = xRMin; xR < xRMax; ++xR) {
1010-
const wR = xR - xRCorner;
1011-
for (let xC = xCMin; xC < xCMax; ++xC) {
1012-
const wC = xC - xCCorner;
1016+
for (let wR = 0; wR < filterHeight; ++wR) {
1017+
const xR = xRCorner + wR * dilationHeight;
1018+
1019+
if (xR < 0 || xR >= convInfo.inHeight) {
1020+
continue;
1021+
}
1022+
1023+
for (let wC = 0; wC < filterWidth; ++wC) {
1024+
const xC = xCCorner + wC * dilationWidth;
1025+
1026+
if (xC < 0 || xC >= convInfo.inWidth) {
1027+
continue;
1028+
}
1029+
10131030
const pixel = x.get(b, xR, xC, d1);
10141031
const weight = filter.get(wR, wC, d1, q);
10151032
dotProd += pixel * weight;
@@ -1021,6 +1038,7 @@ export class MathBackendCPU implements KernelBackend {
10211038
}
10221039
}
10231040
}
1041+
10241042
return y.toTensor();
10251043
}
10261044

src/kernels/webgl/conv_gpu.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ export class Conv2DProgram implements GPGPUProgram {
2929
const padLeft = convInfo.padInfo.left;
3030
const strideHeight = convInfo.strideHeight;
3131
const strideWidth = convInfo.strideWidth;
32+
const dilationHeight = convInfo.dilationHeight;
33+
const dilationWidth = convInfo.dilationWidth;
3234
const filterHeight = convInfo.filterHeight;
3335
const filterWidth = convInfo.filterWidth;
3436

@@ -52,14 +54,14 @@ export class Conv2DProgram implements GPGPUProgram {
5254
// ? = to be determined. : = across all values in that axis.
5355
float dotProd = 0.0;
5456
for (int wR = 0; wR < ${filterHeight}; wR++) {
55-
int xR = xRCorner + wR;
57+
int xR = xRCorner + wR * ${dilationHeight};
5658
5759
if (xR < 0 || xR >= ${convInfo.inHeight}) {
5860
continue;
5961
}
6062
6163
for (int wC = 0; wC < ${filterWidth}; wC++) {
62-
int xC = xCCorner + wC;
64+
int xC = xCCorner + wC * ${dilationWidth};
6365
6466
if (xC < 0 || xC >= ${convInfo.inWidth}) {
6567
continue;

src/kernels/webgl/conv_gpu_depthwise.ts

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ export class DepthwiseConv2DProgram implements GPGPUProgram {
3232
const padLeft = convInfo.padInfo.left;
3333
const strideHeight = convInfo.strideHeight;
3434
const strideWidth = convInfo.strideWidth;
35+
const dilationHeight = convInfo.dilationHeight;
36+
const dilationWidth = convInfo.dilationWidth;
3537
const filterHeight = convInfo.filterHeight;
3638
const filterWidth = convInfo.filterWidth;
3739
const channelMul = convInfo.outChannels / convInfo.inChannels;
@@ -56,14 +58,14 @@ export class DepthwiseConv2DProgram implements GPGPUProgram {
5658
float dotProd = 0.0;
5759
// TODO(dsmilkov): Flatten the two for loops and vec4 the operations.
5860
for (int wR = 0; wR < ${filterHeight}; wR++) {
59-
int xR = xRCorner + wR;
61+
int xR = xRCorner + wR * ${dilationHeight};
6062
6163
if (xR < 0 || xR >= ${xNumRows}) {
6264
continue;
6365
}
6466
6567
for (int wC = 0; wC < ${filterWidth}; wC++) {
66-
int xC = xCCorner + wC;
68+
int xC = xCCorner + wC * ${dilationWidth};
6769
6870
if (xC < 0 || xC >= ${xNumCols}) {
6971
continue;

src/ops/conv.ts

Lines changed: 68 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ export class ConvOps {
4040
* - For more info, see this guide:
4141
* [https://www.tensorflow.org/api_guides/python/nn#Convolution](
4242
* https://www.tensorflow.org/api_guides/python/nn#Convolution)
43+
* @param dataFormat An optional string from "NWC", "NCW". Defaults to "NWC",
44+
* the data is stored in the order of [batch, in_width, in_channels]. Only
45+
* "NWC" is currently supported.
46+
* @param dilation The dilation rate in which we sample input values in
47+
* atrous convolution. Defaults to `1`. If it is greater than 1, then
48+
* stride must be `1`.
4349
* @param dimRoundingMode The rounding mode used when computing output
4450
* dimensions if pad is a number. If none is provided, it will not round
4551
* and error if the output is of fractional size.
@@ -48,6 +54,7 @@ export class ConvOps {
4854
@operation
4955
static conv1d<T extends Tensor2D|Tensor3D>(
5056
input: T, filter: Tensor3D, stride: number, pad: 'valid'|'same'|number,
57+
dataFormat: 'NWC'|'NCW' = 'NWC', dilation = 1,
5158
dimRoundingMode?: 'floor'|'round'|'ceil'): T {
5259
let input3D = input as Tensor3D;
5360
let reshapedTo3D = false;
@@ -74,15 +81,27 @@ export class ConvOps {
7481
input3D.shape[2] === filter.shape[1],
7582
`Error in conv1d: depth of input (${input3D.shape[2]}) must match ` +
7683
`input depth for filter ${filter.shape[1]}.`);
84+
util.assert(
85+
eitherStridesOrDilationsAreOne(stride, dilation),
86+
'Error in conv1D: Either stride or dilation must be 1.' +
87+
`Got stride ${stride} and dilation '${dilation}'`);
88+
util.assert(
89+
dataFormat === 'NWC',
90+
`Error in conv1d: got dataFormat of ${
91+
dataFormat} but only NWC is currently supported.`);
7792

7893
const filter4D =
7994
filter.as4D(1, filter.shape[0], filter.shape[1], filter.shape[2]);
8095
const input4D =
8196
input3D.as4D(input3D.shape[0], 1, input3D.shape[1], input3D.shape[2]);
8297
const strides: [number, number] = [1, stride];
98+
const dilations: [number, number] = [1, dilation];
99+
100+
const conv2dDataFormat = 'NHWC';
83101

84-
const res =
85-
ConvOps.conv2d(input4D, filter4D, strides, pad, dimRoundingMode);
102+
const res = ConvOps.conv2d(
103+
input4D, filter4D, strides, pad, conv2dDataFormat, dilations,
104+
dimRoundingMode);
86105

87106
if (reshapedTo3D) {
88107
return res.as2D(res.shape[2], res.shape[3]) as T;
@@ -108,6 +127,15 @@ export class ConvOps {
108127
* - For more info, see this guide:
109128
* [https://www.tensorflow.org/api_guides/python/nn#Convolution](
110129
* https://www.tensorflow.org/api_guides/python/nn#Convolution)
130+
* @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
131+
* "NHWC". Specify the data format of the input and output data. With the
132+
* default format "NHWC", the data is stored in the order of: [batch,
133+
* height, width, channels]. Only "NHWC" is currently supported.
134+
* @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
135+
* in which we sample input values across the height and width dimensions
136+
* in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single
137+
* number, then `dilationHeight == dilationWidth`. If it is greater than
138+
* 1, then all values of `strides` must be 1.
111139
* @param dimRoundingMode The rounding mode used when computing output
112140
* dimensions if pad is a number. If none is provided, it will not round
113141
* and error if the output is of fractional size.
@@ -116,7 +144,9 @@ export class ConvOps {
116144
@operation
117145
static conv2d<T extends Tensor3D|Tensor4D>(
118146
x: T, filter: Tensor4D, strides: [number, number]|number,
119-
pad: 'valid'|'same'|number, dimRoundingMode?: 'floor'|'round'|'ceil'): T {
147+
pad: 'valid'|'same'|number, dataFormat: 'NHWC'|'NCHW' = 'NHWC',
148+
dilations: [number, number]|number = [1, 1],
149+
dimRoundingMode?: 'floor'|'round'|'ceil'): T {
120150
let x4D = x as Tensor4D;
121151
let reshapedTo4D = false;
122152

@@ -142,13 +172,24 @@ export class ConvOps {
142172
x4D.shape[3] === filter.shape[2],
143173
`Error in conv2d: depth of input (${x4D.shape[3]}) must match ` +
144174
`input depth for filter ${filter.shape[2]}.`);
145-
146-
const dilations = 1;
175+
util.assert(
176+
eitherStridesOrDilationsAreOne(strides, dilations),
177+
'Error in conv2D: Either strides or dilations must be 1.' +
178+
`Got strides ${strides} and dilations '${dilations}'`);
179+
util.assert(
180+
dataFormat === 'NHWC',
181+
`Error in conv2d: got dataFormat of ${
182+
dataFormat} but only NHWC is currently supported.`);
147183

148184
const convInfo = conv_util.computeConv2DInfo(
149185
x4D.shape, filter.shape, strides, dilations, pad, dimRoundingMode);
150186

151187
const grad = (dy: Tensor4D) => {
188+
util.assert(
189+
tupleValuesAreOne(dilations),
190+
'Error in gradient of conv2D: dilation rates greater than 1 are not' +
191+
`yet supported in gradients. Got dilations '${dilations}'`);
192+
152193
return {
153194
x: () => ConvOps.conv2dDerInput(x4D.shape, dy, filter, strides, pad),
154195
filter: () =>
@@ -375,9 +416,13 @@ export class ConvOps {
375416
* https://www.tensorflow.org/api_guides/python/nn#Convolution)
376417
* @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
377418
* in which we sample input values across the height and width dimensions
378-
* in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single
419+
* in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single
379420
* number, then `dilationHeight == dilationWidth`. If it is greater than
380421
* 1, then all values of `strides` must be 1.
422+
* @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
423+
* "NHWC". Specify the data format of the input and output data. With the
424+
* default format "NHWC", the data is stored in the order of: [batch,
425+
* height, width, channels]. Only "NHWC" is currently supported.
381426
* @param dimRoundingMode The rounding mode used when computing output
382427
* dimensions if pad is a number. If none is provided, it will not round
383428
* and error if the output is of fractional size.
@@ -386,7 +431,8 @@ export class ConvOps {
386431
@operation
387432
static depthwiseConv2d<T extends Tensor3D|Tensor4D>(
388433
input: T, filter: Tensor4D, strides: [number, number]|number,
389-
pad: 'valid'|'same'|number, dilations: [number, number]|number = [1, 1],
434+
pad: 'valid'|'same'|number, dataFormat: 'NHWC'|'NCHW' = 'NHWC',
435+
dilations: [number, number]|number = [1, 1],
390436
dimRoundingMode?: 'floor'|'round'|'ceil'): T {
391437
let input4D = input as Tensor4D;
392438
let reshapedTo4D = false;
@@ -410,11 +456,11 @@ export class ConvOps {
410456
if (dilations == null) {
411457
dilations = [1, 1];
412458
}
413-
const [dilationHeight, dilationWidth] = parseTupleParam(dilations);
414459
util.assert(
415-
dilationHeight === 1 && dilationWidth === 1,
416-
'Error in depthwiseConv2D: dilation rates greater than 1 are not yet ' +
417-
`supported. Got dilations '${dilations}'`);
460+
eitherStridesOrDilationsAreOne(strides, dilations),
461+
'Error in depthwiseConv2d: Either strides or dilations must be 1.' +
462+
`Got strides ${strides} and dilations '${dilations}'`);
463+
418464
if (dimRoundingMode != null) {
419465
util.assert(
420466
util.isInt(pad as number),
@@ -438,3 +484,14 @@ export class ConvOps {
438484
function parseTupleParam(param: number|[number, number]): [number, number] {
439485
return typeof param === 'number' ? [param, param] : param;
440486
}
487+
488+
function tupleValuesAreOne(param: number|[number, number]): boolean {
489+
const [dimA, dimB] = parseTupleParam(param);
490+
return dimA === 1 && dimB === 1;
491+
}
492+
493+
function eitherStridesOrDilationsAreOne(
494+
strides: number|[number, number],
495+
dilations: number|[number, number]): boolean {
496+
return tupleValuesAreOne(strides) || tupleValuesAreOne(dilations);
497+
}

0 commit comments

Comments
 (0)