-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathFeatureGridDevice.h
More file actions
171 lines (143 loc) · 6.88 KB
/
FeatureGridDevice.h
File metadata and controls
171 lines (143 loc) · 6.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: LicenseRef-NvidiaProprietary
*
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
*/
#pragma once
#include "FeatureGridMath.h"
#include "tin/tin_matrix.h"
#include <cuda_fp16.hpp>
namespace ntc::cuda
{
class FeatureGrid
{
public:
__device__ FeatureGrid(int numFeatures, int width, int height, size_t latentStride)
: m_latentWidth(width)
, m_latentHeight(height)
, m_numFeatures(numFeatures)
, m_latentStride(latentStride)
{
}
__device__ void SetupBilinearFilter(float u, float v, int& x0, int& y0, int& x1, int& y1, half weights[4])
{
float x = u * m_latentWidth - 0.5f;
float y = v * m_latentHeight - 0.5f;
int x_b = floor(x);
int y_b = floor(y);
float dx = x - x_b;
float dy = y - y_b;
float dxn = 1.f - dx;
float dyn = 1.f - dy;
weights[0] = half(dxn * dyn);
weights[1] = half(dx * dyn);
weights[2] = half(dxn * dy);
weights[3] = half(dx * dy);
// Wrap addressing
if (x_b < 0) x_b += m_latentWidth;
if (y_b < 0) y_b += m_latentHeight;
x0 = x_b % m_latentWidth;
y0 = y_b % m_latentHeight;
x1 = (x_b + 1) % m_latentWidth;
y1 = (y_b + 1) % m_latentHeight;
}
__device__ void Sample(float u, float v, const half* __restrict__ features,
tin::HArray<NTC_MLP_INPUT_CHANNELS>& outputArray, int arrayOffset)
{
int x0, y0, x1, y1;
half weights[4];
SetupBilinearFilter(u, v, x0, y0, x1, y1, weights);
size_t a00 = (size_t(y0) * size_t(m_latentWidth) + size_t(x0)) * FeatureGridMath::FeaturesPerGroup;
size_t a01 = (size_t(y0) * size_t(m_latentWidth) + size_t(x1)) * FeatureGridMath::FeaturesPerGroup;
size_t a10 = (size_t(y1) * size_t(m_latentWidth) + size_t(x0)) * FeatureGridMath::FeaturesPerGroup;
size_t a11 = (size_t(y1) * size_t(m_latentWidth) + size_t(x1)) * FeatureGridMath::FeaturesPerGroup;
#pragma unroll
for (int featureIndex = 0; featureIndex < NTC_MLP_FEATURES; featureIndex += 2)
{
if (featureIndex >= m_numFeatures)
break;
half2 x00 = *(half2*)(features + a00);
half2 x01 = *(half2*)(features + a01);
half2 x10 = *(half2*)(features + a10);
half2 x11 = *(half2*)(features + a11);
half2 d;
d.x = x00.x * weights[0] + x01.x * weights[1] + x10.x * weights[2] + x11.x * weights[3];
d.y = x00.y * weights[0] + x01.y * weights[1] + x10.y * weights[2] + x11.y * weights[3];
// Convert from [0,1] to [-1,1], that works better as a network input
d.x = d.x * half(2.0f) - half(1.0f);
d.y = d.y * half(2.0f) - half(1.0f);
outputArray.set_packed_element(d, (arrayOffset + featureIndex) / 2);
a00 += m_latentStride;
a01 += m_latentStride;
a10 += m_latentStride;
a11 += m_latentStride;
}
}
__device__ void MarkGradientMask(int x, int y, uint32_t* gradientMask, size_t maskOffsetInBits)
{
size_t bitIndex = y * m_latentWidth + x + maskOffsetInBits;
size_t wordIndex = bitIndex >> 5;
uint32_t wordMask = 1u << (bitIndex & 31);
atomicOr(gradientMask + wordIndex, wordMask);
}
template<typename GRID_GRAD_TYPE>
__device__ void SampleBackward(float u, float v,
const tin::HArray<NTC_MLP_INPUT_CHANNELS>& outputGradients, int arrayOffset,
GRID_GRAD_TYPE* __restrict__ gradients, uint32_t* gradientMask, size_t maskOffsetInBits)
{
int x0, y0, x1, y1;
half weights[4];
SetupBilinearFilter(u, v, x0, y0, x1, y1, weights);
MarkGradientMask(x0, y0, gradientMask, maskOffsetInBits);
MarkGradientMask(x1, y0, gradientMask, maskOffsetInBits);
MarkGradientMask(x0, y1, gradientMask, maskOffsetInBits);
MarkGradientMask(x1, y1, gradientMask, maskOffsetInBits);
size_t a00 = (size_t(y0) * size_t(m_latentWidth) + size_t(x0)) * FeatureGridMath::FeaturesPerGroup;
size_t a01 = (size_t(y0) * size_t(m_latentWidth) + size_t(x1)) * FeatureGridMath::FeaturesPerGroup;
size_t a10 = (size_t(y1) * size_t(m_latentWidth) + size_t(x0)) * FeatureGridMath::FeaturesPerGroup;
size_t a11 = (size_t(y1) * size_t(m_latentWidth) + size_t(x1)) * FeatureGridMath::FeaturesPerGroup;
#pragma unroll
for (int featureIndex = 0; featureIndex < NTC_MLP_FEATURES; featureIndex += 2)
{
if (featureIndex >= m_numFeatures)
break;
half2 outputGrad = outputGradients.get_packed_element((arrayOffset + featureIndex) / 2);
outputGrad.x *= half(2.0f);
outputGrad.y *= half(2.0f);
if (std::is_same<GRID_GRAD_TYPE, float>::value)
{
tin::_atomic_addf((float*)&gradients[a00 + 0], float(outputGrad.x * weights[0]));
tin::_atomic_addf((float*)&gradients[a00 + 1], float(outputGrad.y * weights[0]));
tin::_atomic_addf((float*)&gradients[a01 + 0], float(outputGrad.x * weights[1]));
tin::_atomic_addf((float*)&gradients[a01 + 1], float(outputGrad.y * weights[1]));
tin::_atomic_addf((float*)&gradients[a10 + 0], float(outputGrad.x * weights[2]));
tin::_atomic_addf((float*)&gradients[a10 + 1], float(outputGrad.y * weights[2]));
tin::_atomic_addf((float*)&gradients[a11 + 0], float(outputGrad.x * weights[3]));
tin::_atomic_addf((float*)&gradients[a11 + 1], float(outputGrad.y * weights[3]));
}
else
{
tin::_atomic_addh2((half2*)&gradients[a00], half2{outputGrad.x * weights[0], outputGrad.y * weights[0]});
tin::_atomic_addh2((half2*)&gradients[a01], half2{outputGrad.x * weights[1], outputGrad.y * weights[1]});
tin::_atomic_addh2((half2*)&gradients[a10], half2{outputGrad.x * weights[2], outputGrad.y * weights[2]});
tin::_atomic_addh2((half2*)&gradients[a11], half2{outputGrad.x * weights[3], outputGrad.y * weights[3]});
}
a00 += m_latentStride;
a01 += m_latentStride;
a10 += m_latentStride;
a11 += m_latentStride;
}
}
private:
int m_latentWidth;
int m_latentHeight;
int m_numFeatures;
size_t m_latentStride;
};
} // namespace ntc::cuda