-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtensor.zig
More file actions
103 lines (85 loc) · 3.37 KB
/
tensor.zig
File metadata and controls
103 lines (85 loc) · 3.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
const std = @import("std");
const Error = error{dimension_mismatch};
pub fn double(comptime T: type, val: T) T {
return val * 2;
}
pub fn Tensor(comptime T: type) type {
return struct {
data: []T, // pointer (usize) and a length (usize) = 16 bytes
shape: []const usize, // pointer (usize) and a length (usize) = 16
len: usize, // usize = 8 bytes
pub fn init(data: []const T, shape: []const usize) error{dimension_mismatch}!Tensor(T) {
var prod: usize = 1;
for (shape) |s| {
prod *= s;
}
if (prod != data.len) {
return error.dimension_mismatch;
}
return Tensor(T){
.data = @constCast(data),
.len = data.len,
.shape = shape,
};
}
pub fn stride(self: Tensor(T), idx: usize) usize {
return self.data.len / self.shape[idx];
}
pub fn deinit(self: Tensor(T), allocator: std.mem.Allocator) void {
allocator.free(self.data);
allocator.free(self.shape);
}
/// Caller owns all memory and must take care of freeing it
pub fn reshape(self: Tensor(T), shape: []const usize) error{dimension_mismatch}!Tensor(T) {
return try Tensor(T).init(self.data, shape);
}
// TODO: decide if this should this overwrite the data or return a new one
pub fn apply(self: Tensor(T), F: fn (comptime type, T) T) Tensor(T) {
for (0..self.len) |i| {
self.data[i] = F(T, self.data[i]);
}
return self;
}
// TODO: add some tests for data.len < 8
pub fn sum(self: Tensor(T)) T {
// https://developer.apple.com/documentation/accelerate/simd/double-precision_floating-point_vectors
// it seems like on an M3 max, we can get 8 doubles in a SIMD register
// so we can sum 8 doubles at a time
const num_vecs: usize = self.len / 8;
const remainder: usize = self.len % 8;
// sum the SIMD vectors together
var simd_sum: @Vector(8, T) = @splat(0.0);
for (0..num_vecs) |i| {
const vec: @Vector(8, T) = self.data[i * 8 ..][0..8].*;
simd_sum += vec;
}
// sum the summed SIMD vector
var ret: T = @reduce(.Add, simd_sum);
// sum the floats which could not fit inside a 8 element SIMD vector to the final sum
for (self.data[num_vecs * 8 .. num_vecs * 8 + remainder]) |item| {
ret += item;
}
return ret;
}
pub fn equal(self: Tensor(T), other: Tensor(T)) bool {
if (self.len != other.len) {
return false;
}
if (self.shape.len != other.shape.len) {
return false;
}
if (!std.mem.eql(T, self.shape, other.shape)) {
return false;
}
return std.mem.eql(T, self.data, other.data);
}
};
}
test "apply" {
const shape: [2]usize = [2]usize{ 10, 1 };
const buffer: [10]f32 = [_]f32{1} ** 10;
const a = try Tensor(f32).init(buffer[0..10], shape[0..1]);
const expected: [10]f32 = [_]f32{2} ** 10;
const b = a.apply(double);
try std.testing.expectEqualDeep(expected[0..10], b.data);
}