This repository was archived by the owner on Feb 4, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathPytorch3DObjectGenerator.py
More file actions
113 lines (78 loc) · 3.11 KB
/
Pytorch3DObjectGenerator.py
File metadata and controls
113 lines (78 loc) · 3.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import numpy as np
from pytorch3d.renderer import (
OpenGLPerspectiveCameras, look_at_view_transform, look_at_rotation,
RasterizationSettings, MeshRenderer, MeshRasterizer, BlendParams,
SoftSilhouetteShader, HardPhongShader, PointLights
)
# File: Pytorch3DObjectGenerator
# Classes:
# ConvertMesh
# Functions:
# createRenderer
# saveImage
# openImage
class ConvertMesh():
def __init__(self,mesh,renderer,initial_camera_position=np.array([3.0,50.0,0.0])):
# Function: __init__
# Inputs: self,mesh,renderer,initial_camera_position=np.array([3.0,50.0,0.0])
# Process: creates class
super().__init__()
self.mesh = mesh
self.renderer = renderer
#Units are distance, elevation, azimuth (angle)
self.camera_position = initial_camera_position
def set_camera_position(self,camera_position):
# Function: set_camera_position
# Inputs: self,camera_position
# Process: sets camera_position
self.camera_position = camera_position
def get_camera_position(self):
# Function: get_camera_position
# Inputs: self
# Process: returns camera position
return self.camera_position
def renderImage(self):
# Function: renderImage
# Inputs: self
# Process: returns image of the object
R, T = look_at_view_transform(self.camera_position[0], self.camera_position[1], self.camera_position[2])
image = self.renderer(meshes_world=self.mesh.clone(),R=R,T=T)
return image
def createRenderer(image_size,faces_per_pixel,lights_location):
# Function: createRenderer
# Inputs: image_size,faces_per_pixel,lights_location
# Process: creates an image renderer
# Output: returns renderer
cameras = OpenGLPerspectiveCameras()
#Settings for Raster
raster_settings = RasterizationSettings(
image_size=image_size,
blur_radius=0.0,
faces_per_pixel=faces_per_pixel,
)
# We can add a point light in front of the object.
lights = PointLights(location=(lights_location,))
created_renderer = MeshRenderer(
rasterizer=MeshRasterizer(
cameras=cameras,
raster_settings=raster_settings
),
shader=HardPhongShader(cameras=cameras, lights=lights)
)
return created_renderer
def saveImage(image,path):
# Function: saveImage
# Inputs: image,path
# Process: saves image to path
input_image = np.uint8(image[:,:,:,0:3].numpy().squeeze()*255)
image_output = Image.fromarray(input_image)
image_output.save(path)
def openImage(path):
# Function: openImage
# Inputs: path
# Process: opens image in path
# Output: returns image
image = Image.open(path)
image.load()
imageData = torch.Tensor(np.array(image, dtype='f8')).transpose(0,2).unsqueeze(0)
return imageData