go home Home | Main Page | Topics | Namespace List | Class Hierarchy | Alphabetical List | Data Structures | File List | Namespace Members | Data Fields | Globals | Related Pages
Loading...
Searching...
No Matches
ImpactTensorUtils.h
Go to the documentation of this file.
1/*=========================================================================
2 *
3 * Copyright UMC Utrecht and contributors
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0.txt
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 *=========================================================================*/
18
38
39#ifndef ImpactTensorUtils_h
40#define ImpactTensorUtils_h
41
42#include <torch/torch.h>
43#include <vector>
44#include <functional>
45#include <exception>
46#include "ImpactLoss.h"
48#include <random>
49
51{
52
61template <typename TImage, typename TInterpolator>
62torch::Tensor
63ImageToTensor(typename TImage::ConstPointer image,
64 typename TInterpolator::Pointer interpolator,
65 const std::vector<float> & voxelSize,
66 const std::function<typename TImage::PointType(const typename TImage::PointType &)> & transformPoint);
67
77template <typename TImage, typename TFeatureImage>
78typename TFeatureImage::Pointer
79TensorToImage(typename TImage::ConstPointer image, torch::Tensor layers);
80
98template <typename TImage, typename FeaturesMaps, typename InterpolatorType, typename FeaturesImageType>
99std::vector<FeaturesMaps>
101 typename TImage::ConstPointer image,
102 typename InterpolatorType::Pointer interpolator,
103 const std::vector<itk::ImpactModelConfiguration> & modelsConfiguration,
104 torch::Device device,
105 std::vector<unsigned int> pca,
106 std::vector<torch::Tensor> & principalComponents,
107 const std::function<void(typename TImage::ConstPointer, torch::Tensor &, const std::string &)> & writeInputImage,
108 const std::function<typename TImage::PointType(const typename TImage::PointType &)> & transformPoint = nullptr);
109
123std::vector<torch::Tensor>
124GetModelOutputsExample(std::vector<itk::ImpactModelConfiguration> & modelsConfig,
125 const std::string & modelType,
126 torch::Device device);
127
140std::vector<std::vector<float>>
142 std::mt19937 & randomGenerator,
143 unsigned int dimension);
144
145template <typename ImagePointType>
146using ImagesPatchValuesEvaluator = std::function<
147 torch::Tensor(const ImagePointType &, const std::vector<std::vector<float>> &, const std::vector<int64_t> &)>;
148
164template <class ImagePointType>
165std::vector<torch::Tensor>
166GenerateOutputs(const std::vector<itk::ImpactModelConfiguration> & modelConfig,
167 const std::vector<ImagePointType> & fixedPoints,
168 const std::vector<std::vector<std::vector<std::vector<float>>>> & patchIndex,
169 const std::vector<torch::Tensor> subsetsOfFeatures,
170 torch::Device device,
171 const ImpactTensorUtils::ImagesPatchValuesEvaluator<ImagePointType> & imagesPatchValuesEvaluator);
172
173template <typename ImagePointType>
174using ImagesPatchValuesAndJacobiansEvaluator = std::function<torch::Tensor(const ImagePointType &,
175 torch::Tensor &,
176 const std::vector<std::vector<float>> &,
177 const std::vector<int64_t> &,
178 int)>;
179
197template <typename ImagePointType>
198std::vector<torch::Tensor>
199GenerateOutputsAndJacobian(const std::vector<itk::ImpactModelConfiguration> & modelConfig,
200 const std::vector<ImagePointType> & fixedPoints,
201 const std::vector<std::vector<std::vector<std::vector<float>>>> & patchIndex,
202 std::vector<torch::Tensor> subsetsOfFeatures,
203 std::vector<torch::Tensor> fixedOutputsTensor,
204 torch::Device device,
205 std::vector<std::unique_ptr<ImpactLoss::Loss>> & losses,
207 imagesPatchValuesAndJacobiansEvaluator);
208
209} // namespace ImpactTensorUtils
210
211
212#ifndef ITK_MANUAL_INSTANTIATION
213# include "ImpactTensorUtils.hxx"
214#endif
215
216#endif // end #ifndef ImpactTensorUtils_h
std::vector< torch::Tensor > GetModelOutputsExample(std::vector< itk::ImpactModelConfiguration > &modelsConfig, const std::string &modelType, torch::Device device)
Tests the configuration of each model by generating outputs from dummy input.
std::function< torch::Tensor(const ImagePointType &, torch::Tensor &, const std::vector< std::vector< float > > &, const std::vector< int64_t > &, int)> ImagesPatchValuesAndJacobiansEvaluator
TFeatureImage::Pointer TensorToImage(typename TImage::ConstPointer image, torch::Tensor layers)
Converts a tensor (C×D×H×W) to a multi-channel ITK image. Converts the given tensor to an ITK image,...
torch::Tensor ImageToTensor(typename TImage::ConstPointer image, typename TInterpolator::Pointer interpolator, const std::vector< float > &voxelSize, const std::function< typename TImage::PointType(const typename TImage::PointType &)> &transformPoint)
Converts an ITK image to a Torch tensor using physical spacing.
std::vector< FeaturesMaps > GetFeaturesMaps(typename TImage::ConstPointer image, typename InterpolatorType::Pointer interpolator, const std::vector< itk::ImpactModelConfiguration > &modelsConfiguration, torch::Device device, std::vector< unsigned int > pca, std::vector< torch::Tensor > &principalComponents, const std::function< void(typename TImage::ConstPointer, torch::Tensor &, const std::string &)> &writeInputImage, const std::function< typename TImage::PointType(const typename TImage::PointType &)> &transformPoint=nullptr)
Applies one or more models to an image to extract feature maps.
std::vector< torch::Tensor > GenerateOutputs(const std::vector< itk::ImpactModelConfiguration > &modelConfig, const std::vector< ImagePointType > &fixedPoints, const std::vector< std::vector< std::vector< std::vector< float > > > > &patchIndex, const std::vector< torch::Tensor > subsetsOfFeatures, torch::Device device, const ImpactTensorUtils::ImagesPatchValuesEvaluator< ImagePointType > &imagesPatchValuesEvaluator)
Computes feature outputs for all patches using each model.
std::function< torch::Tensor(const ImagePointType &, const std::vector< std::vector< float > > &, const std::vector< int64_t > &)> ImagesPatchValuesEvaluator
std::vector< torch::Tensor > GenerateOutputsAndJacobian(const std::vector< itk::ImpactModelConfiguration > &modelConfig, const std::vector< ImagePointType > &fixedPoints, const std::vector< std::vector< std::vector< std::vector< float > > > > &patchIndex, std::vector< torch::Tensor > subsetsOfFeatures, std::vector< torch::Tensor > fixedOutputsTensor, torch::Device device, std::vector< std::unique_ptr< ImpactLoss::Loss > > &losses, const ImpactTensorUtils::ImagesPatchValuesAndJacobiansEvaluator< ImagePointType > &imagesPatchValuesAndJacobiansEvaluator)
Computes feature outputs and their spatial Jacobians for image registration.
std::vector< std::vector< float > > GetPatchIndex(const itk::ImpactModelConfiguration &modelConfiguration, std::mt19937 &randomGenerator, unsigned int dimension)
Computes patch index offsets around a center point based on model configuration.


Generated on 1774142652 for elastix by doxygen 1.15.0 elastix logo