Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit 25a532b

Browse files
committed
Per feedback, changed order of parameters to match
Tensorflow api. Added dataFormat parameter to conv1d, conv2d, and depthwiseConv2d, but did not yet implement that parameter; it defaults to a value and cannot be a different value until the functionality is implemented
1 parent 321481d commit 25a532b

File tree

5 files changed

+106
-61
lines changed

5 files changed

+106
-61
lines changed

src/math.ts

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -344,32 +344,35 @@ export class NDArrayMath {
344344
/** @deprecated */
345345
conv1d<T extends Tensor2D|Tensor3D>(
346346
input: T, filter: Tensor3D, bias: Tensor1D|null, stride: number,
347-
dilation: number, pad: 'valid'|'same'|number,
347+
pad: 'valid'|'same'|number, dataFormat: 'NWC'|'NCW' = 'NWC', dilation = 1,
348348
dimRoundingMode?: 'floor'|'round'|'ceil'): T {
349349
if (bias != null) {
350350
util.assert(
351351
bias.rank === 1,
352352
`Error in conv1d: bias must be rank 1, but got rank ` +
353353
`${bias.rank}.`);
354354
}
355-
const res =
356-
ops.conv1d(input, filter, stride, dilation, pad, dimRoundingMode);
355+
const res = ops.conv1d(
356+
input, filter, stride, pad, dataFormat, dilation, dimRoundingMode);
357357
return res.add(bias) as T;
358358
}
359359

360360
/** @deprecated */
361361
conv2d<T extends Tensor3D|Tensor4D>(
362362
x: T, filter: Tensor4D, bias: Tensor1D|null,
363-
strides: [number, number]|number, dilations: [number, number]|number,
364-
pad: 'valid'|'same'|number, dimRoundingMode?: 'floor'|'round'|'ceil'): T {
363+
strides: [number, number]|number, pad: 'valid'|'same'|number,
364+
dataFormat: 'NHWC'|'NCHW' = 'NHWC',
365+
dilations: [number, number]|number = [1, 1],
366+
dimRoundingMode?: 'floor'|'round'|'ceil'): T {
365367
if (bias != null) {
366368
util.assert(
367369
bias.rank === 1,
368370
`Error in conv2d: bias must be rank 1, but got rank ` +
369371
`${bias.rank}.`);
370372
}
371373

372-
const res = ops.conv2d(x, filter, strides, dilations, pad, dimRoundingMode);
374+
const res = ops.conv2d(
375+
x, filter, strides, pad, dataFormat, dilations, dimRoundingMode);
373376
return res.add(bias) as T;
374377
}
375378

src/ops/conv.ts

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,22 @@ export class ConvOps {
4040
* - For more info, see this guide:
4141
* [https://www.tensorflow.org/api_guides/python/nn#Convolution](
4242
* https://www.tensorflow.org/api_guides/python/nn#Convolution)
43+
* @param dataFormat An optional string from "NWC", "NCW". Defaults to "NWC",
44+
* the data is stored in the order of [batch, in_width, in_channels]. Only
45+
* "NWC" is currently supported.
46+
* @param dilation The dilation rate in which we sample input values in
47+
* atrous convolution. Defaults to `1`. If it is greater than 1, then
48+
* stride must be `1`.
4349
* @param dimRoundingMode The rounding mode used when computing output
4450
* dimensions if pad is a number. If none is provided, it will not round
4551
* and error if the output is of fractional size.
4652
*/
4753
@doc({heading: 'Operations', subheading: 'Convolution'})
4854
@operation
4955
static conv1d<T extends Tensor2D|Tensor3D>(
50-
input: T, filter: Tensor3D, stride: number, dilation: number,
51-
pad: 'valid'|'same'|number, dimRoundingMode?: 'floor'|'round'|'ceil'): T {
56+
input: T, filter: Tensor3D, stride: number, pad: 'valid'|'same'|number,
57+
dataFormat: 'NWC'|'NCW' = 'NWC', dilation = 1,
58+
dimRoundingMode?: 'floor'|'round'|'ceil'): T {
5259
let input3D = input as Tensor3D;
5360
let reshapedTo3D = false;
5461
if (input.rank === 2) {
@@ -78,6 +85,10 @@ export class ConvOps {
7885
eitherStridesOrDilationsAreOne(stride, dilation),
7986
'Error in conv1D: Either stride or dilation must be 1.' +
8087
`Got stride ${stride} and dilation '${dilation}'`);
88+
util.assert(
89+
dataFormat === 'NWC',
90+
`Error in conv1d: got dataFormat of ${
91+
dataFormat} but only NWC is currently supported.`);
8192

8293
const filter4D =
8394
filter.as4D(1, filter.shape[0], filter.shape[1], filter.shape[2]);
@@ -86,8 +97,11 @@ export class ConvOps {
8697
const strides: [number, number] = [1, stride];
8798
const dilations: [number, number] = [1, dilation];
8899

100+
const conv2dDataFormat = 'NHWC';
101+
89102
const res = ConvOps.conv2d(
90-
input4D, filter4D, strides, dilations, pad, dimRoundingMode);
103+
input4D, filter4D, strides, pad, conv2dDataFormat, dilations,
104+
dimRoundingMode);
91105

92106
if (reshapedTo3D) {
93107
return res.as2D(res.shape[2], res.shape[3]) as T;
@@ -105,11 +119,6 @@ export class ConvOps {
105119
* `[filterHeight, filterWidth, inDepth, outDepth]`.
106120
* @param strides The strides of the convolution: `[strideHeight,
107121
* strideWidth]`.
108-
* @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
109-
* in which we sample input values across the height and width dimensions
110-
* in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single
111-
* number, then `dilationHeight == dilationWidth`. If it is greater than
112-
* 1, then all values of `strides` must be 1.
113122
* @param pad The type of padding algorithm.
114123
* - `same` and stride 1: output will be of same size as input,
115124
* regardless of filter size.
@@ -118,6 +127,15 @@ export class ConvOps {
118127
* - For more info, see this guide:
119128
* [https://www.tensorflow.org/api_guides/python/nn#Convolution](
120129
* https://www.tensorflow.org/api_guides/python/nn#Convolution)
130+
* @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
131+
* "NHWC". Specify the data format of the input and output data. With the
132+
* default format "NHWC", the data is stored in the order of: [batch,
133+
* height, width, channels]. Only "NHWC" is currently supported.
134+
* @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
135+
* in which we sample input values across the height and width dimensions
136+
* in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single
137+
* number, then `dilationHeight == dilationWidth`. If it is greater than
138+
* 1, then all values of `strides` must be 1.
121139
* @param dimRoundingMode The rounding mode used when computing output
122140
* dimensions if pad is a number. If none is provided, it will not round
123141
* and error if the output is of fractional size.
@@ -126,7 +144,8 @@ export class ConvOps {
126144
@operation
127145
static conv2d<T extends Tensor3D|Tensor4D>(
128146
x: T, filter: Tensor4D, strides: [number, number]|number,
129-
dilations: [number, number]|number = [1, 1], pad: 'valid'|'same'|number,
147+
pad: 'valid'|'same'|number, dataFormat: 'NHWC'|'NCHW' = 'NHWC',
148+
dilations: [number, number]|number = [1, 1],
130149
dimRoundingMode?: 'floor'|'round'|'ceil'): T {
131150
let x4D = x as Tensor4D;
132151
let reshapedTo4D = false;
@@ -157,6 +176,10 @@ export class ConvOps {
157176
eitherStridesOrDilationsAreOne(strides, dilations),
158177
'Error in conv2D: Either strides or dilations must be 1.' +
159178
`Got strides ${strides} and dilations '${dilations}'`);
179+
util.assert(
180+
dataFormat === 'NHWC',
181+
`Error in conv2d: got dataFormat of ${
182+
dataFormat} but only NHWC is currently supported.`);
160183

161184
const convInfo = conv_util.computeConv2DInfo(
162185
x4D.shape, filter.shape, strides, dilations, pad, dimRoundingMode);
@@ -391,6 +414,10 @@ export class ConvOps {
391414
* in atrous convolution. Defaults to `[1, 1]`. If `rate` is a single
392415
* number, then `dilationHeight == dilationWidth`. If it is greater than
393416
* 1, then all values of `strides` must be 1.
417+
* @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
418+
* "NHWC". Specify the data format of the input and output data. With the
419+
* default format "NHWC", the data is stored in the order of: [batch,
420+
* height, width, channels]. Only "NHWC" is currently supported.
394421
* @param dimRoundingMode The rounding mode used when computing output
395422
* dimensions if pad is a number. If none is provided, it will not round
396423
* and error if the output is of fractional size.
@@ -400,6 +427,7 @@ export class ConvOps {
400427
static depthwiseConv2d<T extends Tensor3D|Tensor4D>(
401428
input: T, filter: Tensor4D, strides: [number, number]|number,
402429
pad: 'valid'|'same'|number, dilations: [number, number]|number = [1, 1],
430+
dataFormat: 'NHWC'|'NCHW' = 'NHWC',
403431
dimRoundingMode?: 'floor'|'round'|'ceil'): T {
404432
let input4D = input as Tensor4D;
405433
let reshapedTo4D = false;

src/ops/conv1d_test.ts

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,13 @@ describeWithFlags('conv1d', ALL_ENVS, () => {
2828
const fSize = 1;
2929
const pad = 'same';
3030
const stride = 1;
31+
const dataFormat = 'NWC';
3132
const dilation = 1;
3233

3334
const x = dl.tensor3d([1, 2, 3, 4], inputShape);
3435
const w = dl.tensor3d([3], [fSize, inputDepth, outputDepth]);
3536

36-
const result = dl.conv1d(x, w, stride, dilation, pad);
37+
const result = dl.conv1d(x, w, stride, pad, dataFormat, dilation);
3738

3839
expect(result.shape).toEqual([2, 2, 1]);
3940
expectArraysClose(result, [3, 6, 9, 12]);
@@ -46,12 +47,13 @@ describeWithFlags('conv1d', ALL_ENVS, () => {
4647
const fSize = 2;
4748
const pad = 'valid';
4849
const stride = 1;
50+
const dataFormat = 'NWC';
4951
const dilation = 1;
5052

5153
const x = dl.tensor2d([1, 2, 3, 4], inputShape);
5254
const w = dl.tensor3d([2, 1], [fSize, inputDepth, outputDepth]);
5355

54-
const result = dl.conv1d(x, w, stride, dilation, pad);
56+
const result = dl.conv1d(x, w, stride, pad, dataFormat, dilation);
5557

5658
expect(result.shape).toEqual([3, 1]);
5759
expectArraysClose(result, [4, 7, 10]);
@@ -65,19 +67,20 @@ describeWithFlags('conv1d', ALL_ENVS, () => {
6567
const fSizeDilated = 3;
6668
const pad = 'valid';
6769
const stride = 1;
70+
const dataFormat = 'NWC';
6871
const dilation = 2;
6972
const dilationWEffective = 1;
7073

7174
const x = dl.tensor2d([1, 2, 3, 4], inputShape);
7275
const w = dl.tensor3d([2, 1], [fSize, inputDepth, outputDepth]);
7376
// adding a dilation rate is equivalent to using a filter
7477
// with 0s for the dilation rate
75-
const wDilated = dl.tensor3d(
76-
[2, 0, 1], [fSizeDilated, inputDepth, outputDepth]);
78+
const wDilated =
79+
dl.tensor3d([2, 0, 1], [fSizeDilated, inputDepth, outputDepth]);
7780

78-
const result = dl.conv1d(x, w, stride, dilation, pad);
79-
const expectedResult = dl.conv1d(
80-
x, wDilated, stride, dilationWEffective, pad);
81+
const result = dl.conv1d(x, w, stride, pad, dataFormat, dilation);
82+
const expectedResult =
83+
dl.conv1d(x, wDilated, stride, pad, dataFormat, dilationWEffective);
8184

8285
expect(result.shape).toEqual(expectedResult.shape);
8386
expectArraysClose(result, expectedResult);
@@ -91,20 +94,21 @@ describeWithFlags('conv1d', ALL_ENVS, () => {
9194
const fSizeDilated = 7;
9295
const pad = 'valid';
9396
const stride = 1;
97+
const dataFormat = 'NWC';
9498
const dilation = 3;
9599
const dilationWEffective = 1;
96100

97101
const x = dl.tensor2d(
98-
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], inputShape);
102+
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], inputShape);
99103
const w = dl.tensor3d([3, 2, 1], [fSize, inputDepth, outputDepth]);
100104
// adding a dilation rate is equivalent to using a filter
101105
// with 0s for the dilation rate
102106
const wDilated = dl.tensor3d(
103-
[3, 0, 0, 2, 0, 0, 1], [fSizeDilated, inputDepth, outputDepth]);
107+
[3, 0, 0, 2, 0, 0, 1], [fSizeDilated, inputDepth, outputDepth]);
104108

105-
const result = dl.conv1d(x, w, stride, dilation, pad);
106-
const expectedResult = dl.conv1d(
107-
x, wDilated, stride, dilationWEffective, pad);
109+
const result = dl.conv1d(x, w, stride, pad, dataFormat, dilation);
110+
const expectedResult =
111+
dl.conv1d(x, wDilated, stride, pad, dataFormat, dilationWEffective);
108112

109113
expect(result.shape).toEqual(expectedResult.shape);
110114
expectArraysClose(result, expectedResult);
@@ -116,27 +120,31 @@ describeWithFlags('conv1d', ALL_ENVS, () => {
116120
const fSize = 2;
117121
const pad = 0;
118122
const stride = 1;
123+
const dataFormat = 'NWC';
119124
const dilation = 1;
120125

121126
// tslint:disable-next-line:no-any
122127
const x: any = dl.tensor2d([1, 2, 3, 4], [2, 2]);
123128
const w = dl.tensor3d([3, 1], [fSize, inputDepth, outputDepth]);
124129

125-
expect(() => dl.conv1d(x, w, stride, dilation, pad)).toThrowError();
130+
expect(() => dl.conv1d(x, w, stride, pad, dataFormat, dilation))
131+
.toThrowError();
126132
});
127133

128134
it('throws when weights is not rank 3', () => {
129135
const inputDepth = 1;
130136
const inputShape: [number, number, number] = [2, 2, inputDepth];
131137
const pad = 0;
132138
const stride = 1;
139+
const dataFormat = 'NWC';
133140
const dilation = 1;
134141

135142
const x = dl.tensor3d([1, 2, 3, 4], inputShape);
136143
// tslint:disable-next-line:no-any
137144
const w: any = dl.tensor4d([3, 1, 5, 0], [2, 2, 1, 1]);
138145

139-
expect(() => dl.conv1d(x, w, stride, dilation, pad)).toThrowError();
146+
expect(() => dl.conv1d(x, w, stride, pad, dataFormat, dilation))
147+
.toThrowError();
140148
});
141149

142150
it('throws when x depth does not match weight depth', () => {
@@ -147,12 +155,14 @@ describeWithFlags('conv1d', ALL_ENVS, () => {
147155
const fSize = 2;
148156
const pad = 0;
149157
const stride = 1;
158+
const dataFormat = 'NWC';
150159
const dilation = 1;
151160

152161
const x = dl.tensor3d([1, 2, 3, 4], inputShape);
153162
const w = dl.randomNormal<Rank.R3>([fSize, wrongInputDepth, outputDepth]);
154163

155-
expect(() => dl.conv1d(x, w, stride, dilation, pad)).toThrowError();
164+
expect(() => dl.conv1d(x, w, stride, pad, dataFormat, dilation))
165+
.toThrowError();
156166
});
157167

158168
it('throws when both stride and dilation are greater than 1', () => {
@@ -162,11 +172,13 @@ describeWithFlags('conv1d', ALL_ENVS, () => {
162172
const fSize = 1;
163173
const pad = 'same';
164174
const stride = 2;
175+
const dataFormat = 'NWC';
165176
const dilation = 2;
166177

167178
const x = dl.tensor3d([1, 2, 3, 4], inputShape);
168179
const w = dl.tensor3d([3], [fSize, inputDepth, outputDepth]);
169180

170-
expect(() => dl.conv1d(x, w, stride, dilation, pad)).toThrowError();
181+
expect(() => dl.conv1d(x, w, stride, pad, dataFormat, dilation))
182+
.toThrowError();
171183
});
172184
});

0 commit comments

Comments
 (0)