I have a programm that uses np.random many times. Now I wan't the user to pass an argument gpu=True/False
. How can I override np.random to return cm.CUDAMatrix(np.random.uniform(low=low, high=high, size=size))
without ending in a recursion? Or is there a better way to use cudamat with small code changes?
Thanks for your help.
If you need more code please comment.
class FeedForwardNetwork():
def __init__(self, input_dim, hidden_dim, output_dim, dropout=False, dropout_prop=0.5, gpu=True):
np.random.seed(1)
self.input_layer = np.array([])
self.hidden_layer = np.array([])
self.output_layer = np.array([])
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
self.dropout = dropout
self.dropout_prop = dropout_prop
r_input_hidden = math.sqrt(6 / (input_dim + hidden_dim))
r_hidden_output = math.sqrt(6 / (hidden_dim + output_dim))
self.weights_input_hidden = np.random.uniform(low=-0.01, high=0.01, size=(input_dim, hidden_dim))
self.weights_hidden_output = np.random.uniform(low=-0.01, high=0.01, size=(hidden_dim, output_dim))
class FeedForwardNetwork():
def __init__(self, input_dim, hidden_dim, output_dim, dropout=False, dropout_prop=0.5, gpu=True):
np.random.seed(1)
self.input_layer = np.array([])
self.hidden_layer = np.array([])
self.output_layer = np.array([])
self.input_dim = input_dim
self.hidden_dim = hidden_dim
self.output_dim = output_dim
self.dropout = dropout
self.dropout_prop = dropout_prop
r_input_hidden = math.sqrt(6 / (input_dim + hidden_dim))
r_hidden_output = math.sqrt(6 / (hidden_dim + output_dim))
self.weights_input_hidden = np.random.uniform(low=-0.01, high=0.01, size=(input_dim, hidden_dim))
self.weights_hidden_output = np.random.uniform(low=-0.01, high=0.01, size=(hidden_dim, output_dim))
def np_random(self, gpu):
'''gpu:bool'''
if gpu:
return np.random.uniform(low=-0.01, high=0.01, size=(self.input_dim, self.hidden_dim))
else:
return np.random.uniform(low=-0.01, high=0.01, size=(self.hidden_dim, self.output_dim))
Then you can call it from your instance:
instance = FeedForwardNetwork(**kwargs)
instance.np_random(True/False)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.