@@ -28,12 +28,13 @@ describeWithFlags('conv1d', ALL_ENVS, () => {
28
28
const fSize = 1 ;
29
29
const pad = 'same' ;
30
30
const stride = 1 ;
31
+ const dataFormat = 'NWC' ;
31
32
const dilation = 1 ;
32
33
33
34
const x = dl . tensor3d ( [ 1 , 2 , 3 , 4 ] , inputShape ) ;
34
35
const w = dl . tensor3d ( [ 3 ] , [ fSize , inputDepth , outputDepth ] ) ;
35
36
36
- const result = dl . conv1d ( x , w , stride , dilation , pad ) ;
37
+ const result = dl . conv1d ( x , w , stride , pad , dataFormat , dilation ) ;
37
38
38
39
expect ( result . shape ) . toEqual ( [ 2 , 2 , 1 ] ) ;
39
40
expectArraysClose ( result , [ 3 , 6 , 9 , 12 ] ) ;
@@ -46,12 +47,13 @@ describeWithFlags('conv1d', ALL_ENVS, () => {
46
47
const fSize = 2 ;
47
48
const pad = 'valid' ;
48
49
const stride = 1 ;
50
+ const dataFormat = 'NWC' ;
49
51
const dilation = 1 ;
50
52
51
53
const x = dl . tensor2d ( [ 1 , 2 , 3 , 4 ] , inputShape ) ;
52
54
const w = dl . tensor3d ( [ 2 , 1 ] , [ fSize , inputDepth , outputDepth ] ) ;
53
55
54
- const result = dl . conv1d ( x , w , stride , dilation , pad ) ;
56
+ const result = dl . conv1d ( x , w , stride , pad , dataFormat , dilation ) ;
55
57
56
58
expect ( result . shape ) . toEqual ( [ 3 , 1 ] ) ;
57
59
expectArraysClose ( result , [ 4 , 7 , 10 ] ) ;
@@ -65,19 +67,20 @@ describeWithFlags('conv1d', ALL_ENVS, () => {
65
67
const fSizeDilated = 3 ;
66
68
const pad = 'valid' ;
67
69
const stride = 1 ;
70
+ const dataFormat = 'NWC' ;
68
71
const dilation = 2 ;
69
72
const dilationWEffective = 1 ;
70
73
71
74
const x = dl . tensor2d ( [ 1 , 2 , 3 , 4 ] , inputShape ) ;
72
75
const w = dl . tensor3d ( [ 2 , 1 ] , [ fSize , inputDepth , outputDepth ] ) ;
73
76
// adding a dilation rate is equivalent to using a filter
74
77
// with 0s for the dilation rate
75
- const wDilated = dl . tensor3d (
76
- [ 2 , 0 , 1 ] , [ fSizeDilated , inputDepth , outputDepth ] ) ;
78
+ const wDilated =
79
+ dl . tensor3d ( [ 2 , 0 , 1 ] , [ fSizeDilated , inputDepth , outputDepth ] ) ;
77
80
78
- const result = dl . conv1d ( x , w , stride , dilation , pad ) ;
79
- const expectedResult = dl . conv1d (
80
- x , wDilated , stride , dilationWEffective , pad ) ;
81
+ const result = dl . conv1d ( x , w , stride , pad , dataFormat , dilation ) ;
82
+ const expectedResult =
83
+ dl . conv1d ( x , wDilated , stride , pad , dataFormat , dilationWEffective ) ;
81
84
82
85
expect ( result . shape ) . toEqual ( expectedResult . shape ) ;
83
86
expectArraysClose ( result , expectedResult ) ;
@@ -91,20 +94,21 @@ describeWithFlags('conv1d', ALL_ENVS, () => {
91
94
const fSizeDilated = 7 ;
92
95
const pad = 'valid' ;
93
96
const stride = 1 ;
97
+ const dataFormat = 'NWC' ;
94
98
const dilation = 3 ;
95
99
const dilationWEffective = 1 ;
96
100
97
101
const x = dl . tensor2d (
98
- [ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 ] , inputShape ) ;
102
+ [ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 , 12 , 13 , 14 ] , inputShape ) ;
99
103
const w = dl . tensor3d ( [ 3 , 2 , 1 ] , [ fSize , inputDepth , outputDepth ] ) ;
100
104
// adding a dilation rate is equivalent to using a filter
101
105
// with 0s for the dilation rate
102
106
const wDilated = dl . tensor3d (
103
- [ 3 , 0 , 0 , 2 , 0 , 0 , 1 ] , [ fSizeDilated , inputDepth , outputDepth ] ) ;
107
+ [ 3 , 0 , 0 , 2 , 0 , 0 , 1 ] , [ fSizeDilated , inputDepth , outputDepth ] ) ;
104
108
105
- const result = dl . conv1d ( x , w , stride , dilation , pad ) ;
106
- const expectedResult = dl . conv1d (
107
- x , wDilated , stride , dilationWEffective , pad ) ;
109
+ const result = dl . conv1d ( x , w , stride , pad , dataFormat , dilation ) ;
110
+ const expectedResult =
111
+ dl . conv1d ( x , wDilated , stride , pad , dataFormat , dilationWEffective ) ;
108
112
109
113
expect ( result . shape ) . toEqual ( expectedResult . shape ) ;
110
114
expectArraysClose ( result , expectedResult ) ;
@@ -116,27 +120,31 @@ describeWithFlags('conv1d', ALL_ENVS, () => {
116
120
const fSize = 2 ;
117
121
const pad = 0 ;
118
122
const stride = 1 ;
123
+ const dataFormat = 'NWC' ;
119
124
const dilation = 1 ;
120
125
121
126
// tslint:disable-next-line:no-any
122
127
const x : any = dl . tensor2d ( [ 1 , 2 , 3 , 4 ] , [ 2 , 2 ] ) ;
123
128
const w = dl . tensor3d ( [ 3 , 1 ] , [ fSize , inputDepth , outputDepth ] ) ;
124
129
125
- expect ( ( ) => dl . conv1d ( x , w , stride , dilation , pad ) ) . toThrowError ( ) ;
130
+ expect ( ( ) => dl . conv1d ( x , w , stride , pad , dataFormat , dilation ) )
131
+ . toThrowError ( ) ;
126
132
} ) ;
127
133
128
134
it ( 'throws when weights is not rank 3' , ( ) => {
129
135
const inputDepth = 1 ;
130
136
const inputShape : [ number , number , number ] = [ 2 , 2 , inputDepth ] ;
131
137
const pad = 0 ;
132
138
const stride = 1 ;
139
+ const dataFormat = 'NWC' ;
133
140
const dilation = 1 ;
134
141
135
142
const x = dl . tensor3d ( [ 1 , 2 , 3 , 4 ] , inputShape ) ;
136
143
// tslint:disable-next-line:no-any
137
144
const w : any = dl . tensor4d ( [ 3 , 1 , 5 , 0 ] , [ 2 , 2 , 1 , 1 ] ) ;
138
145
139
- expect ( ( ) => dl . conv1d ( x , w , stride , dilation , pad ) ) . toThrowError ( ) ;
146
+ expect ( ( ) => dl . conv1d ( x , w , stride , pad , dataFormat , dilation ) )
147
+ . toThrowError ( ) ;
140
148
} ) ;
141
149
142
150
it ( 'throws when x depth does not match weight depth' , ( ) => {
@@ -147,12 +155,14 @@ describeWithFlags('conv1d', ALL_ENVS, () => {
147
155
const fSize = 2 ;
148
156
const pad = 0 ;
149
157
const stride = 1 ;
158
+ const dataFormat = 'NWC' ;
150
159
const dilation = 1 ;
151
160
152
161
const x = dl . tensor3d ( [ 1 , 2 , 3 , 4 ] , inputShape ) ;
153
162
const w = dl . randomNormal < Rank . R3 > ( [ fSize , wrongInputDepth , outputDepth ] ) ;
154
163
155
- expect ( ( ) => dl . conv1d ( x , w , stride , dilation , pad ) ) . toThrowError ( ) ;
164
+ expect ( ( ) => dl . conv1d ( x , w , stride , pad , dataFormat , dilation ) )
165
+ . toThrowError ( ) ;
156
166
} ) ;
157
167
158
168
it ( 'throws when both stride and dilation are greater than 1' , ( ) => {
@@ -162,11 +172,13 @@ describeWithFlags('conv1d', ALL_ENVS, () => {
162
172
const fSize = 1 ;
163
173
const pad = 'same' ;
164
174
const stride = 2 ;
175
+ const dataFormat = 'NWC' ;
165
176
const dilation = 2 ;
166
177
167
178
const x = dl . tensor3d ( [ 1 , 2 , 3 , 4 ] , inputShape ) ;
168
179
const w = dl . tensor3d ( [ 3 ] , [ fSize , inputDepth , outputDepth ] ) ;
169
180
170
- expect ( ( ) => dl . conv1d ( x , w , stride , dilation , pad ) ) . toThrowError ( ) ;
181
+ expect ( ( ) => dl . conv1d ( x , w , stride , pad , dataFormat , dilation ) )
182
+ . toThrowError ( ) ;
171
183
} ) ;
172
184
} ) ;
0 commit comments