Line data Source code
1 : /**
2 : * \file TypedBaseFilter.hxx
3 : */
4 :
5 : #include "TypedBaseFilter.h"
6 : #include <ATK/Core/Utilities.h>
7 :
8 : #include <boost/mp11/algorithm.hpp>
9 : #include <boost/mp11/list.hpp>
10 :
11 : #include <complex>
12 : #include <cstdint>
13 : #include <type_traits>
14 :
15 : namespace ATK
16 : {
17 : namespace Utilities
18 : {
19 : using ConversionTypes = boost::mp11::mp_list<std::int16_t, std::int32_t, int64_t, float, double, std::complex<float>, std::complex<double> > ;
20 :
21 : template<typename Vector, typename DataType>
22 23576 : void convert_scalar_array(ATK::BaseFilter* filter, unsigned int port, DataType* converted_input, gsl::index size, int type)
23 : {
24 : if constexpr(boost::mp11::mp_empty<Vector>::value)
25 : {
26 1 : throw RuntimeError("Cannot convert types for these filters");
27 : }
28 : else if constexpr(std::is_arithmetic<boost::mp11::mp_front<Vector>>::value)
29 : {
30 23573 : if (type != 0)
31 : {
32 18833 : convert_scalar_array<boost::mp11::mp_pop_front<Vector>, DataType>(filter, port, converted_input, size, type - 1);
33 : }
34 : else
35 : {
36 : using InputOriginalType = boost::mp11::mp_front<Vector>;
37 4740 : InputOriginalType* original_input_array = static_cast<ATK::TypedBaseFilter<InputOriginalType>*>(filter)->get_output_array(port);
38 4740 : ATK::ConversionUtilities<InputOriginalType, DataType>::convert_array(original_input_array, converted_input, size);
39 : }
40 : }
41 : else // This is in case we add arithmetic types after the non arithmetic ones (should not happen)
42 : {
43 2 : convert_scalar_array<boost::mp11::mp_pop_front<Vector>, DataType>(filter, port, converted_input, size, type - 1);
44 : }
45 23568 : }
46 :
47 :
48 : template<typename Vector, typename DataType>
49 2 : void convert_complex_array(ATK::BaseFilter* filter, unsigned int port, DataType* converted_input, gsl::index size, int type)
50 : {
51 : if constexpr(boost::mp11::mp_empty<Vector>::value)
52 : {
53 0 : throw RuntimeError("Can't convert types");
54 : }
55 : else
56 : {
57 2 : assert(type >= 0);
58 2 : if (type != 0)
59 : {
60 1 : convert_complex_array<boost::mp11::mp_pop_front<Vector>, DataType>(filter, port, converted_input, size, type - 1);
61 : }
62 : else
63 : {
64 : using InputOriginalType = boost::mp11::mp_front<Vector>;
65 1 : InputOriginalType* original_input_array = static_cast<ATK::TypedBaseFilter<InputOriginalType>*>(filter)->get_output_array(port);
66 1 : ATK::ConversionUtilities<InputOriginalType, DataType>::convert_array(original_input_array, converted_input, size);
67 : }
68 : }
69 2 : }
70 :
71 : template<typename Vector, typename DataType>
72 4742 : void convert_array(ATK::BaseFilter* filter, gsl::index port, DataType* converted_input, gsl::index size, int type)
73 : {
74 : if constexpr(std::is_arithmetic<DataType>::value)
75 : {
76 4741 : convert_scalar_array<Vector, DataType>(filter, port, converted_input, size, type);
77 : }
78 : else if constexpr(boost::mp11::mp_contains<Vector, DataType>::value)
79 : {
80 1 : convert_complex_array<Vector, DataType>(filter, port, converted_input, size, type);
81 : }
82 : else
83 : {
84 : assert(dynamic_cast<ATK::OutputArrayInterface<DataType>*>(filter));
85 : // For SIMD, you shouldn't call this, but adapt input/output delays so that there is no copy from one filter to another.
86 : DataType* original_input_array = dynamic_cast<ATK::TypedBaseFilter<DataType>*>(filter)->get_output_array(port);
87 : ATK::ConversionUtilities<DataType, DataType>::convert_array(original_input_array, converted_input, size);
88 : }
89 4741 : }
90 :
91 : template<typename Vector, typename DataType>
92 4742 : int get_type()
93 : {
94 : if constexpr(boost::mp11::mp_contains<Vector, DataType>::value)
95 : {
96 4742 : return boost::mp11::mp_find<ConversionTypes, DataType>::value;
97 : }
98 : else
99 : {
100 : return -1;
101 : }
102 : }
103 : }
104 :
105 : template<typename DataType_, typename DataType__>
106 2291 : TypedBaseFilter<DataType_, DataType__>::TypedBaseFilter(gsl::index nb_input_ports, gsl::index nb_output_ports)
107 2291 : :Parent(nb_input_ports, nb_output_ports), converted_inputs_delay(nb_input_ports), converted_inputs(nb_input_ports, nullptr), converted_inputs_size(nb_input_ports, 0), converted_in_delays(nb_input_ports, 0), direct_filters(nb_input_ports, nullptr), outputs_delay(nb_output_ports), outputs(nb_output_ports, nullptr), outputs_size(nb_output_ports, 0), out_delays(nb_output_ports, 0), default_input(nb_input_ports, TypeTraits<DataType_>::Zero()), default_output(nb_output_ports, TypeTraits<DataType__>::Zero())
108 : {
109 2291 : }
110 :
111 : template<typename DataType_, typename DataType__>
112 5 : void TypedBaseFilter<DataType_, DataType__>::set_nb_input_ports(gsl::index nb_ports)
113 : {
114 5 : if(nb_ports == nb_input_ports)
115 : {
116 1 : return;
117 : }
118 4 : Parent::set_nb_input_ports(nb_ports);
119 4 : converted_inputs_delay = std::vector<AlignedVector>(nb_ports);
120 4 : converted_inputs.assign(nb_ports, nullptr);
121 4 : converted_inputs_size.assign(nb_ports, 0);
122 4 : converted_in_delays.assign(nb_ports, 0);
123 4 : direct_filters.assign(nb_ports, nullptr);
124 4 : default_input.assign(nb_ports, TypeTraits<DataTypeInput>::Zero());
125 : }
126 :
127 : template<typename DataType_, typename DataType__>
128 17 : void TypedBaseFilter<DataType_, DataType__>::set_nb_output_ports(gsl::index nb_ports)
129 : {
130 17 : if(nb_ports == nb_output_ports)
131 : {
132 1 : return;
133 : }
134 16 : Parent::set_nb_output_ports(nb_ports);
135 16 : outputs_delay = std::vector<AlignedOutVector>(nb_ports);
136 16 : outputs.assign(nb_ports, nullptr);
137 16 : outputs_size.assign(nb_ports, 0);
138 16 : out_delays.assign(nb_ports, 0);
139 16 : default_output.assign(nb_ports, TypeTraits<DataTypeOutput>::Zero());
140 : }
141 :
142 : template<typename DataType_, typename DataType__>
143 1 : void TypedBaseFilter<DataType_, DataType__>::process_impl(gsl::index size) const
144 : {
145 : // Nothing to do by default
146 1 : }
147 :
148 : template<typename DataType_, typename DataType__>
149 49207 : void TypedBaseFilter<DataType_, DataType__>::prepare_process(gsl::index size)
150 : {
151 49207 : convert_inputs(size);
152 49206 : }
153 :
154 : template<typename DataType_, typename DataType__>
155 4742 : int TypedBaseFilter<DataType_, DataType__>::get_type() const
156 : {
157 4742 : return Utilities::get_type<Utilities::ConversionTypes, DataType__>();
158 : }
159 :
160 : template<typename DataType_, typename DataType__>
161 89774 : DataType__* TypedBaseFilter<DataType_, DataType__>::get_output_array(gsl::index port) const
162 : {
163 89774 : return outputs[port];
164 : }
165 :
166 : template<typename DataType_, typename DataType__>
167 2 : gsl::index TypedBaseFilter<DataType_, DataType__>::get_output_array_size() const
168 : {
169 2 : return outputs_size.front();
170 : }
171 :
172 : template<typename DataType_, typename DataType__>
173 49207 : void TypedBaseFilter<DataType_, DataType__>::convert_inputs(gsl::index size)
174 : {
175 99365 : for(gsl::index i = 0; i < nb_input_ports; ++i)
176 : {
177 : // if the input delay is smaller than the preceding filter output delay, we may have overlap
178 : // if the types are identical and if the type is not -1 (an unknown type)
179 : // if we have overlap, don't copy anything at all
180 50159 : if((input_delay <= connections[i].second->get_output_delay()) && (direct_filters[i] != nullptr))
181 : {
182 45417 : converted_inputs[i] = direct_filters[i]->get_output_array(connections[i].first);
183 45417 : converted_inputs_size[i] = size;
184 45417 : converted_in_delays[i] = input_delay;
185 45417 : continue;
186 : }
187 4742 : auto input_size = converted_inputs_size[i];
188 4742 : auto in_delay = converted_in_delays[i];
189 4742 : if(input_size < size || in_delay < input_delay)
190 : {
191 : // TODO Properly align the beginning of the data, not depending on input delay
192 4502 : AlignedVector temp(input_delay + size, TypeTraits<DataTypeInput>::Zero());
193 4502 : if(input_size == 0)
194 : {
195 11903 : for(unsigned int j = 0; j < input_delay; ++j)
196 : {
197 11558 : temp[j] = default_input[i];
198 : }
199 : }
200 : else
201 : {
202 4157 : const auto input_ptr = converted_inputs[i];
203 20720701 : for(gsl::index j = 0; j < in_delay; ++j)
204 : {
205 20716500 : temp[j] = input_ptr[last_size + j - in_delay];
206 : }
207 : }
208 :
209 4502 : converted_inputs_delay[i] = std::move(temp);
210 4502 : converted_inputs[i] = converted_inputs_delay[i].data() + input_delay;
211 4502 : converted_inputs_size[i] = size;
212 4502 : converted_in_delays[i] = input_delay;
213 : }
214 : else
215 : {
216 240 : auto my_last_size = static_cast<int64_t>(last_size) * input_sampling_rate / output_sampling_rate;
217 240 : const auto input_ptr = converted_inputs[i];
218 1143 : for(gsl::index j = 0; j < input_delay; ++j)
219 : {
220 903 : input_ptr[j - input_delay] = input_ptr[my_last_size + j - input_delay];
221 : }
222 : }
223 4742 : Utilities::convert_array<Utilities::ConversionTypes, DataTypeInput>(connections[i].second, connections[i].first, converted_inputs[i], size, connections[i].second->get_type());
224 : }
225 49206 : }
226 :
227 : template<typename DataType_, typename DataType__>
228 49206 : void TypedBaseFilter<DataType_, DataType__>::prepare_outputs(gsl::index size)
229 : {
230 98867 : for(gsl::index i = 0; i < nb_output_ports; ++i)
231 : {
232 49661 : auto output_size = outputs_size[i];
233 49661 : auto out_delay = out_delays[i];
234 49661 : if(output_size < size || out_delay < output_delay)
235 : {
236 : // TODO Properly align the beginning of the data, not depending on output delay
237 47340 : AlignedOutVector temp(output_delay + size, TypeTraits<DataTypeOutput>::Zero());
238 47340 : if(output_size == 0)
239 : {
240 2594 : for(gsl::index j = 0; j < output_delay; ++j)
241 : {
242 1086 : temp[j] = default_output[i];
243 : }
244 : }
245 : else
246 : {
247 45832 : const auto output_ptr = outputs[i];
248 46572 : for(gsl::index j = 0; j < static_cast<int>(out_delay); ++j)
249 : {
250 740 : temp[j] = output_ptr[last_size + j - out_delay];
251 : }
252 : }
253 :
254 47340 : outputs_delay[i] = std::move(temp);
255 47340 : outputs[i] = outputs_delay[i].data() + output_delay;
256 47340 : outputs_size[i] = size;
257 47340 : out_delays[i] = output_delay;
258 : }
259 : else
260 : {
261 2321 : const auto output_ptr = outputs[i];
262 3178 : for(gsl::index j = 0; j < static_cast<int>(output_delay); ++j)
263 : {
264 857 : output_ptr[j - output_delay] = output_ptr[last_size + j - output_delay];
265 : }
266 : }
267 : }
268 49206 : }
269 :
270 : template<typename DataType_, typename DataType__>
271 2258 : void TypedBaseFilter<DataType_, DataType__>::full_setup()
272 : {
273 : // Reset input arrays
274 2258 : converted_inputs_delay = std::vector<AlignedVector>(nb_input_ports);
275 2258 : converted_inputs.assign(nb_input_ports, nullptr);
276 2258 : converted_inputs_size.assign(nb_input_ports, 0);
277 2258 : converted_in_delays.assign(nb_input_ports, 0);
278 :
279 : // Reset output arrays
280 2258 : outputs_delay = std::vector<AlignedOutVector>(nb_output_ports);
281 2258 : outputs.assign(nb_output_ports, nullptr);
282 2258 : outputs_size.assign(nb_output_ports, 0);
283 2258 : out_delays.assign(nb_output_ports, 0);
284 :
285 2258 : Parent::full_setup();
286 2258 : }
287 :
288 : template<typename DataType_, typename DataType__>
289 1404 : void TypedBaseFilter<DataType_, DataType__>::set_input_port(gsl::index input_port, gsl::not_null<BaseFilter*> filter, gsl::index output_port)
290 : {
291 1404 : set_input_port(input_port, *filter, output_port);
292 1401 : }
293 :
294 : template<typename DataType_, typename DataType__>
295 1411 : void TypedBaseFilter<DataType_, DataType__>::set_input_port(gsl::index input_port, BaseFilter& filter, gsl::index output_port)
296 : {
297 1411 : Parent::set_input_port(input_port, filter, output_port);
298 1408 : converted_inputs_size[input_port] = 0;
299 1408 : converted_in_delays[input_port] = 0;
300 1408 : direct_filters[input_port] = dynamic_cast<OutputArrayInterface<DataType_>*>(&filter);
301 1408 : }
302 : }
|