@@ -40,6 +40,12 @@ export class ConvOps {
40
40
* - For more info, see this guide:
41
41
* [https://www.tensorflow.org/api_guides/python/nn#Convolution](
42
42
* https://www.tensorflow.org/api_guides/python/nn#Convolution)
43
+ * @param dataFormat An optional string from "NWC", "NCW". Defaults to "NWC",
44
+ * the data is stored in the order of [batch, in_width, in_channels]. Only
45
+ * "NWC" is currently supported.
46
+ * @param dilation The dilation rate in which we sample input values in
47
+ * atrous convolution. Defaults to `1`. If it is greater than 1, then
48
+ * stride must be `1`.
43
49
* @param dimRoundingMode The rounding mode used when computing output
44
50
* dimensions if pad is a number. If none is provided, it will not round
45
51
* and error if the output is of fractional size.
@@ -48,6 +54,7 @@ export class ConvOps {
48
54
@operation
49
55
static conv1d < T extends Tensor2D | Tensor3D > (
50
56
input : T , filter : Tensor3D , stride : number , pad : 'valid' | 'same' | number ,
57
+ dataFormat : 'NWC' | 'NCW' = 'NWC' , dilation = 1 ,
51
58
dimRoundingMode ?: 'floor' | 'round' | 'ceil' ) : T {
52
59
let input3D = input as Tensor3D ;
53
60
let reshapedTo3D = false ;
@@ -74,15 +81,27 @@ export class ConvOps {
74
81
input3D . shape [ 2 ] === filter . shape [ 1 ] ,
75
82
`Error in conv1d: depth of input (${ input3D . shape [ 2 ] } ) must match ` +
76
83
`input depth for filter ${ filter . shape [ 1 ] } .` ) ;
84
+ util . assert (
85
+ eitherStridesOrDilationsAreOne ( stride , dilation ) ,
86
+ 'Error in conv1D: Either stride or dilation must be 1.' +
87
+ `Got stride ${ stride } and dilation '${ dilation } '` ) ;
88
+ util . assert (
89
+ dataFormat === 'NWC' ,
90
+ `Error in conv1d: got dataFormat of ${
91
+ dataFormat } but only NWC is currently supported.`) ;
77
92
78
93
const filter4D =
79
94
filter . as4D ( 1 , filter . shape [ 0 ] , filter . shape [ 1 ] , filter . shape [ 2 ] ) ;
80
95
const input4D =
81
96
input3D . as4D ( input3D . shape [ 0 ] , 1 , input3D . shape [ 1 ] , input3D . shape [ 2 ] ) ;
82
97
const strides : [ number , number ] = [ 1 , stride ] ;
98
+ const dilations : [ number , number ] = [ 1 , dilation ] ;
99
+
100
+ const conv2dDataFormat = 'NHWC' ;
83
101
84
- const res =
85
- ConvOps . conv2d ( input4D , filter4D , strides , pad , dimRoundingMode ) ;
102
+ const res = ConvOps . conv2d (
103
+ input4D , filter4D , strides , pad , conv2dDataFormat , dilations ,
104
+ dimRoundingMode ) ;
86
105
87
106
if ( reshapedTo3D ) {
88
107
return res . as2D ( res . shape [ 2 ] , res . shape [ 3 ] ) as T ;
@@ -108,6 +127,15 @@ export class ConvOps {
108
127
* - For more info, see this guide:
109
128
* [https://www.tensorflow.org/api_guides/python/nn#Convolution](
110
129
* https://www.tensorflow.org/api_guides/python/nn#Convolution)
130
+ * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
131
+ * "NHWC". Specify the data format of the input and output data. With the
132
+ * default format "NHWC", the data is stored in the order of: [batch,
133
+ * height, width, channels]. Only "NHWC" is currently supported.
134
+ * @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
135
+ * in which we sample input values across the height and width dimensions
136
+ * in atrous convolution. Defaults to `[1, 1]`. If `dilations` is a single
137
+ * number, then `dilationHeight == dilationWidth`. If it is greater than
138
+ * 1, then all values of `strides` must be 1.
111
139
* @param dimRoundingMode The rounding mode used when computing output
112
140
* dimensions if pad is a number. If none is provided, it will not round
113
141
* and error if the output is of fractional size.
@@ -116,7 +144,9 @@ export class ConvOps {
116
144
@operation
117
145
static conv2d < T extends Tensor3D | Tensor4D > (
118
146
x : T , filter : Tensor4D , strides : [ number , number ] | number ,
119
- pad : 'valid' | 'same' | number , dimRoundingMode ?: 'floor' | 'round' | 'ceil' ) : T {
147
+ pad : 'valid' | 'same' | number , dataFormat : 'NHWC' | 'NCHW' = 'NHWC' ,
148
+ dilations : [ number , number ] | number = [ 1 , 1 ] ,
149
+ dimRoundingMode ?: 'floor' | 'round' | 'ceil' ) : T {
120
150
let x4D = x as Tensor4D ;
121
151
let reshapedTo4D = false ;
122
152
@@ -142,13 +172,24 @@ export class ConvOps {
142
172
x4D . shape [ 3 ] === filter . shape [ 2 ] ,
143
173
`Error in conv2d: depth of input (${ x4D . shape [ 3 ] } ) must match ` +
144
174
`input depth for filter ${ filter . shape [ 2 ] } .` ) ;
145
-
146
- const dilations = 1 ;
175
+ util . assert (
176
+ eitherStridesOrDilationsAreOne ( strides , dilations ) ,
177
+ 'Error in conv2D: Either strides or dilations must be 1.' +
178
+ `Got strides ${ strides } and dilations '${ dilations } '` ) ;
179
+ util . assert (
180
+ dataFormat === 'NHWC' ,
181
+ `Error in conv2d: got dataFormat of ${
182
+ dataFormat } but only NHWC is currently supported.`) ;
147
183
148
184
const convInfo = conv_util . computeConv2DInfo (
149
185
x4D . shape , filter . shape , strides , dilations , pad , dimRoundingMode ) ;
150
186
151
187
const grad = ( dy : Tensor4D ) => {
188
+ util . assert (
189
+ tupleValuesAreOne ( dilations ) ,
190
+ 'Error in gradient of conv2D: dilation rates greater than 1 are not' +
191
+ `yet supported in gradients. Got dilations '${ dilations } '` ) ;
192
+
152
193
return {
153
194
x : ( ) => ConvOps . conv2dDerInput ( x4D . shape , dy , filter , strides , pad ) ,
154
195
filter : ( ) =>
@@ -375,9 +416,13 @@ export class ConvOps {
375
416
* https://www.tensorflow.org/api_guides/python/nn#Convolution)
376
417
* @param dilations The dilation rates: `[dilationHeight, dilationWidth]`
377
418
* in which we sample input values across the height and width dimensions
378
- * in atrous convolution. Defaults to `[1, 1]`. If `dilations ` is a single
419
+ * in atrous convolution. Defaults to `[1, 1]`. If `rate ` is a single
379
420
* number, then `dilationHeight == dilationWidth`. If it is greater than
380
421
* 1, then all values of `strides` must be 1.
422
+ * @param dataFormat: An optional string from: "NHWC", "NCHW". Defaults to
423
+ * "NHWC". Specify the data format of the input and output data. With the
424
+ * default format "NHWC", the data is stored in the order of: [batch,
425
+ * height, width, channels]. Only "NHWC" is currently supported.
381
426
* @param dimRoundingMode The rounding mode used when computing output
382
427
* dimensions if pad is a number. If none is provided, it will not round
383
428
* and error if the output is of fractional size.
@@ -386,7 +431,8 @@ export class ConvOps {
386
431
@operation
387
432
static depthwiseConv2d < T extends Tensor3D | Tensor4D > (
388
433
input : T , filter : Tensor4D , strides : [ number , number ] | number ,
389
- pad : 'valid' | 'same' | number , dilations : [ number , number ] | number = [ 1 , 1 ] ,
434
+ pad : 'valid' | 'same' | number , dataFormat : 'NHWC' | 'NCHW' = 'NHWC' ,
435
+ dilations : [ number , number ] | number = [ 1 , 1 ] ,
390
436
dimRoundingMode ?: 'floor' | 'round' | 'ceil' ) : T {
391
437
let input4D = input as Tensor4D ;
392
438
let reshapedTo4D = false ;
@@ -410,11 +456,11 @@ export class ConvOps {
410
456
if ( dilations == null ) {
411
457
dilations = [ 1 , 1 ] ;
412
458
}
413
- const [ dilationHeight , dilationWidth ] = parseTupleParam ( dilations ) ;
414
459
util . assert (
415
- dilationHeight === 1 && dilationWidth === 1 ,
416
- 'Error in depthwiseConv2D: dilation rates greater than 1 are not yet ' +
417
- `supported. Got dilations '${ dilations } '` ) ;
460
+ eitherStridesOrDilationsAreOne ( strides , dilations ) ,
461
+ 'Error in depthwiseConv2d: Either strides or dilations must be 1.' +
462
+ `Got strides ${ strides } and dilations '${ dilations } '` ) ;
463
+
418
464
if ( dimRoundingMode != null ) {
419
465
util . assert (
420
466
util . isInt ( pad as number ) ,
@@ -438,3 +484,14 @@ export class ConvOps {
438
484
function parseTupleParam ( param : number | [ number , number ] ) : [ number , number ] {
439
485
return typeof param === 'number' ? [ param , param ] : param ;
440
486
}
487
+
488
+ function tupleValuesAreOne ( param : number | [ number , number ] ) : boolean {
489
+ const [ dimA , dimB ] = parseTupleParam ( param ) ;
490
+ return dimA === 1 && dimB === 1 ;
491
+ }
492
+
493
+ function eitherStridesOrDilationsAreOne (
494
+ strides : number | [ number , number ] ,
495
+ dilations : number | [ number , number ] ) : boolean {
496
+ return tupleValuesAreOne ( strides ) || tupleValuesAreOne ( dilations ) ;
497
+ }
0 commit comments