Open3D (C++ API)  0.18.0
Loading...
Searching...
No Matches
RaggedTensor.h
Go to the documentation of this file.
1// ----------------------------------------------------------------------------
2// - Open3D: www.open3d.org -
3// ----------------------------------------------------------------------------
4// Copyright (c) 2018-2023 www.open3d.org
5// SPDX-License-Identifier: MIT
6// ----------------------------------------------------------------------------
7
8#include <torch/custom_class.h>
9#include <torch/script.h>
10
11#include <vector>
12
14
19struct RaggedTensor : torch::CustomClassHolder {
20public:
22
24 RaggedTensor(torch::Tensor values, torch::Tensor row_splits)
25 : _values(values), _row_splits(row_splits) {}
26
34 c10::intrusive_ptr<RaggedTensor> FromRowSplits(torch::Tensor values,
35 torch::Tensor row_splits,
36 bool validate = true) const;
37
39 torch::Tensor GetValues() const;
40
42 torch::Tensor GetRowSplits() const;
43
45 std::string ToString() const;
46
52 torch::Tensor GetItem(int key) const;
53
57 int64_t Len() const;
58
60 c10::intrusive_ptr<RaggedTensor> Clone() const;
61
62 c10::intrusive_ptr<RaggedTensor> Concat(
63 c10::intrusive_ptr<RaggedTensor> r_tensor, int64_t axis) const;
64
65 template <typename T>
66 c10::intrusive_ptr<RaggedTensor> Add(T value) const {
67 return FromRowSplits(_values + value, _row_splits, false);
68 }
69
70 template <typename T>
71 c10::intrusive_ptr<RaggedTensor> Add_(T value) {
72 _values += value;
73 return c10::make_intrusive<RaggedTensor>(_values, _row_splits);
74 }
75
76 template <typename T>
77 c10::intrusive_ptr<RaggedTensor> Sub(T value) const {
78 return FromRowSplits(_values - value, _row_splits, false);
79 }
80
81 template <typename T>
82 c10::intrusive_ptr<RaggedTensor> Sub_(T value) {
83 _values -= value;
84 return c10::make_intrusive<RaggedTensor>(_values, _row_splits);
85 }
86
87 template <typename T>
88 c10::intrusive_ptr<RaggedTensor> Mul(T value) const {
89 return FromRowSplits(_values * value, _row_splits, false);
90 }
91
92 template <typename T>
93 c10::intrusive_ptr<RaggedTensor> Mul_(T value) {
94 _values *= value;
95 return c10::make_intrusive<RaggedTensor>(_values, _row_splits);
96 }
97
98 template <typename T>
99 c10::intrusive_ptr<RaggedTensor> Div(T value) const {
100 return FromRowSplits(_values / value, _row_splits, false);
101 }
102
103 template <typename T>
104 c10::intrusive_ptr<RaggedTensor> Div_(T value) {
105 _values /= value;
106 return c10::make_intrusive<RaggedTensor>(_values, _row_splits);
107 }
108
109 template <typename T>
110 c10::intrusive_ptr<RaggedTensor> FloorDiv(T value) const {
111 return FromRowSplits(_values.floor_divide(value), _row_splits, false);
112 }
113
114 template <typename T>
115 c10::intrusive_ptr<RaggedTensor> FloorDiv_(T value) {
116 _values.floor_divide_(value);
117 return c10::make_intrusive<RaggedTensor>(_values, _row_splits);
118 }
119
120private:
121 torch::Tensor _values, _row_splits;
122};
123
124static auto registry =
125 torch::class_<RaggedTensor>("my_classes", "RaggedTensor")
126 .def(torch::init<>())
127 .def("from_row_splits", &RaggedTensor::FromRowSplits)
128 .def("get_values", &RaggedTensor::GetValues)
129 .def("get_row_splits", &RaggedTensor::GetRowSplits)
130 .def("__repr__",
131 [](const c10::intrusive_ptr<RaggedTensor>& self) {
132 return self->ToString();
133 })
134 .def("__str__",
135 [](const c10::intrusive_ptr<RaggedTensor>& self) {
136 return self->ToString();
137 })
138 .def("__getitem__",
139 [](const c10::intrusive_ptr<RaggedTensor>& self,
140 int64_t key) { return self->GetItem(key); })
141 .def("__len__", &RaggedTensor::Len)
142 .def("clone", &RaggedTensor::Clone)
143 .def("concat", &RaggedTensor::Concat)
144
145 .def("add",
146 [](const c10::intrusive_ptr<RaggedTensor>& self,
147 torch::Tensor value) { return self->Add(value); })
148 .def("add_",
149 [](const c10::intrusive_ptr<RaggedTensor>& self,
150 torch::Tensor value) { return self->Add_(value); })
151 .def("__add__",
152 [](const c10::intrusive_ptr<RaggedTensor>& self,
153 torch::Tensor value) { return self->Add(value); })
154 .def("__iadd__",
155 [](const c10::intrusive_ptr<RaggedTensor>& self,
156 torch::Tensor value) { return self->Add_(value); })
157
158 .def("sub",
159 [](const c10::intrusive_ptr<RaggedTensor>& self,
160 torch::Tensor value) { return self->Sub(value); })
161 .def("sub_",
162 [](const c10::intrusive_ptr<RaggedTensor>& self,
163 torch::Tensor value) { return self->Sub_(value); })
164 .def("__sub__",
165 [](const c10::intrusive_ptr<RaggedTensor>& self,
166 torch::Tensor value) { return self->Sub(value); })
167 .def("__isub__",
168 [](const c10::intrusive_ptr<RaggedTensor>& self,
169 torch::Tensor value) { return self->Sub_(value); })
170
171 .def("mul",
172 [](const c10::intrusive_ptr<RaggedTensor>& self,
173 torch::Tensor value) { return self->Mul(value); })
174 .def("mul_",
175 [](const c10::intrusive_ptr<RaggedTensor>& self,
176 torch::Tensor value) { return self->Mul_(value); })
177 .def("__mul__",
178 [](const c10::intrusive_ptr<RaggedTensor>& self,
179 torch::Tensor value) { return self->Mul(value); })
180 .def("__imul__",
181 [](const c10::intrusive_ptr<RaggedTensor>& self,
182 torch::Tensor value) { return self->Mul_(value); })
183
184 .def("div",
185 [](const c10::intrusive_ptr<RaggedTensor>& self,
186 torch::Tensor value) { return self->Div(value); })
187 .def("div_",
188 [](const c10::intrusive_ptr<RaggedTensor>& self,
189 torch::Tensor value) { return self->Div_(value); })
190 .def("__truediv__",
191 [](const c10::intrusive_ptr<RaggedTensor>& self,
192 torch::Tensor value) { return self->Div(value); })
193 .def("__itruediv__",
194 [](const c10::intrusive_ptr<RaggedTensor>& self,
195 torch::Tensor value) { return self->Div_(value); })
196 .def("__floordiv__",
197 [](const c10::intrusive_ptr<RaggedTensor>& self,
198 torch::Tensor value) { return self->FloorDiv(value); })
199 .def("__ifloordiv__",
200 [](const c10::intrusive_ptr<RaggedTensor>& self,
201 torch::Tensor value) {
202 return self->FloorDiv_(value);
203 });
Definition RaggedTensor.h:19
c10::intrusive_ptr< RaggedTensor > Div_(T value)
Definition RaggedTensor.h:104
RaggedTensor(torch::Tensor values, torch::Tensor row_splits)
Constructor for creating RaggedTensor with values and row_splits.
Definition RaggedTensor.h:24
RaggedTensor()
Definition RaggedTensor.h:21
c10::intrusive_ptr< RaggedTensor > FloorDiv_(T value)
Definition RaggedTensor.h:115
c10::intrusive_ptr< RaggedTensor > FloorDiv(T value) const
Definition RaggedTensor.h:110
torch::Tensor GetValues() const
Returns _values tensor.
Definition RaggedTensor.cpp:39
c10::intrusive_ptr< RaggedTensor > Sub_(T value)
Definition RaggedTensor.h:82
c10::intrusive_ptr< RaggedTensor > Add(T value) const
Definition RaggedTensor.h:66
c10::intrusive_ptr< RaggedTensor > Add_(T value)
Definition RaggedTensor.h:71
c10::intrusive_ptr< RaggedTensor > Mul_(T value)
Definition RaggedTensor.h:93
c10::intrusive_ptr< RaggedTensor > FromRowSplits(torch::Tensor values, torch::Tensor row_splits, bool validate=true) const
Definition RaggedTensor.cpp:12
int64_t Len() const
Definition RaggedTensor.cpp:54
c10::intrusive_ptr< RaggedTensor > Clone() const
Copy Tensor to the same device.
Definition RaggedTensor.cpp:56
torch::Tensor GetRowSplits() const
Returns _row_splits tensor.
Definition RaggedTensor.cpp:40
c10::intrusive_ptr< RaggedTensor > Sub(T value) const
Definition RaggedTensor.h:77
c10::intrusive_ptr< RaggedTensor > Div(T value) const
Definition RaggedTensor.h:99
std::string ToString() const
Returns string representation.
Definition RaggedTensor.cpp:42
c10::intrusive_ptr< RaggedTensor > Mul(T value) const
Definition RaggedTensor.h:88
c10::intrusive_ptr< RaggedTensor > Concat(c10::intrusive_ptr< RaggedTensor > r_tensor, int64_t axis) const
Definition RaggedTensor.cpp:60
torch::Tensor GetItem(int key) const
Definition RaggedTensor.cpp:49